Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -7,11 +7,10 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -370,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
@@ -391,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
@@ -431,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -15,13 +16,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
||||
strides[i] = strides[i + 1] * shape[i + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
@@ -8,10 +8,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -136,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
if (gemmOpAdaptor.getTransA()) {
|
||||
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
@@ -175,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
@@ -199,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
@@ -258,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!bType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||
// Not a gemv
|
||||
@@ -341,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
@@ -388,14 +429,28 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
if (gemmOpAdaptor.getTransA()) {
|
||||
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
||||
if (!aType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||
return failure();
|
||||
}
|
||||
if (!bType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
if (numOutRows <= 1)
|
||||
@@ -438,7 +493,14 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
});
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
if (!cType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||
return failure();
|
||||
}
|
||||
if (cType.getRank() != 2) {
|
||||
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||
return failure();
|
||||
}
|
||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||
return failure();
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -6,13 +6,12 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
template <typename PoolOp>
|
||||
static FailureOr<Value>
|
||||
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
if (values.empty()) {
|
||||
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||
return failure();
|
||||
}
|
||||
return createSpatConcat(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||
static FailureOr<Value>
|
||||
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
|
||||
if (windowValues.empty()) {
|
||||
op->emitOpError("pool window resolved to zero valid elements");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static Value
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||
static FailureOr<Value>
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
return failure();
|
||||
}
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
|
||||
if (failed(reducedWindow))
|
||||
return failure();
|
||||
Value reducedWindowValue = *reducedWindow;
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
|
||||
if (failed(scaledWindow))
|
||||
return failure();
|
||||
reducedWindowValue = *scaledWindow;
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindow);
|
||||
outputChannelTiles.push_back(reducedWindowValue);
|
||||
}
|
||||
|
||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
|
||||
if (failed(rowPixel))
|
||||
return failure();
|
||||
rowPixels.push_back(*rowPixel);
|
||||
}
|
||||
|
||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
|
||||
if (failed(row))
|
||||
return failure();
|
||||
rows.push_back(*row);
|
||||
}
|
||||
|
||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
|
||||
if (failed(batchResult))
|
||||
return failure();
|
||||
batchResults.push_back(*batchResult);
|
||||
}
|
||||
|
||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
|
||||
if (failed(pooledOutput))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user