replace old convolution support in spatial (WIP)
This commit is contained in:
Submodule onnx-mlir updated: 82018d7ce5...eb54c2afc4
@@ -3,7 +3,9 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.hpp
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.hpp
|
||||
Math/Conv.cpp
|
||||
Math/ExperimentalConv.cpp
|
||||
Math/ExperimentalGemm.cpp
|
||||
|
||||
@@ -1,583 +1,247 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conv.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
// NOTE:
|
||||
// This might be useful to re-implement this considering for loops.
|
||||
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
/**
|
||||
* @brief A momentary representation of a core, to be used within the tiling of
|
||||
* a convolution operation.
|
||||
*/
|
||||
class Core {
|
||||
public:
|
||||
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
|
||||
: coreId(coreId), rewriter(rewriter) {}
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
/**
|
||||
* @brief Add a MVM operation to the core.
|
||||
*
|
||||
* @param inputTile The input tile to the MVM operation.
|
||||
* @param xbarIndex The index of the crossbar weight to use.
|
||||
* @param outputTileId The id of the output tile.
|
||||
* @param mvmOutType The result's shape.
|
||||
* @return Value The result of the MVM operation.
|
||||
*/
|
||||
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
||||
// Use the inputTile as the reference location for the MVM operation.
|
||||
Location loc = inputTile.getLoc();
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1);
|
||||
|
||||
// Move the insertion point to the end of the block.
|
||||
rewriter.setInsertionPointToEnd(block.get());
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
|
||||
// Add the inputTile to the block arguments, and to the operands.
|
||||
Value operand = operandMap.lookupOrNull(inputTile);
|
||||
if (not operand) {
|
||||
operand = block->addArgument(inputTile.getType(), loc);
|
||||
operands.push_back(inputTile);
|
||||
operandMap.map(inputTile, operand);
|
||||
}
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
|
||||
// is correct.
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
||||
|
||||
// Construct the MVM operation
|
||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
// Since we are within the same core and no computation can happen in
|
||||
// paralllel, we can just apply a linear reduction in case we have multiple
|
||||
// MVM operations for the same outputTile.
|
||||
auto lastMVM = outputTileToMVM.find(outputTileId);
|
||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
||||
|
||||
// If an entry for this outputTile already exists, apply reduction.
|
||||
if (lastMVM != outputTileToMVM.end()) {
|
||||
// MVM results should have the same type for reduction.
|
||||
assert(lastMVM->second.getType() == result.getType());
|
||||
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
|
||||
}
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
outputTileToMVM[outputTileId] = result;
|
||||
return result;
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64(*padsAttr, 0);
|
||||
padWidthBegin = getI64(*padsAttr, 1);
|
||||
padHeightEnd = getI64(*padsAttr, 2);
|
||||
padWidthEnd = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
/**
|
||||
* @brief Mark a result as remappable, and return a shared pointer to it.
|
||||
*
|
||||
* This function marks a result as remappable, and returns a shared pointer to
|
||||
* it. We need to keep track of these values to generate the YieldOp at a
|
||||
* later stage.
|
||||
*
|
||||
* @param result A result to track, for later remapping.
|
||||
* @return shared_ptr<Value> A shared pointer to the result.
|
||||
*/
|
||||
shared_ptr<Value> makeResultRemappable(Value result) {
|
||||
// Verify that the result is present in the block.
|
||||
assert(result.getDefiningOp()->getBlock() == block.get());
|
||||
|
||||
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
|
||||
|
||||
resultsToRemap.push_back(remappableResult);
|
||||
results.push_back(result);
|
||||
|
||||
return remappableResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add a remappable operand to the core, to merge partial results
|
||||
* inter-core.
|
||||
*
|
||||
* @param remappableOperand The operand to add.
|
||||
* @return Value The block argument representing the operand.
|
||||
*/
|
||||
Value addRemappableOperand(std::shared_ptr<Value> operand) {
|
||||
// Check that the operand is not already there.
|
||||
assert(not operandMap.contains(*operand));
|
||||
|
||||
Value argument = block->addArgument(operand->getType(), operand->getLoc());
|
||||
remappableOperands.push_back(operand);
|
||||
return argument;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
|
||||
*
|
||||
* @param loc The location of the operation.
|
||||
* @return spatial::SpatWeightedCompute
|
||||
*/
|
||||
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
|
||||
// Get the shape of the results.
|
||||
SmallVector<Type> resultTypes;
|
||||
for (const auto& value : results)
|
||||
resultTypes.push_back(value.getType());
|
||||
|
||||
// Create the WComputeOp, with non-remappable operands only.
|
||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
|
||||
|
||||
// Add the body to the WComputeOp.
|
||||
Block* releasedBlock = block.release();
|
||||
wcomputeOp.getBody().push_back(releasedBlock);
|
||||
|
||||
// Add the `yieldOp` at the end, with the results.
|
||||
rewriter.setInsertionPointToEnd(releasedBlock);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, results);
|
||||
|
||||
return wcomputeOp;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Remap the results to the WComputeOp results.
|
||||
*/
|
||||
void remapResults() {
|
||||
// Remap all the results to the WComputeOp results.
|
||||
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
|
||||
for (size_t i = 0; i < resultsToRemap.size(); i++)
|
||||
*resultsToRemap[i] = wcomputeOp.getResult(i);
|
||||
}
|
||||
|
||||
void addRemappedOperands() {
|
||||
// Insert the remappableOperands (which were remapped in
|
||||
// `addRemappableOperand` of another Core)
|
||||
for (auto remappedValue : remappableOperands)
|
||||
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
|
||||
|
||||
// Update the wcomputeOp operandSegmentSize
|
||||
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
|
||||
}
|
||||
|
||||
size_t addXbarWeight(Value weight) {
|
||||
assert(!isXbarsFull());
|
||||
xbarWeights.push_back(weight);
|
||||
return xbarWeights.size() - 1;
|
||||
}
|
||||
|
||||
bool isXbarsFull() {
|
||||
assert(xbarWeights.size() <= crossbarCountInCore);
|
||||
return xbarWeights.size() == crossbarCountInCore;
|
||||
}
|
||||
|
||||
bool isCoreEmpty() { return block->empty(); }
|
||||
|
||||
void dump() {
|
||||
// Print the coreId
|
||||
llvm::outs() << "Core " << coreId << ":\n";
|
||||
// Print the weights
|
||||
llvm::outs() << "Xbar Weights:\n";
|
||||
for (auto weight : xbarWeights)
|
||||
weight.dump();
|
||||
// Print the operands
|
||||
llvm::outs() << "Operands:\n";
|
||||
for (auto operand : operands)
|
||||
llvm::outs() << operand << "\n";
|
||||
|
||||
// Dump the body block
|
||||
for (auto& op : block->getOperations())
|
||||
op.dump();
|
||||
|
||||
// Print the results
|
||||
llvm::outs() << "Results:\n";
|
||||
for (auto result : results)
|
||||
llvm::outs() << result << "\n";
|
||||
}
|
||||
|
||||
const size_t coreId;
|
||||
|
||||
private:
|
||||
ConversionPatternRewriter& rewriter;
|
||||
|
||||
// Should these be set<Value> instead? But I need to keep the order
|
||||
vector<Value> operands;
|
||||
vector<std::shared_ptr<Value>> remappableOperands;
|
||||
|
||||
vector<Value> results;
|
||||
vector<std::shared_ptr<Value>> resultsToRemap;
|
||||
|
||||
// Maps from input tiles to the block operand
|
||||
IRMapping operandMap;
|
||||
|
||||
// Map from outputTileId to MVM operation producing it
|
||||
unordered_map<size_t, Value> outputTileToMVM;
|
||||
|
||||
vector<Value> xbarWeights;
|
||||
|
||||
unique_ptr<mlir::Block> block = make_unique<Block>();
|
||||
|
||||
spatial::SpatWeightedCompute wcomputeOp;
|
||||
};
|
||||
|
||||
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ONNXConvOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
struct Producer_t {
|
||||
Value value;
|
||||
shared_ptr<Core> core;
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
|
||||
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
|
||||
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
|
||||
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
|
||||
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
|
||||
|
||||
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
|
||||
|
||||
// TODO: Pad value at beginning and end of each dimension could be
|
||||
// different. We should handle this case.
|
||||
|
||||
// MapOperations mapOperation = MapOperations::None;
|
||||
//
|
||||
// // If we have just one user, and it is an activation funcion (or more in
|
||||
// // general a mapping operation) just inline it in the computeOps
|
||||
// auto firstUserOp = *conv->getUsers().begin();
|
||||
// if (conv->hasOneUse()) {
|
||||
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
|
||||
//
|
||||
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// conv, "Softmax not supported as activation for convolutions.");
|
||||
// }
|
||||
// }
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
||||
|
||||
Location loc = conv.getLoc();
|
||||
|
||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
||||
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
|
||||
|
||||
// Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt = resolveImgInputTiles(
|
||||
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
|
||||
|
||||
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
|
||||
// Tile the weight tensor
|
||||
// Weight tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Channel Tile
|
||||
// c. Kernel `x` position
|
||||
// d. Kernel `y` position
|
||||
// For example: weightTiles[filterTile][channelTile][x][y]
|
||||
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
||||
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
sizes = {rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
for (size_t i = 0; i < outputTileCount; i++) {
|
||||
if (i == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
|
||||
sizes[1] = rewriter.getIndexAttr(crossbarSize);
|
||||
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
|
||||
for (size_t j = 0; j < inputTileCount; j++) {
|
||||
if (j == inputTileCount - 1 && inputTileRemainder != 0)
|
||||
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
|
||||
for (size_t x = 0; x < krn_w; x++) {
|
||||
for (size_t y = 0; y < krn_h; y++) {
|
||||
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
weightTiles[i][j][x][y] =
|
||||
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
|
||||
/* Distribute the computation among many compute cores
|
||||
* Try to compute in-core the computation for each output tile, and reduce
|
||||
* over as few cores as possible
|
||||
*/
|
||||
|
||||
// Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[filterTile][x][y]
|
||||
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
|
||||
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
||||
|
||||
size_t replicationFactor;
|
||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
|
||||
replicationFactor = 1;
|
||||
else
|
||||
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
||||
// producers[outTile][out_x][out_y][producerIndex]
|
||||
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
|
||||
outputTileCount,
|
||||
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
||||
|
||||
// Schedule in cores
|
||||
size_t coreId = 0;
|
||||
vector<shared_ptr<Core>> curCores(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
|
||||
vector<shared_ptr<Core>> cores;
|
||||
|
||||
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
|
||||
|
||||
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
|
||||
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
|
||||
|
||||
RankedTensorType mvmOutType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
|
||||
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
|
||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
|
||||
|
||||
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
|
||||
|
||||
vector<size_t> xbarIndexes(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
|
||||
|
||||
size_t out_x = 0;
|
||||
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
|
||||
size_t out_y = 0;
|
||||
|
||||
// I use `replicationFactor` cores. I divide the input_w into
|
||||
// `replicationFactor` slices, and each slice is distributed to a
|
||||
// core. `coreIndex` is the index of the core that will be used
|
||||
// for this slice
|
||||
size_t coreIndex = in_x / replicationSliceSize;
|
||||
assert(coreIndex < replicationFactor);
|
||||
|
||||
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
|
||||
// Adjust the input based on the kernel
|
||||
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
|
||||
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
|
||||
|
||||
// Check if we are within the input image
|
||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
|
||||
out_y++;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
|
||||
auto mvm = curCores[coreIndex]->addMVM(
|
||||
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
|
||||
|
||||
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
|
||||
|
||||
out_y++;
|
||||
}
|
||||
out_x++;
|
||||
}
|
||||
|
||||
// Computations for these crossbars are done, check if the cores
|
||||
// crossbars are fully used. If full, swap with new core
|
||||
for (size_t i = 0; i < replicationFactor; i++) {
|
||||
if (curCores[i]->isXbarsFull()) {
|
||||
cores.emplace_back(std::move(curCores[i]));
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& curCore : curCores)
|
||||
if (curCore->isCoreEmpty() == false)
|
||||
cores.emplace_back(std::move(curCore));
|
||||
curCores.clear();
|
||||
// Now, do the reduction of each output pixel tile
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
for (size_t out_x = 0; out_x < output_w; out_x++) {
|
||||
for (size_t out_y = 0; out_y < output_h; out_y++) {
|
||||
// First, check if some producers are within the same core. If this is
|
||||
// true, `Core::addMVM` have already done the reduction within-core.
|
||||
// This means that we only need to consider the last producer for that
|
||||
// core.
|
||||
|
||||
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
|
||||
for (auto producer : producers[outTile][out_x][out_y])
|
||||
withinCoreReducedProducers[producer.core->coreId] = producer;
|
||||
|
||||
// Now, we need to apply inter-core reduction
|
||||
|
||||
// Base case with one producer
|
||||
if (withinCoreReducedProducers.size() == 1) {
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
auto singleProducer = withinCoreReducedProducers.begin()->second;
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: This is a linear reduction, not a tree reduction. We can do
|
||||
// better: a tree reduction would make more computations happen in
|
||||
// parallel.
|
||||
|
||||
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
|
||||
|
||||
auto it = withinCoreReducedProducers.begin();
|
||||
it++;
|
||||
while (it != withinCoreReducedProducers.end()) {
|
||||
|
||||
Producer_t curProducer = it->second;
|
||||
|
||||
shared_ptr<Core> core1;
|
||||
shared_ptr<Core> core2;
|
||||
Value core1Value;
|
||||
Value core2Value;
|
||||
|
||||
auto lastProducerCoreId = lastProducer.core->coreId;
|
||||
auto curProducerCoreId = curProducer.core->coreId;
|
||||
|
||||
assert(lastProducerCoreId != curProducerCoreId
|
||||
&& "We should have already applied within-core reduction, how "
|
||||
"could we have same cores here?");
|
||||
|
||||
// Sort the cores by coreId
|
||||
if (curProducerCoreId < lastProducerCoreId) {
|
||||
core1 = curProducer.core;
|
||||
core1Value = curProducer.value;
|
||||
core2 = lastProducer.core;
|
||||
core2Value = lastProducer.value;
|
||||
}
|
||||
else {
|
||||
core1 = lastProducer.core;
|
||||
core1Value = lastProducer.value;
|
||||
core2 = curProducer.core;
|
||||
core2Value = curProducer.value;
|
||||
}
|
||||
|
||||
auto newCoreRes = core1->makeResultRemappable(core1Value);
|
||||
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
|
||||
|
||||
rewriter.setInsertionPointAfterValue(core2Value);
|
||||
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
|
||||
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
|
||||
|
||||
lastProducer = {vaddRes, core2};
|
||||
|
||||
it++;
|
||||
}
|
||||
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
|
||||
rewriter.setInsertionPointAfter(conv);
|
||||
spatial::SpatWeightedCompute lastWComputeOp;
|
||||
for (auto& core : cores) {
|
||||
lastWComputeOp = core->createWComputeOp(loc);
|
||||
core->remapResults();
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
}
|
||||
|
||||
for (auto& core : cores)
|
||||
core->addRemappedOperands();
|
||||
|
||||
// Set the insertion point after the last WComputeOp.
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
SmallVector<Value> tilesToConcat;
|
||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
||||
for (size_t outX = 0; outX < output_h; outX++)
|
||||
for (size_t outY = 0; outY < output_w; outY++)
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
|
||||
|
||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
|
||||
|
||||
// Value outputImage =
|
||||
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
|
||||
|
||||
// If no mapping (activation) was applied, just replace ConvOp
|
||||
// if (mapOperation == MapOperations::None) {
|
||||
// rewriter.replaceOp(conv, outputImage);
|
||||
// } else {
|
||||
// // If mapping was applied, erase ConvOp and replace the mapping op
|
||||
// rewriter.eraseOp(conv);
|
||||
// rewriter.replaceOp(firstUserOp, outputImage);
|
||||
// }
|
||||
|
||||
return success();
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
};
|
||||
|
||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConvOpTile>(ctx);
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// Gemm output: [numPatches, cOut]
|
||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||
const int64_t numPatches = outHeight * outWidth;
|
||||
|
||||
auto elemType = xType.getElementType();
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, x, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
x = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each output position (oh, ow), extract the patch from x
|
||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, x, offsets, sizes, strides);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB) {
|
||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
biasType,
|
||||
b,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
// Gemm: A @ B + C = im2col @ W^T + b
|
||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2col,
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
Value gemmOut = gemmOp.getY();
|
||||
auto collectComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), gemmOut);
|
||||
|
||||
auto* collectBlock = new Block();
|
||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||
collectComputeOp.getBody().push_back(collectBlock);
|
||||
rewriter.setInsertionPointToStart(collectBlock);
|
||||
|
||||
auto gemmOutArg = collectBlock->getArguments().front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
|
||||
rewriter.replaceOp(convOp, collectComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
23
src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp
Normal file
23
src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ConvToGemm : mlir::OpConversionPattern<mlir::ONNXConvOp> {
|
||||
ConvToGemm(mlir::MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp convOp,
|
||||
mlir::ONNXConvOpAdaptor convOpAdaptor,
|
||||
mlir::ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
583
src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp
Normal file
583
src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp
Normal file
@@ -0,0 +1,583 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
// NOTE:
|
||||
// This might be useful to re-implement this considering for loops.
|
||||
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
||||
|
||||
/**
|
||||
* @brief A momentary representation of a core, to be used within the tiling of
|
||||
* a convolution operation.
|
||||
*/
|
||||
class Core {
|
||||
public:
|
||||
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
|
||||
: coreId(coreId), rewriter(rewriter) {}
|
||||
|
||||
/**
|
||||
* @brief Add a MVM operation to the core.
|
||||
*
|
||||
* @param inputTile The input tile to the MVM operation.
|
||||
* @param xbarIndex The index of the crossbar weight to use.
|
||||
* @param outputTileId The id of the output tile.
|
||||
* @param mvmOutType The result's shape.
|
||||
* @return Value The result of the MVM operation.
|
||||
*/
|
||||
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
||||
// Use the inputTile as the reference location for the MVM operation.
|
||||
Location loc = inputTile.getLoc();
|
||||
|
||||
// Move the insertion point to the end of the block.
|
||||
rewriter.setInsertionPointToEnd(block.get());
|
||||
|
||||
// Add the inputTile to the block arguments, and to the operands.
|
||||
Value operand = operandMap.lookupOrNull(inputTile);
|
||||
if (not operand) {
|
||||
operand = block->addArgument(inputTile.getType(), loc);
|
||||
operands.push_back(inputTile);
|
||||
operandMap.map(inputTile, operand);
|
||||
}
|
||||
|
||||
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
|
||||
// is correct.
|
||||
|
||||
// Construct the MVM operation
|
||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
|
||||
|
||||
// Since we are within the same core and no computation can happen in
|
||||
// paralllel, we can just apply a linear reduction in case we have multiple
|
||||
// MVM operations for the same outputTile.
|
||||
auto lastMVM = outputTileToMVM.find(outputTileId);
|
||||
|
||||
// If an entry for this outputTile already exists, apply reduction.
|
||||
if (lastMVM != outputTileToMVM.end()) {
|
||||
// MVM results should have the same type for reduction.
|
||||
assert(lastMVM->second.getType() == result.getType());
|
||||
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
|
||||
}
|
||||
|
||||
outputTileToMVM[outputTileId] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Mark a result as remappable, and return a shared pointer to it.
|
||||
*
|
||||
* This function marks a result as remappable, and returns a shared pointer to
|
||||
* it. We need to keep track of these values to generate the YieldOp at a
|
||||
* later stage.
|
||||
*
|
||||
* @param result A result to track, for later remapping.
|
||||
* @return shared_ptr<Value> A shared pointer to the result.
|
||||
*/
|
||||
shared_ptr<Value> makeResultRemappable(Value result) {
|
||||
// Verify that the result is present in the block.
|
||||
assert(result.getDefiningOp()->getBlock() == block.get());
|
||||
|
||||
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
|
||||
|
||||
resultsToRemap.push_back(remappableResult);
|
||||
results.push_back(result);
|
||||
|
||||
return remappableResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add a remappable operand to the core, to merge partial results
|
||||
* inter-core.
|
||||
*
|
||||
* @param remappableOperand The operand to add.
|
||||
* @return Value The block argument representing the operand.
|
||||
*/
|
||||
Value addRemappableOperand(std::shared_ptr<Value> operand) {
|
||||
// Check that the operand is not already there.
|
||||
assert(not operandMap.contains(*operand));
|
||||
|
||||
Value argument = block->addArgument(operand->getType(), operand->getLoc());
|
||||
remappableOperands.push_back(operand);
|
||||
return argument;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
|
||||
*
|
||||
* @param loc The location of the operation.
|
||||
* @return spatial::SpatWeightedCompute
|
||||
*/
|
||||
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
|
||||
// Get the shape of the results.
|
||||
SmallVector<Type> resultTypes;
|
||||
for (const auto& value : results)
|
||||
resultTypes.push_back(value.getType());
|
||||
|
||||
// Create the WComputeOp, with non-remappable operands only.
|
||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
|
||||
|
||||
// Add the body to the WComputeOp.
|
||||
Block* releasedBlock = block.release();
|
||||
wcomputeOp.getBody().push_back(releasedBlock);
|
||||
|
||||
// Add the `yieldOp` at the end, with the results.
|
||||
rewriter.setInsertionPointToEnd(releasedBlock);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, results);
|
||||
|
||||
return wcomputeOp;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Remap the results to the WComputeOp results.
|
||||
*/
|
||||
void remapResults() {
|
||||
// Remap all the results to the WComputeOp results.
|
||||
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
|
||||
for (size_t i = 0; i < resultsToRemap.size(); i++)
|
||||
*resultsToRemap[i] = wcomputeOp.getResult(i);
|
||||
}
|
||||
|
||||
void addRemappedOperands() {
|
||||
// Insert the remappableOperands (which were remapped in
|
||||
// `addRemappableOperand` of another Core)
|
||||
for (auto remappedValue : remappableOperands)
|
||||
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
|
||||
|
||||
// Update the wcomputeOp operandSegmentSize
|
||||
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
|
||||
}
|
||||
|
||||
size_t addXbarWeight(Value weight) {
|
||||
assert(!isXbarsFull());
|
||||
xbarWeights.push_back(weight);
|
||||
return xbarWeights.size() - 1;
|
||||
}
|
||||
|
||||
bool isXbarsFull() {
|
||||
assert(xbarWeights.size() <= crossbarCountInCore);
|
||||
return xbarWeights.size() == crossbarCountInCore;
|
||||
}
|
||||
|
||||
bool isCoreEmpty() { return block->empty(); }
|
||||
|
||||
void dump() {
|
||||
// Print the coreId
|
||||
llvm::outs() << "Core " << coreId << ":\n";
|
||||
// Print the weights
|
||||
llvm::outs() << "Xbar Weights:\n";
|
||||
for (auto weight : xbarWeights)
|
||||
weight.dump();
|
||||
// Print the operands
|
||||
llvm::outs() << "Operands:\n";
|
||||
for (auto operand : operands)
|
||||
llvm::outs() << operand << "\n";
|
||||
|
||||
// Dump the body block
|
||||
for (auto& op : block->getOperations())
|
||||
op.dump();
|
||||
|
||||
// Print the results
|
||||
llvm::outs() << "Results:\n";
|
||||
for (auto result : results)
|
||||
llvm::outs() << result << "\n";
|
||||
}
|
||||
|
||||
const size_t coreId;
|
||||
|
||||
private:
|
||||
ConversionPatternRewriter& rewriter;
|
||||
|
||||
// Should these be set<Value> instead? But I need to keep the order
|
||||
vector<Value> operands;
|
||||
vector<std::shared_ptr<Value>> remappableOperands;
|
||||
|
||||
vector<Value> results;
|
||||
vector<std::shared_ptr<Value>> resultsToRemap;
|
||||
|
||||
// Maps from input tiles to the block operand
|
||||
IRMapping operandMap;
|
||||
|
||||
// Map from outputTileId to MVM operation producing it
|
||||
unordered_map<size_t, Value> outputTileToMVM;
|
||||
|
||||
vector<Value> xbarWeights;
|
||||
|
||||
unique_ptr<mlir::Block> block = make_unique<Block>();
|
||||
|
||||
spatial::SpatWeightedCompute wcomputeOp;
|
||||
};
|
||||
|
||||
struct ConvToManyGemms : public OpConversionPattern<ONNXConvOp> {
|
||||
ConvToManyGemms(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
struct Producer_t {
|
||||
Value value;
|
||||
shared_ptr<Core> core;
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
|
||||
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
|
||||
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
|
||||
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
|
||||
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
|
||||
|
||||
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
|
||||
|
||||
// TODO: Pad value at beginning and end of each dimension could be
|
||||
// different. We should handle this case.
|
||||
|
||||
// MapOperations mapOperation = MapOperations::None;
|
||||
//
|
||||
// // If we have just one user, and it is an activation funcion (or more in
|
||||
// // general a mapping operation) just inline it in the computeOps
|
||||
// auto firstUserOp = *conv->getUsers().begin();
|
||||
// if (conv->hasOneUse()) {
|
||||
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
|
||||
//
|
||||
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// conv, "Softmax not supported as activation for convolutions.");
|
||||
// }
|
||||
// }
|
||||
|
||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
||||
|
||||
Location loc = conv.getLoc();
|
||||
|
||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
||||
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
|
||||
|
||||
// Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt = resolveImgInputTiles(
|
||||
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
|
||||
|
||||
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
|
||||
// Tile the weight tensor
|
||||
// Weight tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Channel Tile
|
||||
// c. Kernel `x` position
|
||||
// d. Kernel `y` position
|
||||
// For example: weightTiles[filterTile][channelTile][x][y]
|
||||
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
||||
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
sizes = {rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
for (size_t i = 0; i < outputTileCount; i++) {
|
||||
if (i == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
|
||||
sizes[1] = rewriter.getIndexAttr(crossbarSize);
|
||||
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
|
||||
for (size_t j = 0; j < inputTileCount; j++) {
|
||||
if (j == inputTileCount - 1 && inputTileRemainder != 0)
|
||||
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
|
||||
for (size_t x = 0; x < krn_w; x++) {
|
||||
for (size_t y = 0; y < krn_h; y++) {
|
||||
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
weightTiles[i][j][x][y] =
|
||||
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Distribute the computation among many compute cores
|
||||
* Try to compute in-core the computation for each output tile, and reduce
|
||||
* over as few cores as possible
|
||||
*/
|
||||
|
||||
// Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Filter Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[filterTile][x][y]
|
||||
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
|
||||
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
||||
|
||||
size_t replicationFactor;
|
||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
|
||||
replicationFactor = 1;
|
||||
else
|
||||
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
||||
// producers[outTile][out_x][out_y][producerIndex]
|
||||
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
|
||||
outputTileCount,
|
||||
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
||||
|
||||
// Schedule in cores
|
||||
size_t coreId = 0;
|
||||
vector<shared_ptr<Core>> curCores(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
|
||||
vector<shared_ptr<Core>> cores;
|
||||
|
||||
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
|
||||
|
||||
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
|
||||
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
|
||||
|
||||
RankedTensorType mvmOutType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
|
||||
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
|
||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
|
||||
|
||||
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
|
||||
|
||||
vector<size_t> xbarIndexes(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
|
||||
|
||||
size_t out_x = 0;
|
||||
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
|
||||
size_t out_y = 0;
|
||||
|
||||
// I use `replicationFactor` cores. I divide the input_w into
|
||||
// `replicationFactor` slices, and each slice is distributed to a
|
||||
// core. `coreIndex` is the index of the core that will be used
|
||||
// for this slice
|
||||
size_t coreIndex = in_x / replicationSliceSize;
|
||||
assert(coreIndex < replicationFactor);
|
||||
|
||||
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
|
||||
// Adjust the input based on the kernel
|
||||
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
|
||||
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
|
||||
|
||||
// Check if we are within the input image
|
||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
|
||||
out_y++;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
|
||||
auto mvm = curCores[coreIndex]->addMVM(
|
||||
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
|
||||
|
||||
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
|
||||
|
||||
out_y++;
|
||||
}
|
||||
out_x++;
|
||||
}
|
||||
|
||||
// Computations for these crossbars are done, check if the cores
|
||||
// crossbars are fully used. If full, swap with new core
|
||||
for (size_t i = 0; i < replicationFactor; i++) {
|
||||
if (curCores[i]->isXbarsFull()) {
|
||||
cores.emplace_back(std::move(curCores[i]));
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& curCore : curCores)
|
||||
if (curCore->isCoreEmpty() == false)
|
||||
cores.emplace_back(std::move(curCore));
|
||||
curCores.clear();
|
||||
// Now, do the reduction of each output pixel tile
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
for (size_t out_x = 0; out_x < output_w; out_x++) {
|
||||
for (size_t out_y = 0; out_y < output_h; out_y++) {
|
||||
// First, check if some producers are within the same core. If this is
|
||||
// true, `Core::addMVM` have already done the reduction within-core.
|
||||
// This means that we only need to consider the last producer for that
|
||||
// core.
|
||||
|
||||
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
|
||||
for (auto producer : producers[outTile][out_x][out_y])
|
||||
withinCoreReducedProducers[producer.core->coreId] = producer;
|
||||
|
||||
// Now, we need to apply inter-core reduction
|
||||
|
||||
// Base case with one producer
|
||||
if (withinCoreReducedProducers.size() == 1) {
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
auto singleProducer = withinCoreReducedProducers.begin()->second;
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: This is a linear reduction, not a tree reduction. We can do
|
||||
// better: a tree reduction would make more computations happen in
|
||||
// parallel.
|
||||
|
||||
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
|
||||
|
||||
auto it = withinCoreReducedProducers.begin();
|
||||
it++;
|
||||
while (it != withinCoreReducedProducers.end()) {
|
||||
|
||||
Producer_t curProducer = it->second;
|
||||
|
||||
shared_ptr<Core> core1;
|
||||
shared_ptr<Core> core2;
|
||||
Value core1Value;
|
||||
Value core2Value;
|
||||
|
||||
auto lastProducerCoreId = lastProducer.core->coreId;
|
||||
auto curProducerCoreId = curProducer.core->coreId;
|
||||
|
||||
assert(lastProducerCoreId != curProducerCoreId
|
||||
&& "We should have already applied within-core reduction, how "
|
||||
"could we have same cores here?");
|
||||
|
||||
// Sort the cores by coreId
|
||||
if (curProducerCoreId < lastProducerCoreId) {
|
||||
core1 = curProducer.core;
|
||||
core1Value = curProducer.value;
|
||||
core2 = lastProducer.core;
|
||||
core2Value = lastProducer.value;
|
||||
}
|
||||
else {
|
||||
core1 = lastProducer.core;
|
||||
core1Value = lastProducer.value;
|
||||
core2 = curProducer.core;
|
||||
core2Value = curProducer.value;
|
||||
}
|
||||
|
||||
auto newCoreRes = core1->makeResultRemappable(core1Value);
|
||||
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
|
||||
|
||||
rewriter.setInsertionPointAfterValue(core2Value);
|
||||
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
|
||||
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
|
||||
|
||||
lastProducer = {vaddRes, core2};
|
||||
|
||||
it++;
|
||||
}
|
||||
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
// Use last producer as the final result
|
||||
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
|
||||
rewriter.setInsertionPointAfter(conv);
|
||||
spatial::SpatWeightedCompute lastWComputeOp;
|
||||
for (auto& core : cores) {
|
||||
lastWComputeOp = core->createWComputeOp(loc);
|
||||
core->remapResults();
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
}
|
||||
|
||||
for (auto& core : cores)
|
||||
core->addRemappedOperands();
|
||||
|
||||
// Set the insertion point after the last WComputeOp.
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
SmallVector<Value> tilesToConcat;
|
||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
||||
for (size_t outX = 0; outX < output_h; outX++)
|
||||
for (size_t outY = 0; outY < output_w; outY++)
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
|
||||
|
||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
|
||||
|
||||
// Value outputImage =
|
||||
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
|
||||
|
||||
// If no mapping (activation) was applied, just replace ConvOp
|
||||
// if (mapOperation == MapOperations::None) {
|
||||
// rewriter.replaceOp(conv, outputImage);
|
||||
// } else {
|
||||
// // If mapping was applied, erase ConvOp and replace the mapping op
|
||||
// rewriter.eraseOp(conv);
|
||||
// rewriter.replaceOp(firstUserOp, outputImage);
|
||||
// }
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ConvToManyGemms>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -2,18 +2,16 @@
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "Gemm.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -23,394 +21,368 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
GemmToManyGemv(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 2) {}
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
assert("A should have been transposed already" && !adaptor.getTransA());
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
}
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location gemmLoc = gemmOp.getLoc();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
Value out = gemmOp.getY();
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
float alpha = adaptor.getAlpha().convertToFloat();
|
||||
float beta = adaptor.getBeta().convertToFloat();
|
||||
bool transA = adaptor.getTransA();
|
||||
bool transB = adaptor.getTransB();
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(out.getType());
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location gemmLoc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
Value out = gemmOp.getY();
|
||||
|
||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
||||
bool transA = gemmOpAdaptor.getTransA();
|
||||
bool transB = gemmOpAdaptor.getTransB();
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(out.getType());
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape()))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
auto outLastHSliceSize = cLastHSliceSize;
|
||||
|
||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
SmallVector<Value> cHSlices;
|
||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||
if (hasC)
|
||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
RankedTensorType outHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||
RankedTensorType outLastHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> outHSlices;
|
||||
outHSlices.reserve(outNumHSlices);
|
||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||
RankedTensorType currOutHSliceType = outHSliceType;
|
||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||
currOutHSliceType = outLastHSliceType;
|
||||
|
||||
SmallVector<Value> partialResults;
|
||||
partialResults.reserve(coresPerVSlice);
|
||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(aHSlices[coreId].size());
|
||||
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
for (auto aHSlice : aHSlices[coreId])
|
||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
auto computeArgs = computeBlock->getArguments();
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(computeArgs.size());
|
||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||
vmmOutputs.push_back(
|
||||
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
}
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
Value cHSlice = cHSlices[outSliceId];
|
||||
partialResults.push_back(cHSlice);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
auto reduceComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
auto* reduceBlock = new Block();
|
||||
for (auto partialResult : partialResults)
|
||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||
rewriter.setInsertionPointToStart(reduceBlock);
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
auto blockArgs = reduceBlock->getArguments();
|
||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
|
||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
auto outLastHSliceSize = cLastHSliceSize;
|
||||
|
||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
SmallVector<Value> cHSlices;
|
||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||
if (hasC)
|
||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
RankedTensorType outHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||
RankedTensorType outLastHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> outHSlices;
|
||||
outHSlices.reserve(outNumHSlices);
|
||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||
RankedTensorType currOutHSliceType = outHSliceType;
|
||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||
currOutHSliceType = outLastHSliceType;
|
||||
|
||||
SmallVector<Value> partialResults;
|
||||
partialResults.reserve(coresPerVSlice);
|
||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(aHSlices[coreId].size());
|
||||
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
for (auto aHSlice : aHSlices[coreId])
|
||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
auto computeArgs = computeBlock->getArguments();
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(computeArgs.size());
|
||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||
vmmOutputs.push_back(
|
||||
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
Value cHSlice = cHSlices[outSliceId];
|
||||
partialResults.push_back(cHSlice);
|
||||
}
|
||||
|
||||
auto reduceComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||
|
||||
auto* reduceBlock = new Block();
|
||||
for (auto partialResult : partialResults)
|
||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||
rewriter.setInsertionPointToStart(reduceBlock);
|
||||
|
||||
auto blockArgs = reduceBlock->getArguments();
|
||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
|
||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* Resolves the ONNXExpOp from the use chain of the given start value.
|
||||
*
|
||||
* This function traverses the use chain of the start value until it finds an
|
||||
* ONNXExpOp. It returns the value of the ONNXExpOp.
|
||||
*
|
||||
* @param startValue The starting value of the use chain.
|
||||
* @return The value of the ONNXExpOp found in the use chain.
|
||||
*/
|
||||
static Value resolveONNXExpOpFromUseChain(Value startValue) {
|
||||
Value walker = startValue;
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
||||
walker = walker.getDefiningOp()->getOperand(0);
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
assert(walker && walker.getDefiningOp()
|
||||
&& "Unwinded the whole chain of operations while trying to "
|
||||
"find ONNXExpOp, but did not find it");
|
||||
}
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
||||
&& "Old output tile (softmax reducer) is not produced by an "
|
||||
"ONNXExpOp");
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
return walker;
|
||||
Value GemvToSpatialCompute::resolveONNXExpOpFromUseChain(Value startValue) {
|
||||
Value walker = startValue;
|
||||
|
||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
||||
walker = walker.getDefiningOp()->getOperand(0);
|
||||
|
||||
assert(walker && walker.getDefiningOp()
|
||||
&& "Unwinded the whole chain of operations while trying to "
|
||||
"find ONNXExpOp, but did not find it");
|
||||
}
|
||||
|
||||
// Softmax is a special case, as it requires another reduction after the
|
||||
// first one. In the cores, `applyReducePattern` already applied
|
||||
// f(x) = exp(x) to each tile. This mean that now we just need to
|
||||
// reduce-sum these tiles, and then divide each tile by the reduced sum,
|
||||
// which is propagated back to the cores via a broadcast channel.
|
||||
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc) const {
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
||||
&& "Old output tile (softmax reducer) is not produced by an "
|
||||
"ONNXExpOp");
|
||||
|
||||
// TODO: Check case with one compute op
|
||||
return walker;
|
||||
}
|
||||
|
||||
// Cast vector of Value into vector of ComputeOp
|
||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
||||
}));
|
||||
LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc) {
|
||||
// TODO: Check case with one compute op
|
||||
|
||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
||||
// Cast vector of Value into vector of ComputeOp
|
||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
||||
}));
|
||||
|
||||
reducer.applyReducePattern(
|
||||
softmaxOpsToReduce,
|
||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
|
||||
/* preprocess = */
|
||||
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
|
||||
[&](Value softmaxDivisor) {
|
||||
// Signal that this is the compute with the softmax divisor
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
||||
|
||||
// Broadcast the divisor to all the cores
|
||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
|
||||
reducer.applyReducePattern(
|
||||
softmaxOpsToReduce,
|
||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
|
||||
/* preprocess = */
|
||||
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
|
||||
[&](Value softmaxDivisor) {
|
||||
// Signal that this is the compute with the softmax divisor
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
||||
|
||||
/*
|
||||
* softmaxDividend = onnx.exp (...)
|
||||
* sum = spat.SumOp(softmaxDividend)
|
||||
* [following can be repeated N times, thus walk the use chain]
|
||||
* softmaxDivisor = spat.sadd(sum, ...)
|
||||
*/
|
||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
||||
// Broadcast the divisor to all the cores
|
||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
||||
/*
|
||||
* softmaxDividend = onnx.exp (...)
|
||||
* sum = spat.SumOp(softmaxDividend)
|
||||
* [following can be repeated N times, thus walk the use chain]
|
||||
* softmaxDivisor = spat.sadd(sum, ...)
|
||||
*/
|
||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
||||
|
||||
// Do not divide here, divide after this
|
||||
return softmaxDivisor;
|
||||
});
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
||||
|
||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
||||
// the reduced denominator.
|
||||
outputOpsAndResNums.clear();
|
||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
||||
// Do not divide here, divide after this
|
||||
return softmaxDivisor;
|
||||
});
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
||||
// the reduced denominator.
|
||||
outputOpsAndResNums.clear();
|
||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
||||
|
||||
Value divisor;
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
||||
|
||||
// Check if this compute contains the softmax divisor: if so, find the
|
||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
||||
// using ChannelBroadcastReceiveOp
|
||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
||||
Value divisor;
|
||||
|
||||
bool found = false;
|
||||
for (auto broadcastOp :
|
||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
||||
assert(found == false
|
||||
&& "More than one ChannelBroadcastSendOp in "
|
||||
"compute? How is this possible?");
|
||||
found = true;
|
||||
// Check if this compute contains the softmax divisor: if so, find the
|
||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
||||
// using ChannelBroadcastReceiveOp
|
||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
||||
|
||||
divisor = broadcastOp.getData();
|
||||
}
|
||||
bool found = false;
|
||||
for (auto broadcastOp :
|
||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
||||
assert(found == false
|
||||
&& "More than one ChannelBroadcastSendOp in "
|
||||
"compute? How is this possible?");
|
||||
found = true;
|
||||
|
||||
assert(found
|
||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
||||
"divisor was specified to be?");
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
|
||||
divisor = broadcastOp.getData();
|
||||
}
|
||||
|
||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||
// needed because some some may have a different amount of `VAddOp`s due
|
||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
||||
// multiples)
|
||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||
|
||||
assert(found
|
||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
||||
"divisor was specified to be?");
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||
|
||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
||||
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
|
||||
}
|
||||
|
||||
return success();
|
||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||
// needed because some some may have a different amount of `VAddOp`s due
|
||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
||||
// multiples)
|
||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||
|
||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
||||
}
|
||||
};
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
|
||||
54
src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp
Normal file
54
src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
constexpr mlir::StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
struct GemmToManyGemv : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
|
||||
GemmToManyGemv(mlir::MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 2) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
|
||||
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
mlir::ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : mlir::OpConversionPattern<mlir::ONNXGemmOp> {
|
||||
GemvToSpatialCompute(mlir::MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
|
||||
llvm::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp,
|
||||
mlir::ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
mlir::ConversionPatternRewriter& rewriter) const override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* Resolves the ONNXExpOp from the use chain of the given start value.
|
||||
*
|
||||
* This function traverses the use chain of the start value until it finds an
|
||||
* ONNXExpOp. It returns the value of the ONNXExpOp.
|
||||
*
|
||||
* @param startValue The starting value of the use chain.
|
||||
* @return The value of the ONNXExpOp found in the use chain.
|
||||
*/
|
||||
static mlir::Value resolveONNXExpOpFromUseChain(mlir::Value startValue);
|
||||
|
||||
// Softmax is a special case, as it requires another reduction after the
|
||||
// first one. In the cores, `applyReducePattern` already applied
|
||||
// f(x) = exp(x) to each tile. This mean that now we just need to
|
||||
// reduce-sum these tiles, and then divide each tile by the reduced sum,
|
||||
// which is propagated back to the cores via a broadcast channel.
|
||||
static llvm::LogicalResult softmaxReductionApplication(llvm::SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc);
|
||||
};
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "Math/Conv.hpp"
|
||||
#include "ONNXToSpatialPass.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
|
||||
@@ -6,7 +6,6 @@ namespace onnx_mlir {
|
||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -29,12 +30,12 @@ void SpatialToPIMPass::runOnOperation() {
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
BIN
validation/operations/conv/batch_64/conv_batch_64.onnx
Normal file
BIN
validation/operations/conv/batch_64/conv_batch_64.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/simple/conv.onnx
Normal file
BIN
validation/operations/conv/simple/conv.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/with_constant/conv_with_constant.onnx
Normal file
BIN
validation/operations/conv/with_constant/conv_with_constant.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user