refactor
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-29 14:00:10 +02:00
parent e8f09fd67f
commit f492400eda
37 changed files with 1407 additions and 1898 deletions
@@ -10,6 +10,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp
Patterns/Math/ConvGeometry.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
Patterns/Math/MatMul.cpp
@@ -30,7 +31,7 @@ add_pim_library(OMONNXToSpatial
LowerSpatialPlansPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
Common/MatrixProductLowering.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
@@ -2,7 +2,7 @@
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "MatrixProductLowering.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -1,45 +0,0 @@
#include <algorithm>
#include "IndexingUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
} // namespace onnx_mlir
@@ -1,20 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
} // namespace onnx_mlir
@@ -0,0 +1,48 @@
#include "MatrixProductLowering.hpp"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
Value createPaddedInputCompute(Value input,
RankedTensorType paddedInputType,
PatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
mlir::Value createZeroPaddedTensor(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
mlir::Value createPaddedInputCompute(mlir::Value input,
mlir::RankedTensorType paddedInputType,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -3,9 +3,6 @@
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -15,73 +12,6 @@ using namespace mlir;
namespace onnx_mlir {
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
@@ -1,89 +1,15 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -26,6 +26,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
struct ConvLoweringState {
Value x;
Value w;
Value b;
RankedTensorType xType;
RankedTensorType wType;
RankedTensorType outType;
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t padHeightBegin;
int64_t padHeightEnd;
int64_t padWidthBegin;
int64_t padWidthEnd;
int64_t strideHeight;
int64_t strideWidth;
int64_t dilationHeight;
int64_t dilationWidth;
bool hasBias;
};
struct ConvGeometry {
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t k;
int64_t c;
int64_t p;
int64_t xbarSize;
int64_t pack;
uint64_t im2colElements;
bool hasBias;
bool isDepthwise;
};
struct ConvLoweringDecision {
PimConvLoweringType strategy;
std::string reason;
@@ -108,19 +56,6 @@ struct PreparedConvInput {
RankedTensorType type;
};
struct RowInterval {
int64_t begin = 0;
int64_t end = 0;
};
struct ConvRowDemand {
RowInterval outputRows;
RowInterval neededInputRows;
RowInterval acquiredInputRows;
int64_t topHaloRows = 0;
int64_t bottomHaloRows = 0;
};
struct ConvStrategyEstimate {
uint64_t estimatedMvmCount = 0;
uint64_t estimatedReductionVAddCount = 0;
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
PatternRewriter& rewriter,
Location loc);
static bool
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
return estimate;
}
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
ConvGeometry geo {
state.batchSize,
state.numChannelsIn,
state.xHeight,
state.xWidth,
state.numChannelsOut,
state.wHeight,
state.wWidth,
state.outHeight,
state.outWidth,
state.group,
state.numChannelsInPerGroup,
state.numChannelsOutPerGroup,
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
state.numChannelsOutPerGroup,
state.batchSize * state.outHeight * state.outWidth,
static_cast<int64_t>(crossbarSize.getValue()),
1,
0,
state.hasBias,
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
};
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
return geo;
}
static std::string formatShape(ArrayRef<int64_t> dims) {
std::string text;
llvm::raw_string_ostream os(text);
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
return std::nullopt;
}
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
int64_t inputHeight,
int64_t kernelH,
int64_t strideH,
int64_t dilationH,
int64_t padTop) {
const int64_t rawBegin = outputRows.begin * strideH - padTop;
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
}
static bool covers(RowInterval acquired, RowInterval needed) {
return acquired.begin <= needed.begin && acquired.end >= needed.end;
}
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
ConvRowDemand demand;
demand.outputRows = outputRows;
demand.neededInputRows = neededInputRows;
demand.acquiredInputRows = neededInputRows;
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
return demand;
}
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
if (state.batchSize != 1) {
failureReason = "unsupported_batch";
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
rewriteConvLoweringReport(reportEntries);
}
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
}
return std::max<uint64_t>(1, chunkPositions);
}
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
});
}
static bool
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
}
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
assert(value > 0 && "expected positive value");
limit = std::min(value, limit);
@@ -0,0 +1,77 @@
#include "ConvGeometry.hpp"
#include <algorithm>
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
}
ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
ConvGeometry geo {
state.batchSize,
state.numChannelsIn,
state.xHeight,
state.xWidth,
state.numChannelsOut,
state.wHeight,
state.wWidth,
state.outHeight,
state.outWidth,
state.group,
state.numChannelsInPerGroup,
state.numChannelsOutPerGroup,
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
state.numChannelsOutPerGroup,
state.batchSize * state.outHeight * state.outWidth,
static_cast<int64_t>(crossbarSize.getValue()),
1,
0,
state.hasBias,
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
};
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
return geo;
}
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
}
return std::max<uint64_t>(1, chunkPositions);
}
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) {
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(state.xHeight, rawEnd)};
}
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
ConvRowDemand demand;
demand.outputRows = outputRows;
demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state);
demand.acquiredInputRows = demand.neededInputRows;
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
const int64_t rawEnd =
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
demand.acquiredInputRows = demand.neededInputRows;
return demand;
}
} // namespace onnx_mlir
@@ -0,0 +1,86 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include <cstdint>
namespace onnx_mlir {
struct ConvLoweringState {
mlir::Value x;
mlir::Value w;
mlir::Value b;
mlir::RankedTensorType xType;
mlir::RankedTensorType wType;
mlir::RankedTensorType outType;
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t padHeightBegin;
int64_t padHeightEnd;
int64_t padWidthBegin;
int64_t padWidthEnd;
int64_t strideHeight;
int64_t strideWidth;
int64_t dilationHeight;
int64_t dilationWidth;
bool hasBias;
};
struct ConvGeometry {
int64_t batchSize;
int64_t numChannelsIn;
int64_t xHeight;
int64_t xWidth;
int64_t numChannelsOut;
int64_t wHeight;
int64_t wWidth;
int64_t outHeight;
int64_t outWidth;
int64_t group;
int64_t numChannelsInPerGroup;
int64_t numChannelsOutPerGroup;
int64_t k;
int64_t c;
int64_t p;
int64_t xbarSize;
int64_t pack;
uint64_t im2colElements;
bool hasBias;
bool isDepthwise;
};
struct RowInterval {
int64_t begin = 0;
int64_t end = 0;
};
struct ConvRowDemand {
RowInterval outputRows;
RowInterval neededInputRows;
RowInterval acquiredInputRows;
int64_t topHaloRows = 0;
int64_t bottomHaloRows = 0;
};
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
} // namespace onnx_mlir
@@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane,
rewriter.getInsertionBlock()->getParentOp());
}
static Value
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
@@ -232,22 +210,6 @@ static Value extractATile(
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
}
static Value createPaddedInputCompute(Value input,
RankedTensorType paddedInputType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
Value b,
RankedTensorType aType,
@@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
return createONNXTranspose(resultType, {0, 2, 1});
}
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value createPaddedBatchedInputCompute(Value input,
RankedTensorType paddedInputType,
PatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> targetBatchShape,
@@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
auto paddedRhs =
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
if (succeeded(paddedRhs)) {
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
shapeInfo->outType.getElementType());
@@ -29,100 +29,6 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
});
}
static bool isMaterializableExternalTensorOp(Operation* op) {
return isa<spatial::SpatChannelReceiveOp,
spatial::SpatExtractRowsOp,
tensor::ExtractSliceOp,
tensor::ExpandShapeOp,
tensor::CollapseShapeOp>(op);
}
//TODO REMOVE THIS UGLY FIX
//TODO: Remove this helper once compute_batch external tensor captures are
// fixed at the producer side.
//
// This function is a temporary SpatialToPim repair path. It clones selected
// external tensor producers, such as channel_receive and tensor view/slice ops,
// into the new pim.core_batch body when the old spat.compute_batch body refers
// to tensor values defined outside the batch.
//
// The real invariant should be stronger:
//
// A spat.compute_batch body must not capture external tensor values.
// Every tensor used inside the body must be either:
// - a compute_batch block argument,
// - defined inside the compute_batch body,
// - or a legal constant-like value.
//
// If this invariant is violated, the responsible producer, most likely merge
// schedule materialization, should emit verifier-clean Spatial IR instead of
// relying on SpatialToPim to clone external producer chains later.
//
// After that producer-side fix:
// 1. remove isMaterializableExternalTensorOp,
// 2. remove materializeExternalTensorValue,
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
// tensor operand,
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
// before SpatialToPim.
//
// Be careful not to replace every external tensor capture with a normal
// compute_batch input blindly: host-backed tensors and explicit inter-core
// communication have different semantics. In particular, channel_receive-like
// values should be materialized through the communication model, not silently
// treated as host inputs.
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
Location loc,
Block& oldBlock,
Value value,
IRMapping& mapper) {
if (mapper.contains(value))
return mapper.lookup(value);
if (!isa<TensorType>(value.getType()))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
return failure();
if (definingOp->getBlock() == &oldBlock)
return failure();
if (!isMaterializableExternalTensorOp(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
if (succeeded(materializedOperand))
mapper.map(operand, *materializedOperand);
}
Operation* cloned = rewriter.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
if (failed(checkedCoreId))
return failure();
coreIds.push_back(*checkedCoreId);
++fallbackCoreId;
}
return coreIds;
}
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
@@ -386,7 +292,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
auto coreIds = getRequiredScheduledBatchCoreIds(computeBatchOp, "spatial compute_batch core id");
if (failed(coreIds))
return failure();
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -638,9 +544,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
continue;
InFlightDiagnostic diagnostic =
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
@@ -9,6 +9,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
@@ -141,17 +142,6 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
}
}
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
if (failed(checkedCoreId))
return failure();
++fallbackCoreId;
return *checkedCoreId;
}
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
@@ -311,7 +301,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCom
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId);
auto checkedCoreId = getRequiredScheduledCoreId(computeOp, "spatial compute core id");
if (failed(checkedCoreId))
return failure();
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
@@ -44,121 +44,29 @@ using namespace pim;
namespace onnx_mlir {
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
continue;
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
return globalOp;
}
std::string nameStem;
llvm::raw_string_ostream nameStream(nameStem);
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
nameStream.flush();
std::string symbolName = nameStem;
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
symbolName = (nameStem + "_" + Twine(suffix++)).str();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
zeroAttr,
rewriter.getUnitAttr(),
IntegerAttr {});
}
static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr =
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
.getOutput();
}
static bool isHostBackedMemRefValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return isa<memref::GetGlobalOp>(definingOp);
}
return false;
}
static bool isHostBackedTensorValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return false;
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
extractSliceOp.getMixedOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
return false;
}
value = extractSliceOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
return isHostBackedMemRefValue(toTensorOp.getBuffer());
return false;
}
return false;
}
static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
createZeroPaddedTensor(IRRewriter& rewriter, Location loc, Value value, RankedTensorType resultType) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector");
@@ -169,26 +77,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
if (failed(zeroed))
return failure();
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr))
return failure();
if (isHostBackedTensorValue(vector)) {
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
return createZeroPaddedTensor(rewriter, loc, vector, paddedType);
}
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0;
outputTensors.clear();
operationsToRemove.clear();
ModuleOp moduleOp = getOperation();
@@ -362,7 +254,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
}
LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
bool hasFailure = false;
funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
@@ -371,7 +262,7 @@ LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func:
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp);
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
if (failed(paddedInput)) {
hasFailure = true;
return WalkResult::interrupt();
@@ -36,7 +36,6 @@ private:
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
llvm::SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);