simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
@@ -72,9 +73,37 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va
|
|||||||
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::optional<int64_t> getConstantIndexValue(Value value) {
|
||||||
|
if (auto constantIndex = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||||
|
return constantIndex.value();
|
||||||
|
|
||||||
|
APInt constantValue;
|
||||||
|
if (matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return constantValue.getSExtValue();
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||||
|
|
||||||
|
SmallVector<Attribute> operandConstants;
|
||||||
|
operandConstants.reserve(operands.size());
|
||||||
|
for (Value operand : operands) {
|
||||||
|
std::optional<int64_t> constantValue = getConstantIndexValue(operand);
|
||||||
|
if (!constantValue)
|
||||||
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||||
|
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Attribute> foldedResults;
|
||||||
|
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||||
|
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
||||||
|
if (constantResult)
|
||||||
|
return createIndexConstant(rewriter, constantResult.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +144,15 @@ static Value createGemmBatchKOffset(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value createGemmBatchHOffset(
|
static Value createGemmBatchHOffset(
|
||||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
Value lane,
|
||||||
|
int64_t numOutRows,
|
||||||
|
int64_t numKSlices,
|
||||||
|
int64_t numOutHSlices,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (numOutHSlices == 1)
|
||||||
|
return createIndexConstant(rewriter, 0);
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
MLIRContext* context = rewriter.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
return createAffineApply(
|
return createAffineApply(
|
||||||
@@ -290,6 +327,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
RankedTensorType partialPiecesType,
|
RankedTensorType partialPiecesType,
|
||||||
int64_t numOutRows,
|
int64_t numOutRows,
|
||||||
int64_t numKSlices,
|
int64_t numKSlices,
|
||||||
|
int64_t numOutHSlices,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||||
@@ -314,7 +352,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
|
|
||||||
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
|
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
|
||||||
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
||||||
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||||
|
|
||||||
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||||
auto bTileType = RankedTensorType::get(
|
auto bTileType = RankedTensorType::get(
|
||||||
@@ -610,7 +648,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
auto partialPiecesType =
|
auto partialPiecesType =
|
||||||
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
|
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
|
||||||
auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, rewriter, loc);
|
auto batchOp =
|
||||||
|
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||||
auto reductionCompute = createReductionCompute(
|
auto reductionCompute = createReductionCompute(
|
||||||
batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
|
batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
|
||||||
|
|
||||||
|
|||||||
@@ -1019,6 +1019,27 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createAffineApplyOrConstant(
|
||||||
|
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
|
||||||
|
SmallVector<Attribute> operandConstants;
|
||||||
|
operandConstants.reserve(operands.size());
|
||||||
|
for (Value operand : operands) {
|
||||||
|
auto constantValue = getConstantIntValue(operand);
|
||||||
|
if (!constantValue)
|
||||||
|
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||||
|
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Attribute> foldedResults;
|
||||||
|
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||||
|
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
||||||
|
if (constantResult)
|
||||||
|
return createIndexConstant(state, anchor, constantResult.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
||||||
MLIRContext* context = state.func.getContext();
|
MLIRContext* context = state.func.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
@@ -1033,7 +1054,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
|
|||||||
}
|
}
|
||||||
|
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {index}).getResult();
|
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createIndexedIndexValue(
|
Value createIndexedIndexValue(
|
||||||
@@ -3346,7 +3367,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
|
|||||||
|
|
||||||
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
||||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}).getResult();
|
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createBatchClassRunSourceLane(MaterializerState& state,
|
Value createBatchClassRunSourceLane(MaterializerState& state,
|
||||||
|
|||||||
Reference in New Issue
Block a user