This commit is contained in:
@@ -7,8 +7,8 @@
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -152,8 +152,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
@@ -167,7 +167,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
@@ -192,7 +192,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
|
||||
hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc);
|
||||
hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc);
|
||||
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -6,9 +6,9 @@
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -148,15 +148,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
} // namespace
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -179,7 +176,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -200,7 +197,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +208,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||
ReturnPathLoweringResult returnPathResult =
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||
return failure();
|
||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||
@@ -240,7 +236,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
@@ -249,7 +245,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -261,8 +257,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
mlir::OperationFolder& constantFolder;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -44,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
@@ -390,9 +385,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
|
||||
} // namespace
|
||||
|
||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Value currentReturnValue = returnValue;
|
||||
@@ -427,8 +420,8 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||
Location loc = producerOp->getLoc();
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
@@ -437,13 +430,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -464,7 +457,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -480,11 +473,11 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
@@ -505,7 +498,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
@@ -553,12 +546,15 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult
|
||||
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
OpResult result,
|
||||
Value yieldValue,
|
||||
IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
return;
|
||||
@@ -575,13 +571,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
|
||||
if (isReturnHelperChainOp(op)) {
|
||||
Value source = op->getOperand(0);
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(state, computeOp);
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
@@ -589,33 +585,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||
markOpToRemove(state, receiveOp);
|
||||
markOpToRemove(receiveOp);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
@@ -624,7 +620,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
struct ReturnPathState {
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -23,54 +23,28 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Pass/PIMPasses.h"
|
||||
#include "SpatialToPimPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
namespace raptor {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace raptor
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
@@ -150,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
@@ -197,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
addReturnOutputBuffers(returnOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -217,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -266,7 +240,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||
eraseOpsToRemove();
|
||||
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
@@ -309,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
@@ -343,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
@@ -387,18 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::eraseOpsToRemove() {
|
||||
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||
for (Operation* op : operationsToRemove) {
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace raptor {
|
||||
|
||||
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(mlir::Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user