This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isCompileTimeComputable(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
||||
}
|
||||
|
||||
static Value
|
||||
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value
|
||||
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (multiplier == 0)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
|
||||
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
|
||||
}
|
||||
|
||||
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (divisor == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
|
||||
}
|
||||
|
||||
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
|
||||
static Value createGemmBatchKOffset(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (numKSlices == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (numOutHSlices == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange {partialPiecesType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
|
||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType};
|
||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
auto aTileType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
paddedBType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
|
||||
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto weight = batchOp.getWeightArgument(0);
|
||||
auto input = batchOp.getInputArgument(0);
|
||||
auto output = batchOp.getOutputArgument(0);
|
||||
assert(lane && weight && input && output && "malformed Gemm compute_batch body");
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||
Value bTile =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
||||
.getResult();
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
|
||||
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
paddedBType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value bTile =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return batchOp;
|
||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchRow(
|
||||
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchColumn(
|
||||
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
const int64_t laneCount = numOutRows * numOutCols;
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange {scalarPiecesType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||
ValueRange {},
|
||||
ValueRange {a, b});
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
|
||||
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
|
||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed
|
||||
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto inputA = batchOp.getInputArgument(0);
|
||||
auto inputB = batchOp.getInputArgument(1);
|
||||
auto output = batchOp.getOutputArgument(0);
|
||||
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
|
||||
|
||||
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed
|
||||
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return batchOp;
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = createIndexConstant(rewriter, 0);
|
||||
Value c1 = createIndexConstant(rewriter, 1);
|
||||
Value cLaneCount = createIndexConstant(rewriter, laneCount);
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
|
||||
Location loc) {
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
}
|
||||
|
||||
static Value extractReductionPiece(Value partialPiecesArg,
|
||||
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = createIndexConstant(rewriter, 0);
|
||||
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = createIndexConstant(rewriter, 0);
|
||||
Value c1 = createIndexConstant(rewriter, 1);
|
||||
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices);
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user