All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
797 lines
31 KiB
C++
797 lines
31 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/raw_os_ostream.h"
|
|
|
|
#include <cassert>
|
|
#include <filesystem>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
#include "src/Compiler/CompilerOptions.hpp"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace pim;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
|
|
SpatialToPimPass() = default;
|
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
|
|
void runOnOperation() final;
|
|
|
|
private:
|
|
SmallVector<Value> outputTensors;
|
|
size_t coreId = 0;
|
|
SmallVector<Operation*> operationsToRemove;
|
|
|
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
|
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
|
void
|
|
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
|
void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
|
unsigned int argIndex,
|
|
Value channelSourceOp,
|
|
Value consumerValue,
|
|
spatial::SpatChannelNewOp& channel,
|
|
bool useBroadcastOp,
|
|
IRRewriter& rewriter);
|
|
void markOpToRemove(Operation* op);
|
|
void annotateChannelCoreIds(func::FuncOp funcOp);
|
|
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
|
|
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static bool isChannelUseChainOp(Operation* op) {
|
|
return isa<tensor::ExtractSliceOp,
|
|
tensor::CollapseShapeOp,
|
|
tensor::ExpandShapeOp,
|
|
tensor::CastOp,
|
|
tosa::ReshapeOp,
|
|
pim::PimTransposeOp>(op);
|
|
}
|
|
|
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
|
for (Value operand : op->getOperands()) {
|
|
if (mapping.lookupOrNull(operand))
|
|
continue;
|
|
|
|
Operation* definingOp = operand.getDefiningOp();
|
|
if (!definingOp)
|
|
continue;
|
|
|
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
|
continue;
|
|
|
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
}
|
|
|
|
static size_t countComputeLeafUsers(Value value) {
|
|
size_t leafUserCount = 0;
|
|
|
|
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
|
for (OpOperand& use : currentValue.getUses()) {
|
|
Operation* owner = use.getOwner();
|
|
if (isa<spatial::SpatCompute>(owner)) {
|
|
leafUserCount++;
|
|
continue;
|
|
}
|
|
|
|
if (!isChannelUseChainOp(owner))
|
|
llvm_unreachable("Channel use chain contains unsupported op");
|
|
|
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
|
self(owner->getResult(0), self);
|
|
}
|
|
};
|
|
|
|
walkUses(value, walkUses);
|
|
return leafUserCount;
|
|
}
|
|
|
|
void SpatialToPimPass::runOnOperation() {
|
|
coreId = 1;
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
func::FuncDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
populateWithGenerated(patterns);
|
|
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
|
|
addResultBuffer(returnOp, rewriter);
|
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
|
markOpToRemove(receiveOp);
|
|
runOnReceiveOp(receiveOp, rewriter);
|
|
}
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
markOpToRemove(computeOp);
|
|
runOnComputeOp(computeOp, rewriter);
|
|
}
|
|
|
|
annotateChannelCoreIds(funcOp);
|
|
lowerBroadcastChannelOps(funcOp, rewriter);
|
|
|
|
RewritePatternSet channelPatterns(ctx);
|
|
populateWithGenerated(channelPatterns);
|
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
|
replaceReturnOpOperands(returnOp, rewriter);
|
|
|
|
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
|
while (!pendingRemovals.empty()) {
|
|
bool erasedAnyOp = false;
|
|
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
|
Operation* opToRemove = *it;
|
|
if (!opToRemove->use_empty()) {
|
|
++it;
|
|
continue;
|
|
}
|
|
|
|
rewriter.eraseOp(opToRemove);
|
|
it = pendingRemovals.erase(it);
|
|
erasedAnyOp = true;
|
|
}
|
|
|
|
if (erasedAnyOp)
|
|
continue;
|
|
|
|
for (auto opToRemove : pendingRemovals) {
|
|
opToRemove->dump();
|
|
for (auto user : opToRemove->getUsers())
|
|
user->dump();
|
|
}
|
|
assert(false && "tracked op removal reached a cycle or missed dependency");
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "pim0");
|
|
}
|
|
|
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
|
Location loc = computeOp->getLoc();
|
|
|
|
auto& block = computeOp.getRegion().front();
|
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
|
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
|
|
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
|
if (result.use_empty())
|
|
continue;
|
|
|
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
|
|
|
auto resultUses = result.getUses();
|
|
auto numResultUses = rangeLength(resultUses);
|
|
if (numResultUses == 1) {
|
|
OpOperand& resultUse = *resultUses.begin();
|
|
Operation* resultUser = resultUse.getOwner();
|
|
|
|
if (isChannelUseChainOp(resultUser)) {
|
|
SmallVector<Operation*> returnChain;
|
|
Value chainedValue = result;
|
|
Operation* chainUser = resultUser;
|
|
|
|
while (isChannelUseChainOp(chainUser)) {
|
|
returnChain.push_back(chainUser);
|
|
auto chainUses = chainUser->getResult(0).getUses();
|
|
if (rangeLength(chainUses) != 1)
|
|
break;
|
|
chainedValue = chainUser->getResult(0);
|
|
chainUser = chainUses.begin()->getOwner();
|
|
}
|
|
|
|
if (isa<func::ReturnOp>(chainUser)) {
|
|
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
|
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
IRMapping mapping;
|
|
mapping.map(result, yieldValue);
|
|
|
|
Value storedValue = yieldValue;
|
|
for (Operation* op : returnChain) {
|
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
storedValue = clonedOp->getResult(0);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
markOpToRemove(op);
|
|
}
|
|
|
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
|
|
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
|
if (auto storedOp = storedValue.getDefiningOp())
|
|
rewriter.setInsertionPointAfter(storedOp);
|
|
PimMemCopyDevToHostOp::create(rewriter,
|
|
loc,
|
|
outputTensor.getType(),
|
|
outputTensor,
|
|
storedValue,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
}
|
|
|
|
if (isa<func::ReturnOp>(resultUser)) {
|
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
|
size_t offset = 0;
|
|
size_t numElements = yieldType.getNumElements();
|
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
|
|
|
// Store to global memory
|
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
PimMemCopyDevToHostOp::create(rewriter,
|
|
loc,
|
|
outputTensor.getType(),
|
|
outputTensor,
|
|
yieldValue,
|
|
rewriter.getI32IntegerAttr(offset),
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(numElements * elementSize));
|
|
continue;
|
|
}
|
|
|
|
if (isa<tensor::ConcatOp>(resultUser)) {
|
|
auto concatOp = resultUser;
|
|
auto concatValue = concatOp->getResult(0);
|
|
auto concatUses = concatValue.getUses();
|
|
auto numConcatUses = rangeLength(concatUses);
|
|
if (numConcatUses == 1) {
|
|
Value chainedValue = concatValue;
|
|
Operation* concatUser = concatUses.begin()->getOwner();
|
|
|
|
while (isChannelUseChainOp(concatUser)) {
|
|
auto chainUses = concatUser->getResult(0).getUses();
|
|
if (rangeLength(chainUses) != 1)
|
|
break;
|
|
chainedValue = concatUser->getResult(0);
|
|
concatUser = chainUses.begin()->getOwner();
|
|
}
|
|
|
|
if (isa<func::ReturnOp>(concatUser)) {
|
|
size_t concatIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
|
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
|
|
size_t offset = 0;
|
|
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
|
|
offset += cast<ShapedType>(operand.getType()).getNumElements()
|
|
* cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
|
|
|
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
|
|
|
// Store to global memory
|
|
Value outputTensor = outputTensors[concatIndexInReturn];
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
PimMemCopyDevToHostOp::create(rewriter,
|
|
loc,
|
|
outputTensor.getType(),
|
|
outputTensor,
|
|
yieldValue,
|
|
rewriter.getI32IntegerAttr(offset),
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If this pattern was not found, then create a channel and send the value
|
|
|
|
// 1. Create a new ChannelOp
|
|
rewriter.setInsertionPoint(computeOp);
|
|
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
|
|
|
// 2. Receive value through the channel. Broadcast is needed whenever the
|
|
// value eventually reaches more than one compute consumer, even through a
|
|
// chain of view-like ops.
|
|
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
|
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
|
|
|
// 3. Send the value through the channel
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
if (useBroadcastOp)
|
|
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
|
|
else
|
|
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
|
|
}
|
|
|
|
// Use `HaltOp` instead of `YieldOp`
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
|
|
|
// Replace `spat.compute` with `pim.core`
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
|
block.eraseArguments(0, block.getNumArguments());
|
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
|
Block* tempComputeBlock = new Block();
|
|
computeOp.getBody().push_back(tempComputeBlock);
|
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
|
}
|
|
|
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return;
|
|
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
|
if (!dpsDefiningOp)
|
|
return;
|
|
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
if (!tiedOperand)
|
|
return;
|
|
Value tiedValue = tiedOperand->get();
|
|
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
|
tiedValue.setType(newType);
|
|
self(tiedValue, newType, self);
|
|
};
|
|
|
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
|
auto resultTensor = vmmOp.getOutput();
|
|
auto outShape = getTensorShape(outTensorOperand);
|
|
assert(isHVectorShape(outShape));
|
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
|
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
|
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
|
if (outTensorOperand == vmmOp.getInput()) {
|
|
rewriter.setInsertionPoint(vmmOp);
|
|
auto newOutputBuffer =
|
|
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
|
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
|
}
|
|
else {
|
|
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
|
outTensorOperand.setType(newType);
|
|
}
|
|
resultTensor.setType(newType);
|
|
|
|
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
|
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
|
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
|
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
|
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
|
rewriter.setInsertionPointAfter(vmmOp);
|
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
|
}
|
|
});
|
|
}
|
|
|
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
outputTensors.reserve(returnOp->getNumOperands());
|
|
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
|
for (auto returnValue : returnOp->getOperands()) {
|
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
|
outputTensors.push_back(returnValue);
|
|
}
|
|
else {
|
|
auto newOutputTensor =
|
|
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
|
outputTensors.push_back(newOutputTensor);
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
Location loc = funcOp.getLoc();
|
|
|
|
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
|
auto tensorType = cast<ShapedType>(valueToReplace.getType());
|
|
Type elementType = tensorType.getElementType();
|
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
|
|
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
|
|
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
loc,
|
|
tensorType,
|
|
deviceTensor,
|
|
hostTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
|
|
|
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
|
|
};
|
|
|
|
// Replace input tensors with memRefs
|
|
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
|
|
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
|
|
BlockArgument tensorArg = funcOp.getArgument(i);
|
|
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
|
|
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
|
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
|
|
|
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
|
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
|
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
|
|
|
Block& block = funcOp.getBody().front();
|
|
rewriter.setInsertionPoint(&block.front());
|
|
auto toTensorOp =
|
|
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
|
inputTensors.push_back(toTensorOp);
|
|
|
|
tensorArg.replaceAllUsesWith(toTensorOp);
|
|
if (failed(funcOp.eraseArgument(i)))
|
|
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
|
}
|
|
|
|
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
|
for (auto& op : funcOp.getBody().getOps())
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
unsigned numComputeWeights = computeOp.getWeights().size();
|
|
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
|
TypedValue<TensorType> tensorSource;
|
|
int64_t elementsOffset = 0;
|
|
|
|
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
|
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
|
|
|
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
|
continue;
|
|
|
|
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
|
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
|
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
|
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
|
|
assert("Extracting slice non-contiguous in memory"
|
|
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
|
|
|
|
for (size_t i = 0; i < sliceOffsets.size(); i++) {
|
|
int64_t partialOffset = sliceOffsets[i];
|
|
if (partialOffset != 0)
|
|
for (size_t j = i + 1; j < sourceShape.size(); j++)
|
|
partialOffset *= sourceShape[j];
|
|
elementsOffset += partialOffset;
|
|
}
|
|
|
|
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
|
|
sliceOpsToRemove.insert(sliceOp);
|
|
}
|
|
else
|
|
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
|
|
|
// Compute results must be transferred through channels via send/receive
|
|
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
|
continue;
|
|
|
|
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
|
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
|
|
}
|
|
}
|
|
|
|
for (auto sliceOp : sliceOpsToRemove)
|
|
if (sliceOp->getUses().empty())
|
|
rewriter.eraseOp(sliceOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
|
unsigned int argIndex,
|
|
Value channelSourceOp,
|
|
Value consumerValue,
|
|
spatial::SpatChannelNewOp& channel,
|
|
bool useBroadcastOp,
|
|
IRRewriter& rewriter) {
|
|
auto& computeBlock = computeOp.getRegion().front();
|
|
//(remember that WeightedCompute have weights as first operands, however these
|
|
// weights are not included in the block arguments. Thus, when indexing the
|
|
// block argument we need to remove the weights count)
|
|
auto computeWeightsCount = computeOp.getWeights().size();
|
|
auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount);
|
|
// Receive the tensor just before the first use of the value
|
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
|
Value receivedValue;
|
|
if (useBroadcastOp)
|
|
receivedValue =
|
|
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
|
else
|
|
receivedValue =
|
|
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
|
|
|
Value replacementValue = receivedValue;
|
|
if (consumerValue != channelSourceOp) {
|
|
SmallVector<Operation*> clonedChain;
|
|
Value currentValue = consumerValue;
|
|
while (currentValue != channelSourceOp) {
|
|
Operation* definingOp = currentValue.getDefiningOp();
|
|
if (!definingOp || !isChannelUseChainOp(definingOp))
|
|
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
|
|
|
clonedChain.push_back(definingOp);
|
|
currentValue = definingOp->getOperand(0);
|
|
}
|
|
|
|
IRMapping mapping;
|
|
mapping.map(channelSourceOp, receivedValue);
|
|
for (Operation* op : llvm::reverse(clonedChain)) {
|
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
markOpToRemove(op);
|
|
}
|
|
|
|
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
|
}
|
|
|
|
assert(replacementValue.getType() == blockArg.getType()
|
|
&& "Replayed channel use chain must match block argument type");
|
|
blockArg.replaceAllUsesWith(replacementValue);
|
|
}
|
|
|
|
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
|
spatial::SpatChannelNewOp& channel,
|
|
bool useBroadcastOp,
|
|
IRRewriter& rewriter) {
|
|
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
|
for (OpOperand& use : currentValue.getUses()) {
|
|
Operation* owner = use.getOwner();
|
|
if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
|
|
replaceBlockArgumentWithRecvOp(
|
|
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
|
continue;
|
|
}
|
|
|
|
if (!isChannelUseChainOp(owner))
|
|
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
|
|
|
markOpToRemove(owner);
|
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
|
self(owner->getResult(0), self);
|
|
}
|
|
};
|
|
|
|
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
|
}
|
|
|
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
|
if (!llvm::is_contained(operationsToRemove, op))
|
|
operationsToRemove.push_back(op);
|
|
}
|
|
|
|
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
|
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
|
markOpToRemove(channelNewOp);
|
|
|
|
if (channelNewOp->use_empty())
|
|
return;
|
|
|
|
spatial::SpatChannelSendOp sendOp;
|
|
spatial::SpatChannelReceiveOp receiveOp;
|
|
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
|
|
|
for (Operation* user : channelNewOp->getUsers()) {
|
|
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
|
sendOp = op;
|
|
continue;
|
|
}
|
|
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
|
|
receiveOp = op;
|
|
continue;
|
|
}
|
|
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
|
|
broadcastSendOp = op;
|
|
continue;
|
|
}
|
|
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user))
|
|
continue;
|
|
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
|
}
|
|
|
|
if (broadcastSendOp) {
|
|
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
|
|
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
|
return;
|
|
}
|
|
|
|
if (!sendOp || !receiveOp)
|
|
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
|
|
|
|
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
|
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
|
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
|
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
|
|
});
|
|
}
|
|
|
|
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
|
|
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
|
|
|
|
for (auto sendOp : broadcastSendOps) {
|
|
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
|
|
|
rewriter.setInsertionPoint(sendOp);
|
|
bool foundReceiver = false;
|
|
for (Operation* user : channelNewOp->getUsers()) {
|
|
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
|
|
if (!receiveOp)
|
|
continue;
|
|
|
|
foundReceiver = true;
|
|
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
|
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
|
}
|
|
|
|
if (!foundReceiver)
|
|
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
|
|
|
|
rewriter.eraseOp(sendOp);
|
|
}
|
|
|
|
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
|
|
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
|
|
|
|
for (auto receiveOp : broadcastReceiveOps) {
|
|
rewriter.setInsertionPoint(receiveOp);
|
|
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
|
Value receivedValue =
|
|
PimReceiveOp::create(
|
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
rewriter.replaceOp(receiveOp, receivedValue);
|
|
}
|
|
}
|
|
|
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
|
for (auto it : llvm::enumerate(originalOperands)) {
|
|
size_t orderWithinReturn = it.index();
|
|
Operation* returnOperand = it.value().getDefiningOp();
|
|
|
|
rewriter.modifyOpInPlace(returnOp,
|
|
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
|
|
|
Operation* opToErase = returnOperand;
|
|
while (opToErase) {
|
|
bool isExclusivelyOwnedByReturnChain = opToErase->use_empty();
|
|
if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) {
|
|
Operation* onlyUser = *opToErase->getUsers().begin();
|
|
isExclusivelyOwnedByReturnChain =
|
|
isa<func::ReturnOp, tensor::ConcatOp>(onlyUser) || isChannelUseChainOp(onlyUser);
|
|
}
|
|
if (!isExclusivelyOwnedByReturnChain)
|
|
break;
|
|
|
|
if (isChannelUseChainOp(opToErase)) {
|
|
Value source = opToErase->getOperand(0);
|
|
markOpToRemove(opToErase);
|
|
opToErase = source.getDefiningOp();
|
|
continue;
|
|
}
|
|
|
|
if (isa<tensor::ConcatOp>(opToErase))
|
|
markOpToRemove(opToErase);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
|
|
|
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
|
|
|
auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter);
|
|
if (failed(sendOpOpt))
|
|
llvm_unreachable("ChannelReceiveOp has no matching SendOp");
|
|
|
|
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
|
|
|
Value receiveRes = receiveOp.getResult();
|
|
|
|
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
|
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
|
|
|
if (useBroadcastOp) {
|
|
// When receiving, we actually noticed that the value has more than one
|
|
// user. This means that we need to get the replace the original SendOp with
|
|
// a BroadcastSendOp
|
|
rewriter.setInsertionPoint(sendOp);
|
|
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getInput());
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
|
|
|
} // namespace onnx_mlir
|