add PIM accelerator
This commit is contained in:
21
src/PIM/Conversion/SpatialToPIM/CMakeLists.txt
Normal file
21
src/PIM/Conversion/SpatialToPIM/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
|
||||
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(SpatialToPIMIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToPIM
|
||||
SpatialToPIMPass.hpp
|
||||
SpatialToPIMPass.cpp
|
||||
SpatialToPIMCommon.cpp
|
||||
|
||||
DEPENDS
|
||||
SpatialToPIMIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMCompilerOptions
|
||||
OMPIMCommon
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
28
src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td
Normal file
28
src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef SPATIAL_TO_PIM
|
||||
#define SPATIAL_TO_PIM
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def spatToPimVMMOp : Pat<
|
||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimMVMOp : Pat<
|
||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimMVMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVAddOp : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
97
src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp
Normal file
97
src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "SpatialToPIMCommon.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
[1, 10, 3, 4] inputShape
|
||||
[0, 2, 1, 3] offsets
|
||||
|
||||
acc = 1
|
||||
---
|
||||
ret = 3
|
||||
acc = 4
|
||||
---
|
||||
ret = 3 + 4 * 1 = 7
|
||||
acc = 12
|
||||
---
|
||||
ret = 7 + 12 * 2 = 31
|
||||
acc = 120
|
||||
---
|
||||
ret = 31 + 120 * 0 = 31
|
||||
acc = 120
|
||||
*/
|
||||
|
||||
size_t returnValue = 0;
|
||||
|
||||
auto sliceOffsets = sliceOp.getStaticOffsets();
|
||||
auto inputDimSizes = inputShape.getShape();
|
||||
|
||||
assert(sliceOffsets.size() == inputDimSizes.size());
|
||||
|
||||
size_t accumulatedDimensionSize = 1;
|
||||
|
||||
// Reverse iterate the two vectors
|
||||
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
|
||||
auto curSliceOffset = std::get<0>(it);
|
||||
auto curInputDimSize = std::get<1>(it);
|
||||
|
||||
returnValue += accumulatedDimensionSize * curSliceOffset;
|
||||
accumulatedDimensionSize *= curInputDimSize;
|
||||
}
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
assert(!users.empty());
|
||||
|
||||
Operation* earliestUser = *users.begin();
|
||||
for (auto curUser : users)
|
||||
if (curUser->isBeforeInBlock(earliestUser))
|
||||
earliestUser = curUser;
|
||||
|
||||
return earliestUser;
|
||||
}
|
||||
|
||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
|
||||
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
||||
});
|
||||
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||
}
|
||||
|
||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||
Value result = operation->getResult(0);
|
||||
auto resultType = result.getType();
|
||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||
|
||||
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
|
||||
auto validOperands =
|
||||
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
|
||||
auto bestOperand = validOperands.begin();
|
||||
|
||||
if (bestOperand != validOperands.end())
|
||||
return *bestOperand;
|
||||
|
||||
auto resultShapedType = cast<ShapedType>(resultType);
|
||||
rewriter.setInsertionPoint(operation);
|
||||
return rewriter.create<tensor::EmptyOp>(
|
||||
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
108
src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp
Normal file
108
src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
*
|
||||
* The static offsets represent the starting position of the slice in each
|
||||
* dimension, while the static tensor input gives its dimension size.
|
||||
*
|
||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
||||
* calculated.
|
||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
||||
* \return The actual offset of the ExtractSliceOp.
|
||||
*/
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the earliest operation that uses the given value within the value's
|
||||
* block.
|
||||
*
|
||||
* @param value The value for which to find the earliest user operation.
|
||||
* @return The earliest user operation that uses the given value within the
|
||||
* current block.
|
||||
*/
|
||||
Operation* getEarliestUserWithinBlock(Value value);
|
||||
|
||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
|
||||
|
||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
|
||||
|
||||
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
|
||||
const ArrayRef<int64_t> offsets,
|
||||
const ArrayRef<int64_t> sizes,
|
||||
const ArrayRef<int64_t> strides) {
|
||||
// Check that all strides are 1
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
// Check offsets from right to left:
|
||||
// The first offset_n at position n different from 0:
|
||||
// - limits all sizes to the left to 1
|
||||
// - limits size_n to dimension_n - offset_n
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check sizes from right to left:
|
||||
// The first size_n at position n different from shape_n limits all sizes to the left to 1
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
|
||||
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
|
||||
}
|
||||
|
||||
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
491
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp
Normal file
491
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp
Normal file
@@ -0,0 +1,491 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "SpatialToPIMPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace pim {
|
||||
|
||||
void SpatialToPIMPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
||||
if (!funcOp)
|
||||
llvm_unreachable("No FuncOp found in the begin of module");
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addResultBuffer(returnOp, rewriter);
|
||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
||||
|
||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||
operationsToRemove.push_back(receiveOp);
|
||||
runOnReceiveOp(receiveOp, rewriter);
|
||||
}
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
|
||||
operationsToRemove.push_back(computeOp);
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnOpOperands(returnOp, rewriter);
|
||||
|
||||
// Remove all ComputeOps
|
||||
for (auto opToRemove : llvm::reverse(operationsToRemove)) {
|
||||
if (!opToRemove->use_empty()) {
|
||||
opToRemove->dump();
|
||||
for (auto user : opToRemove->getUsers())
|
||||
user->dump();
|
||||
assert(false && "opToRemove should be unused at this point");
|
||||
}
|
||||
rewriter.eraseOp(opToRemove);
|
||||
}
|
||||
|
||||
// Dump to file for debug
|
||||
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
|
||||
std::filesystem::create_directory(outputDir);
|
||||
std::fstream file(outputDir + "/pim.mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||
|
||||
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||
// If this result has no uses, then just skip it
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
|
||||
/*
|
||||
* Here we assume that ReturnOp are only reachable by the following patterns:
|
||||
*
|
||||
* 1)
|
||||
* %0 = spat.compute([...])
|
||||
* [%0 has one user, which is a ConcatOp]
|
||||
* %1 = tensor.concat(%0)
|
||||
* [%1 has one user, which is a ReturnOp]
|
||||
* return %1
|
||||
*
|
||||
* 2)
|
||||
* %0 = spat.compute([...])
|
||||
* [%0 has one user, which is a ReturnOp]
|
||||
* return %0
|
||||
*
|
||||
* If the IR is like 2), then we can store the tensor to the output global memory location
|
||||
*/
|
||||
auto resultUses = result.getUses();
|
||||
auto numResultUses = rangeLength(resultUses);
|
||||
if (numResultUses == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t offset = 0;
|
||||
size_t numElements = yieldType.getNumElements();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
|
||||
// Store to global memory
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.create<PimMemCopyDevToHostOp>(loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
|
||||
auto concatOp = resultUser;
|
||||
auto concatValue = concatOp->getResult(0);
|
||||
auto concatUses = concatValue.getUses();
|
||||
auto numConcatUses = rangeLength(concatUses);
|
||||
if (numConcatUses == 1) {
|
||||
OpOperand& concatUse = *concatUses.begin();
|
||||
Operation* concatUser = concatUse.getOwner();
|
||||
if (isa<func::ReturnOp>(concatUser)) {
|
||||
size_t concatIndexInReturn = concatUse.getOperandNumber();
|
||||
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
|
||||
size_t offset = 0;
|
||||
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
|
||||
offset += cast<ShapedType>(operand.getType()).getNumElements() * cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
|
||||
// Store to global memory
|
||||
Value outputTensor = outputTensors[concatIndexInReturn];
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.create<PimMemCopyDevToHostOp>(
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
rewriter.getI32IntegerAttr(offset),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If this pattern was not found, then create a channel and send the value
|
||||
|
||||
// 1. Create a new ChannelOp
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
||||
|
||||
// 2. Receive value through the channel
|
||||
// If this result is used by more than one user, then use a "Broadcast"
|
||||
// channel operation. However, there is a special case: we have a single
|
||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
||||
// will detect this case and update `useBroadcastOp` accordingly.
|
||||
bool useBroadcastOp = (numResultUses > 1);
|
||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
||||
|
||||
// 3. Send the value through the channel
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
if (useBroadcastOp)
|
||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
|
||||
else
|
||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
|
||||
}
|
||||
|
||||
// Use `HaltOp` instead of `YieldOp`
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||
|
||||
// Replace `spat.compute` with `pim.core`
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||
rewriter.create<PimHaltOp>(computeOp.getLoc());
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return;
|
||||
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
||||
if (!dpsDefiningOp)
|
||||
return;
|
||||
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return;
|
||||
Value tiedValue = tiedOperand->get();
|
||||
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
||||
tiedValue.setType(newType);
|
||||
self(tiedValue, newType, self);
|
||||
};
|
||||
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outTensorOperand = vmmOp.getOutBuf();
|
||||
auto resultTensor = vmmOp.getOutRes();
|
||||
auto outShape = getTensorShape(outTensorOperand);
|
||||
assert(isHVectorShape(outShape));
|
||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||
outTensorOperand.setType(newType);
|
||||
resultTensor.setType(newType);
|
||||
|
||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
||||
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
||||
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
||||
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
||||
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||
rewriter.setInsertionPointAfter(vmmOp);
|
||||
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||
for (auto returnValue : returnOp->getOperands()) {
|
||||
auto newOutputTensor =
|
||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
||||
outputTensors.push_back(newOutputTensor);
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
|
||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(valueToReplace.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
||||
|
||||
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
|
||||
|
||||
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
|
||||
loc,
|
||||
tensorType,
|
||||
deviceTensor,
|
||||
hostTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
|
||||
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
|
||||
};
|
||||
|
||||
// Replace input tensors with memRefs
|
||||
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
|
||||
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
|
||||
BlockArgument tensorArg = funcOp.getArgument(i);
|
||||
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
|
||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||
|
||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||
|
||||
Block& block = funcOp.getBody().front();
|
||||
rewriter.setInsertionPoint(&block.front());
|
||||
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||
inputTensors.push_back(toTensorOp);
|
||||
|
||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||
funcOp.eraseArgument(i);
|
||||
}
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
for (auto& op : funcOp.getBody().getOps())
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
||||
TypedValue<TensorType> tensorSource;
|
||||
int64_t elementsOffset = 0;
|
||||
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
|
||||
assert("Extracting slice non-contiguous in memory"
|
||||
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
|
||||
|
||||
for (size_t i = 0; i < sliceOffsets.size(); i++) {
|
||||
int64_t partialOffset = sliceOffsets[i];
|
||||
if (partialOffset != 0)
|
||||
for (size_t j = i + 1; j < sourceShape.size(); j++)
|
||||
partialOffset *= sourceShape[j];
|
||||
elementsOffset += partialOffset;
|
||||
}
|
||||
|
||||
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
|
||||
sliceOpsToRemove.insert(sliceOp);
|
||||
}
|
||||
else
|
||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
||||
|
||||
// Compute results must be transferred through channels via send/receive
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
||||
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto sliceOp : sliceOpsToRemove)
|
||||
if (sliceOp->getUses().empty())
|
||||
rewriter.eraseOp(sliceOp);
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto& computeBlock = computeOp.getRegion().front();
|
||||
//(remember that WeightedCompute have weights as first operands, however these
|
||||
// weights are not included in the block arguments. Thus, when indexing the
|
||||
// block argument we need to remove the weights count)
|
||||
auto computeWeightsCount = computeOp.getWeights().size();
|
||||
auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount);
|
||||
// Receive the tensor just before the first use of the value
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
Value receivedValue;
|
||||
if (useBroadcastOp)
|
||||
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
||||
else
|
||||
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
||||
|
||||
blockArg.replaceAllUsesWith(receivedValue);
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
IRRewriter& rewriter) {
|
||||
auto sourceOpUses = channelSourceOp.getUses();
|
||||
|
||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
||||
if (useBroadcastOp == false) {
|
||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
||||
assert(rangeLength(sourceOpUses) == 1);
|
||||
|
||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
||||
if (reshapeOpUsesCount > 1)
|
||||
useBroadcastOp = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& resultUse : sourceOpUses) {
|
||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
||||
|
||||
if (computeUser) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!computeUser) {
|
||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
||||
if (!reshapeOp) {
|
||||
resultUse.getOwner()->dump();
|
||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
||||
}
|
||||
|
||||
// The tensorType now becomes the one of the reshapeOp
|
||||
channelTensorType = reshapeOp.getResult().getType();
|
||||
|
||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
||||
|
||||
if (!computeUser)
|
||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
||||
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
||||
}
|
||||
|
||||
// Remove the reshapeOp, so that the sourceOp has no users
|
||||
operationsToRemove.push_back(reshapeOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
|
||||
size_t orderWithinReturn = it.index();
|
||||
|
||||
rewriter.modifyOpInPlace(returnOp,
|
||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||
|
||||
// If the operand is a concatenation operation and the returnOp was the only
|
||||
// user of the returnOperand, we can safely remove it
|
||||
if (isAConcatOp(returnOperand)) {
|
||||
auto returnOperandUses = it.value().getUses();
|
||||
if (rangeLength(returnOperandUses) == 0)
|
||||
rewriter.eraseOp(returnOperand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||
|
||||
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
||||
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter);
|
||||
if (failed(sendOpOpt))
|
||||
llvm_unreachable("ChannelReceiveOp has no matching SendOp");
|
||||
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
auto tensorType = receiveOp.getType();
|
||||
Value receiveRes = receiveOp.getResult();
|
||||
|
||||
// Check if the receiveOp value has more than one user
|
||||
auto receiveUses = receiveRes.getUses();
|
||||
auto receiveUsesCount = rangeLength(receiveUses);
|
||||
assert(receiveUsesCount > 0);
|
||||
bool useBroadcastOp = receiveUsesCount > 1;
|
||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
||||
|
||||
if (useBroadcastOp) {
|
||||
// When receiving, we actually noticed that the value has more than one
|
||||
// user. This means that we need to get the replace the original SendOp with
|
||||
// a BroadcastSendOp
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
|
||||
} // namespace onnx_mlir
|
||||
60
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp
Normal file
60
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace pim {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
||||
|
||||
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPIMPass() = default;
|
||||
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<Value> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
|
||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void addReceiveOps(Value& channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& channelTensorType,
|
||||
bool& useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
Type& tensorType,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace pim
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
12
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp
Normal file
12
src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace spatial {
|
||||
|
||||
// TODO: Add here eventual patterns
|
||||
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user