This commit is contained in:
@@ -7,8 +7,8 @@
|
|||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
@@ -152,8 +152,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
Location loc = computeBatchOp.getLoc();
|
Location loc = computeBatchOp.getLoc();
|
||||||
Block& oldBlock = computeBatchOp.getBody().front();
|
Block& oldBlock = computeBatchOp.getBody().front();
|
||||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||||
@@ -167,7 +167,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||||
SmallVector<Value> batchInputs;
|
SmallVector<Value> batchInputs;
|
||||||
if (!computeBatchOp.getInputs().empty())
|
if (!computeBatchOp.getInputs().empty())
|
||||||
@@ -192,7 +192,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
return computeBatchOp.emitOpError(
|
return computeBatchOp.emitOpError(
|
||||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||||
|
|
||||||
hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc);
|
hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc);
|
||||||
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -6,9 +6,9 @@
|
|||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -148,15 +148,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
IRRewriter& rewriter,
|
||||||
state.operationsToRemove.push_back(op);
|
OperationFolder& constantFolder) {
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
|
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
@@ -179,7 +176,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||||
.getOutput();
|
.getOutput();
|
||||||
blockArg.replaceAllUsesWith(received);
|
blockArg.replaceAllUsesWith(received);
|
||||||
markOpToRemove(state, receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,7 +197,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
blockArg.replaceAllUsesWith(received);
|
blockArg.replaceAllUsesWith(received);
|
||||||
markOpToRemove(state, receiveTensorOp);
|
markOpToRemove(receiveTensorOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,9 +208,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
if (result.use_empty())
|
if (result.use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
|
||||||
ReturnPathLoweringResult returnPathResult =
|
ReturnPathLoweringResult returnPathResult =
|
||||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||||
return failure();
|
return failure();
|
||||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||||
@@ -240,7 +236,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
auto coreOp = PimCoreOp::create(rewriter,
|
auto coreOp = PimCoreOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
ValueRange(computeWeights),
|
ValueRange(computeWeights),
|
||||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||||
@@ -249,7 +245,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
|
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,8 +257,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
PimMemCopyHostToDevOp::create(rewriter,
|
PimMemCopyHostToDevOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
input,
|
input,
|
||||||
getTensorSizeInBytesAttr(rewriter, input))
|
getTensorSizeInBytesAttr(rewriter, input))
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
struct CoreLoweringState {
|
|
||||||
size_t& nextCoreId;
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
|
||||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
|
||||||
mlir::OperationFolder& constantFolder;
|
|
||||||
};
|
|
||||||
|
|
||||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -9,9 +9,9 @@
|
|||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -44,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
|
||||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
|
||||||
state.operationsToRemove.push_back(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||||
std::string name = baseName.str();
|
std::string name = baseName.str();
|
||||||
unsigned suffix = 0;
|
unsigned suffix = 0;
|
||||||
@@ -390,9 +385,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||||
IRRewriter& rewriter,
|
|
||||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||||
Value currentReturnValue = returnValue;
|
Value currentReturnValue = returnValue;
|
||||||
@@ -427,8 +420,8 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerProducedValueReturnPath(
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||||
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
|
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||||
Location loc = producerOp->getLoc();
|
Location loc = producerOp->getLoc();
|
||||||
OperationFolder constantFolder(producerOp->getContext());
|
OperationFolder constantFolder(producerOp->getContext());
|
||||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||||
@@ -437,13 +430,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
|||||||
Value currentStoredValue = storedValue;
|
Value currentStoredValue = storedValue;
|
||||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||||
for (Operation* op : returnUse->helperChain)
|
for (Operation* op : returnUse->helperChain)
|
||||||
markOpToRemove(state, op);
|
markOpToRemove(op);
|
||||||
|
|
||||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||||
rewriter.setInsertionPointAfter(storedOp);
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -464,7 +457,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
|||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
rewriter.setInsertionPointAfterValue(storedValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -480,11 +473,11 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
|||||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
||||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
|
|
||||||
if (concatReturnUse->helperChain.empty()) {
|
if (concatReturnUse->helperChain.empty()) {
|
||||||
rewriter.setInsertionPointAfterValue(storedValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
@@ -505,7 +498,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
|||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfterValue(storedValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||||
@@ -553,12 +546,15 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
|||||||
return ReturnPathLoweringResult::NotReturnPath;
|
return ReturnPathLoweringResult::NotReturnPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
raptor::SpatialToPimPass::ReturnPathLoweringResult
|
||||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
|
OpResult result,
|
||||||
|
Value yieldValue,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
|
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||||
if (!op)
|
if (!op)
|
||||||
return;
|
return;
|
||||||
@@ -575,13 +571,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
|
|
||||||
if (isReturnHelperChainOp(op)) {
|
if (isReturnHelperChainOp(op)) {
|
||||||
Value source = op->getOperand(0);
|
Value source = op->getOperand(0);
|
||||||
markOpToRemove(state, op);
|
markOpToRemove(op);
|
||||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
markOpToRemove(state, computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (!computeOp.getInputs().empty())
|
if (!computeOp.getInputs().empty())
|
||||||
for (Value input : computeOp.getInputs())
|
for (Value input : computeOp.getInputs())
|
||||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||||
@@ -589,33 +585,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getOperands())
|
for (Value operand : concatOp.getOperands())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getInputs())
|
for (Value operand : concatOp.getInputs())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getInputs())
|
for (Value operand : concatOp.getInputs())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||||
markOpToRemove(state, receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||||
markOpToRemove(state, receiveTensorOp);
|
markOpToRemove(receiveTensorOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
@@ -624,7 +620,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
size_t orderWithinReturn = it.index();
|
size_t orderWithinReturn = it.index();
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
|
||||||
|
|
||||||
struct ReturnPathState {
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
|
||||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class ReturnPathLoweringResult {
|
|
||||||
Handled,
|
|
||||||
NotReturnPath,
|
|
||||||
Failure
|
|
||||||
};
|
|
||||||
|
|
||||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
|
||||||
mlir::IRRewriter& rewriter,
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
|
||||||
mlir::OpResult result,
|
|
||||||
mlir::Value yieldValue,
|
|
||||||
ReturnPathState& state,
|
|
||||||
mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
|
||||||
mlir::Value producedValue,
|
|
||||||
mlir::Value storedValue,
|
|
||||||
ReturnPathState& state,
|
|
||||||
mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -23,54 +23,28 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
#include "Pass/PIMPasses.h"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
using namespace pim;
|
using namespace pim;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
namespace raptor {
|
||||||
namespace {
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||||
|
|
||||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
} // namespace raptor
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
||||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
||||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
||||||
|
|
||||||
SpatialToPimPass() = default;
|
|
||||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
||||||
|
|
||||||
void runOnOperation() final;
|
|
||||||
|
|
||||||
private:
|
|
||||||
SmallVector<OutputTensorFactory> outputTensors;
|
|
||||||
size_t coreId = 0;
|
|
||||||
SmallVector<Operation*> operationsToRemove;
|
|
||||||
|
|
||||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void markOpToRemove(Operation* op);
|
|
||||||
void eraseOpsToRemove();
|
|
||||||
|
|
||||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||||
@@ -150,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
|||||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 0;
|
coreId = 0;
|
||||||
|
outputTensors.clear();
|
||||||
|
operationsToRemove.clear();
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
|
||||||
@@ -197,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
addReturnOutputBuffers(returnOp, rewriter);
|
||||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
|
||||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -217,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
|
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||||
markOpToRemove(computeBatchOp);
|
markOpToRemove(computeBatchOp);
|
||||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -266,7 +240,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||||
eraseOpsToRemove();
|
eraseOpsToRemove();
|
||||||
|
|
||||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||||
@@ -309,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
dumpModule(moduleOp, "pim0");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
OperationFolder constantFolder(funcOp.getContext());
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||||
@@ -343,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
OperationFolder constantFolder(funcOp.getContext());
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
|
|
||||||
@@ -387,18 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||||
if (!llvm::is_contained(operationsToRemove, op))
|
if (!llvm::is_contained(operationsToRemove, op))
|
||||||
operationsToRemove.push_back(op);
|
operationsToRemove.push_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::eraseOpsToRemove() {
|
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||||
for (Operation* op : operationsToRemove) {
|
for (Operation* op : operationsToRemove) {
|
||||||
op->dropAllUses();
|
op->dropAllUses();
|
||||||
op->erase();
|
op->erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace raptor {
|
||||||
|
|
||||||
|
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||||
|
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||||
|
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||||
|
|
||||||
|
SpatialToPimPass() = default;
|
||||||
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||||
|
|
||||||
|
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||||
|
size_t coreId = 0;
|
||||||
|
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||||
|
|
||||||
|
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
|
||||||
|
mlir::IRRewriter& rewriter,
|
||||||
|
mlir::OperationFolder& constantFolder);
|
||||||
|
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
enum class ReturnPathLoweringResult {
|
||||||
|
Handled,
|
||||||
|
NotReturnPath,
|
||||||
|
Failure
|
||||||
|
};
|
||||||
|
|
||||||
|
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
|
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||||
|
mlir::OpResult result,
|
||||||
|
mlir::Value yieldValue,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||||
|
mlir::Value producedValue,
|
||||||
|
mlir::Value storedValue,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void markOpToRemove(mlir::Operation* op);
|
||||||
|
void eraseOpsToRemove();
|
||||||
|
|
||||||
|
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace raptor
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
Reference in New Issue
Block a user