normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-30 16:37:28 +02:00
parent 7c3943bd06
commit ab63498f3f
14 changed files with 340 additions and 278 deletions
+1
View File
@@ -1,4 +1,5 @@
add_pim_library(OMPimCommon add_pim_library(OMPimCommon
IR/AffineUtils.cpp
IR/AddressAnalysis.cpp IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp IR/ConstantUtils.cpp
+182
View File
@@ -0,0 +1,182 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "AffineUtils.hpp"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs < 0)
--quotient;
return quotient;
}
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs > 0)
++quotient;
return quotient;
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
}
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
if (multiplier == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
}
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.mod divisor");
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
}
Value affineFloorDivConst(
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
if (divisor == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
return failure();
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = resolver(operand);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
}
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
bool isDimAndConstantAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
} // namespace onnx_mlir
+55
View File
@@ -0,0 +1,55 @@
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
namespace onnx_mlir {
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineMap map,
mlir::ValueRange operands,
mlir::Operation* constantAnchor);
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange dims,
mlir::Operation* constantAnchor);
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t multiplier,
mlir::Operation* constantAnchor);
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
llvm::FailureOr<int64_t>
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
} // namespace onnx_mlir
+19 -18
View File
@@ -1,10 +1,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "ConstantUtils.hpp" #include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -75,24 +73,27 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
} }
Value createAffineApplyOrFoldedConstant( std::optional<int64_t> matchConstantIndexValue(Value value) {
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) { if (!value || !value.getType().isIndex())
SmallVector<Attribute> operandConstants; return std::nullopt;
operandConstants.reserve(operands.size());
for (Value operand : operands) { if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
APInt constantValue; return constant.value();
if (!matchPattern(operand, m_ConstantInt(&constantValue)))
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); if (auto constant = value.getDefiningOp<arith::ConstantOp>())
operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue())); if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
} }
SmallVector<Attribute> foldedResults; std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (succeeded(map.constantFold(operandConstants, foldedResults))) { if (auto attr = dyn_cast<Attribute>(value))
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front())) if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt()); return intAttr.getInt();
} if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); return std::nullopt;
} }
} // namespace onnx_mlir } // namespace onnx_mlir
+5 -6
View File
@@ -1,11 +1,12 @@
#pragma once #pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include <optional>
namespace onnx_mlir { namespace onnx_mlir {
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp); mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
@@ -22,10 +23,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
mlir::Location loc,
mlir::AffineMap map, std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
mlir::ValueRange operands,
mlir::Operation* anchorOp);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,11 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/APInt.h"
#include <algorithm> #include <algorithm>
#include "IndexingUtils.hpp" #include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
using namespace mlir; using namespace mlir;
@@ -47,46 +42,4 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
return normalizedAxes; return normalizedAxes;
} }
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
}
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
}
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
}
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
}
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateIndexConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value);
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,11 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@@ -21,21 +17,4 @@ int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank); mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange operands);
mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
mlir::Operation* anchorOp,
mlir::Value value,
int64_t multiplier);
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -11,6 +11,7 @@
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -27,8 +28,7 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
} }
static bool hasConstantIndices(tensor::ExtractOp extractOp) { static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(), return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
} }
static bool isStaticTensorResult(Operation* op) { static bool isStaticTensorResult(Operation* op) {
@@ -14,6 +14,7 @@
#include <utility> #include <utility>
#include "Common/IR/ConstantUtils.hpp" #include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -60,8 +61,11 @@ static Value createGemmBatchKOffset(
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant( return createOrFoldAffineApply(rewriter,
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); loc,
(d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
} }
static Value createGemmBatchHOffset(Value lane, static Value createGemmBatchHOffset(Value lane,
@@ -75,8 +79,11 @@ static Value createGemmBatchHOffset(Value lane,
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant( return createOrFoldAffineApply(rewriter,
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); loc,
d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
} }
static Value static Value
@@ -259,7 +266,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ValueRange {b}, ValueRange {b},
ValueRange {a}, ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) { [&](detail::SpatComputeBatchBodyArgs args) {
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows); Value row =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp());
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -297,7 +305,8 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); return createOrFoldAffineApply(
rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
} }
static Value extractDynamicGemmBColumn( static Value extractDynamicGemmBColumn(
@@ -429,7 +438,8 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
ValueRange {a, b}, ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) { [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols); Value column =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -477,7 +487,8 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar(); Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front(); Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols); Value column =
onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -515,8 +526,11 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) { Location loc) {
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrFoldedConstant( return createOrFoldAffineApply(rewriter,
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); loc,
d0 * (numKSlices * numOutRows) + kSlice * numOutRows,
ValueRange {hSlice},
rewriter.getInsertionBlock()->getParentOp());
} }
static Value extractReductionPiece(Value partialPiecesArg, static Value extractReductionPiece(Value partialPiecesArg,
@@ -597,8 +611,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced = Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = onnx_mlir::multiplyIndexByConstant( Value hOffset = onnx_mlir::affineMulConst(
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
if (biasArg) { if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset}; SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice = Value biasSlice =
@@ -7,6 +7,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -276,7 +277,8 @@ static Value extractBatchedBTile(Value b,
static Value getBatchLaneIndex( static Value getBatchLaneIndex(
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) { Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices); return affineFloorDivConst(
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
} }
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
@@ -300,16 +302,15 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
ValueRange {b}, ValueRange {b},
ValueRange {a}, ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) { [&](detail::SpatComputeBatchBodyArgs args) {
Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows); Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices); Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices); Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices); Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value kOffset = Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue()); Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
auto aTileType = auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType()); RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
@@ -445,10 +446,11 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
ValueRange {}, ValueRange {},
ValueRange {a, b}, ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) { [&](detail::SpatComputeBatchBodyArgs args) {
Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -491,10 +493,11 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces,
Value lane = loop.getInductionVar(); Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front(); Value outputAcc = loop.getRegionIterArgs().front();
Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create( Value scalar = tensor::ExtractSliceOp::create(
@@ -542,10 +545,9 @@ static Value extractBatchedReductionPiece(Value partialPiecesArg,
int64_t numOutRows, int64_t numOutRows,
PatternRewriter& rewriter, PatternRewriter& rewriter,
Location loc) { Location loc) {
Value batchOffset = multiplyIndexByConstant( Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices); Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
Value hOffset = Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows);
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows); Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset); Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset); Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
@@ -631,7 +633,7 @@ static Value createBatchedReductionCompute(Value partialPieces,
{2} {2}
}); });
Value hOffset = Value hOffset =
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
SmallVector<OpFoldResult> outputSizes { SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
@@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -82,7 +83,7 @@ computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternR
expr = expr.floorDiv(stride); expr = expr.floorDiv(stride);
if (dimSize != 1) if (dimSize != 1)
expr = expr % dimSize; expr = expr % dimSize;
return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane}); return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
} }
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input, static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
+7 -34
View File
@@ -13,6 +13,8 @@
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
@@ -56,48 +58,19 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
return success(); return success();
} }
static bool isConstantIndexLike(Value value) {
APInt constantValue;
return matchPattern(value, m_ConstantInt(&constantValue));
}
static bool isSupportedLaneAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value)) if (value == laneArg || matchConstantIndexValue(value))
return true; return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>(); auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) { if (affineApply) {
if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0) if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
return false; return false;
if (!llvm::all_of(affineApply.getMapOperands(), if (!llvm::all_of(affineApply.getMapOperands(),
[&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) { [&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) {
return false; return false;
} }
return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0)); return isDimAndConstantAffineExpr(affineApply.getAffineMap().getResult(0));
} }
auto extractOp = value.getDefiningOp<tensor::ExtractOp>(); auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
@@ -112,8 +85,8 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
auto addOp = value.getDefiningOp<arith::AddIOp>(); auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp) if (!addOp)
return false; return false;
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs())) return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs())); || (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
} }
static LogicalResult static LogicalResult
@@ -21,6 +21,7 @@
#include "MaterializeMergeSchedule.hpp" #include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -1053,7 +1054,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
} }
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func); return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func);
} }
Value createIndexedIndexValue( Value createIndexedIndexValue(
@@ -1295,47 +1296,16 @@ std::optional<unsigned> getLaneProjectedDim(ArrayRef<OpFoldResult> offsets, Valu
} }
static FailureOr<int64_t> evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) { static FailureOr<int64_t> evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) {
if (auto attr = dyn_cast<Attribute>(value)) { if (std::optional<int64_t> constant = matchConstantIndexValue(value))
auto intAttr = dyn_cast<IntegerAttr>(attr); return *constant;
if (!intAttr)
return failure();
return intAttr.getInt();
}
Value current = cast<Value>(value); Value current = cast<Value>(value);
if (current == laneArg) if (current == laneArg)
return static_cast<int64_t>(lane); return static_cast<int64_t>(lane);
if (auto constant = current.getDefiningOp<arith::ConstantIndexOp>()) if (auto affineApply = current.getDefiningOp<affine::AffineApplyOp>())
return constant.value(); return evaluateAffineApply(affineApply,
[&](Value operand) { return evaluateProjectedOffsetForLane(operand, laneArg, lane); });
if (auto constant = current.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
return intAttr.getInt();
if (auto affineApply = current.getDefiningOp<affine::AffineApplyOp>()) {
AffineMap map = affineApply.getAffineMap();
if (map.getNumResults() != 1)
return failure();
SmallVector<Attribute, 4> operandConstants;
operandConstants.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = evaluateProjectedOffsetForLane(operand, laneArg, lane);
if (failed(folded))
return failure();
operandConstants.push_back(IntegerAttr::get(IndexType::get(current.getContext()), *folded));
}
SmallVector<Attribute, 1> foldedResults;
if (failed(map.constantFold(operandConstants, foldedResults)) || foldedResults.size() != 1)
return failure();
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (!constantResult)
return failure();
return constantResult.getInt();
}
return failure(); return failure();
} }
@@ -3503,7 +3473,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size()); int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
} }
Value createBatchClassRunSourceLane(MaterializerState& state, Value createBatchClassRunSourceLane(MaterializerState& state,
@@ -21,6 +21,8 @@
#include "ComputeGraph.hpp" #include "ComputeGraph.hpp"
#include "ComputeInstanceUtils.hpp" #include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Support/TypeUtilities.hpp" #include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -148,57 +150,6 @@ static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_
return weight; return weight;
} }
static FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr<int64_t> {
if (divisor <= 0)
return failure();
if (value >= 0)
return value / divisor;
return -((-value + divisor - 1) / divisor);
};
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs);
case AffineExprKind::CeilDiv:
if (*rhs <= 0)
return failure();
return (*lhs + *rhs - 1) / *rhs;
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDiv(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
static FailureOr<int64_t> static FailureOr<int64_t>
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg); evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
@@ -222,12 +173,8 @@ evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::op
if (auto it = bindings.find(value); it != bindings.end()) if (auto it = bindings.find(value); it != bindings.end())
return it->second; return it->second;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>()) if (std::optional<int64_t> constant = matchConstantIndexValue(value))
return constant.value(); return *constant;
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
return intAttr.getInt();
if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) { if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
auto constant = extract.getTensor().getDefiningOp<arith::ConstantOp>(); auto constant = extract.getTensor().getDefiningOp<arith::ConstantOp>();
@@ -245,26 +192,11 @@ evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::op
return failure(); return failure();
} }
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>(); if (auto affineApply = value.getDefiningOp<affine::AffineApplyOp>())
if (!affineApply) return evaluateAffineApply(affineApply,
return failure(); [&](Value operand) { return evaluateIndexLike(operand, bindings, lane, laneArg); });
AffineMap map = affineApply.getAffineMap();
if (map.getNumResults() != 1)
return failure(); return failure();
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = evaluateIndexLike(operand, bindings, lane, laneArg);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
} }
static FailureOr<SmallVector<int64_t, 4>> evaluateIndexList(ArrayRef<OpFoldResult> values, static FailureOr<SmallVector<int64_t, 4>> evaluateIndexList(ArrayRef<OpFoldResult> values,