add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -180,8 +181,11 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(
|
||||
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
|
||||
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
||||
if (mlir::failed(laneCountAttr))
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||
|
||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -305,58 +306,67 @@ static Value createIm2colRowComputes(Value x,
|
||||
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
auto im2colLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cNumPatches,
|
||||
c1,
|
||||
ValueRange {im2colInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value im2colAcc = iterArgs.front();
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
|
||||
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch =
|
||||
tensor::ExtractSliceOp::create(rewriter, nestedLoc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
nestedLoc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
yielded.push_back(updatedIm2col);
|
||||
return success();
|
||||
});
|
||||
if (failed(im2colLoop))
|
||||
return failure();
|
||||
Value im2col = im2colLoop->results.front();
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1)
|
||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
return success();
|
||||
});
|
||||
|
||||
return im2colComputeOp.getResult(0);
|
||||
assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed");
|
||||
return im2colComputeOp->getResult(0);
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
@@ -247,16 +248,16 @@ static Value createPaddedInputCompute(Value input,
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType paddedBType,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType paddedBType,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
@@ -294,7 +295,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
@@ -416,15 +418,15 @@ static Value createBroadcastedBiasScalar(Value bias,
|
||||
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numOutRows = outType.getDimSize(0);
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
@@ -454,26 +456,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value bias,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType biasType,
|
||||
RankedTensorType outType,
|
||||
float alpha,
|
||||
float beta,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatCompute> createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value bias,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType biasType,
|
||||
RankedTensorType outType,
|
||||
float alpha,
|
||||
float beta,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
SmallVector<Value> inputs {scalarPieces};
|
||||
if (bias)
|
||||
inputs.push_back(bias);
|
||||
|
||||
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) {
|
||||
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
|
||||
Value pieces = blockArgs[0];
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
@@ -481,40 +484,50 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
auto loop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cLaneCount,
|
||||
c1,
|
||||
ValueRange {outputInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value outputAcc = iterArgs.front();
|
||||
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc);
|
||||
Value column =
|
||||
onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar = tensor::ExtractSliceOp::create(
|
||||
rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
|
||||
.getResult();
|
||||
if (alpha != 1.0f) {
|
||||
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc);
|
||||
scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult();
|
||||
}
|
||||
if (biasArg) {
|
||||
Value biasScalar =
|
||||
createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc);
|
||||
if (beta != 1.0f) {
|
||||
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc);
|
||||
biasScalar =
|
||||
spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult();
|
||||
}
|
||||
scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult();
|
||||
}
|
||||
SmallVector<OpFoldResult> outputOffsets {row, column};
|
||||
Value outputNext =
|
||||
tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
|
||||
.getResult();
|
||||
yielded.push_back(outputNext);
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
|
||||
Value column =
|
||||
onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
|
||||
.getResult();
|
||||
if (alpha != 1.0f) {
|
||||
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc);
|
||||
scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult();
|
||||
}
|
||||
if (biasArg) {
|
||||
Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc);
|
||||
if (beta != 1.0f) {
|
||||
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc);
|
||||
biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult();
|
||||
}
|
||||
scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult();
|
||||
}
|
||||
SmallVector<OpFoldResult> outputOffsets {row, column};
|
||||
Value outputNext =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
|
||||
.getResult();
|
||||
scf::YieldOp::create(rewriter, loc, outputNext);
|
||||
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -579,85 +592,92 @@ static Value reducePartialPiecesForHSlice(Value partialPiecesArg,
|
||||
return activePieces.front();
|
||||
}
|
||||
|
||||
static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
Value bias,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numKSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatCompute> createReductionCompute(Value partialPieces,
|
||||
Value bias,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numKSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<Value> inputs {partialPieces};
|
||||
if (bias)
|
||||
inputs.push_back(bias);
|
||||
|
||||
auto computeOp = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) {
|
||||
Value partialPiecesArg = blockArgs[0];
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType)
|
||||
biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc);
|
||||
auto computeOp =
|
||||
createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
|
||||
Value partialPiecesArg = blockArgs[0];
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType)
|
||||
biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc);
|
||||
|
||||
const int64_t numOutRows = outType.getDimSize(0);
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue());
|
||||
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
const int64_t numOutRows = outType.getDimSize(0);
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue());
|
||||
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
|
||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||
Value reduced =
|
||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||
Value hOffset = onnx_mlir::affineMulConst(
|
||||
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
||||
if (biasArg) {
|
||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
Value biasSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides)
|
||||
.getResult();
|
||||
reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult();
|
||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||
Value reduced =
|
||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||
Value hOffset = onnx_mlir::affineMulConst(
|
||||
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
||||
if (biasArg) {
|
||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
Value biasSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides)
|
||||
.getResult();
|
||||
reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides)
|
||||
.getResult();
|
||||
};
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cOutHSlices,
|
||||
c1,
|
||||
ValueRange {outputInit},
|
||||
[&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice));
|
||||
return success();
|
||||
});
|
||||
if (failed(hLoop))
|
||||
return failure();
|
||||
paddedOutput = hLoop->results.front();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides)
|
||||
.getResult();
|
||||
};
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
|
||||
Value hSlice = hLoop.getInductionVar();
|
||||
Value outputAcc = hLoop.getRegionIterArgs().front();
|
||||
scf::YieldOp::create(rewriter, loc, buildOutputSlice(outputAcc, hSlice));
|
||||
|
||||
rewriter.setInsertionPointAfter(hLoop);
|
||||
paddedOutput = hLoop.getResult(0);
|
||||
}
|
||||
|
||||
Value result = paddedOutput;
|
||||
if (paddedOutType != outType) {
|
||||
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
result =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides)
|
||||
.getResult();
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
Value result = paddedOutput;
|
||||
if (paddedOutType != outType) {
|
||||
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
result =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides)
|
||||
.getResult();
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
return success();
|
||||
});
|
||||
|
||||
return computeOp;
|
||||
}
|
||||
@@ -755,9 +775,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||
auto batchOp =
|
||||
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto outputCompute = createDynamicGemmOutputCompute(
|
||||
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
||||
rewriter.replaceOp(gemmOp, outputCompute.getResults());
|
||||
batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
||||
if (failed(outputCompute))
|
||||
return failure();
|
||||
rewriter.replaceOp(gemmOp, outputCompute->getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -832,10 +856,14 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
|
||||
auto batchOp =
|
||||
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto reductionCompute = createReductionCompute(
|
||||
batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
|
||||
batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
|
||||
if (failed(reductionCompute))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(gemmOp, reductionCompute.getResults());
|
||||
rewriter.replaceOp(gemmOp, reductionCompute->getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
@@ -281,18 +282,18 @@ static Value getBatchLaneIndex(
|
||||
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
int64_t aBatchCount,
|
||||
RankedTensorType bType,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
int64_t aBatchCount,
|
||||
RankedTensorType bType,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
@@ -331,7 +332,8 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed");
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
@@ -422,17 +424,17 @@ static Value extractDynamicBatchedRowVector(Value matrix,
|
||||
});
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
|
||||
int64_t aBatchCount,
|
||||
Value b,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||
int64_t aBatchCount,
|
||||
Value b,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numBatches = outType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
@@ -466,64 +468,73 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed");
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) {
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) -> LogicalResult {
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputScalarType,
|
||||
scalar,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
scf::YieldOp::create(
|
||||
auto loop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||
c0,
|
||||
cLaneCount,
|
||||
c1,
|
||||
ValueRange {outputInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value outputAcc = iterArgs.front();
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value batch = affineFloorDivConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value batchLane = affineModConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value row = affineFloorDivConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
|
||||
Value column = affineModConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar = tensor::ExtractSliceOp::create(
|
||||
rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
nestedLoc,
|
||||
outputScalarType,
|
||||
scalar,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
||||
SmallVector<OpFoldResult> outputSizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value next =
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, nestedLoc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult();
|
||||
yielded.push_back(next);
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
|
||||
return success();
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
return computeOp->getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -587,16 +598,16 @@ static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
|
||||
return activePieces.front();
|
||||
}
|
||||
|
||||
static Value createBatchedReductionCompute(Value partialPieces,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numBatches,
|
||||
int64_t numKSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<Value> createBatchedReductionCompute(Value partialPieces,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numBatches,
|
||||
int64_t numKSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) {
|
||||
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) -> LogicalResult {
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
|
||||
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
@@ -612,43 +623,55 @@ static Value createBatchedReductionCompute(Value partialPieces,
|
||||
Value cNumOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
Value batch = batchLoop.getInductionVar();
|
||||
Value batchAcc = batchLoop.getRegionIterArgs().front();
|
||||
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
Value hSlice = hLoop.getInductionVar();
|
||||
Value outputAcc = hLoop.getRegionIterArgs().front();
|
||||
|
||||
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
||||
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
|
||||
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputSliceType,
|
||||
reduced,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value hOffset =
|
||||
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
scf::YieldOp::create(
|
||||
auto batchLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(hLoop);
|
||||
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
Value paddedOutput = batchLoop.getResult(0);
|
||||
c0,
|
||||
cNumBatches,
|
||||
c1,
|
||||
ValueRange {outputInit},
|
||||
[&](
|
||||
OpBuilder&, Location batchLoc, Value batch, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
|
||||
auto hLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
batchLoc,
|
||||
c0,
|
||||
cNumOutHSlices,
|
||||
c1,
|
||||
ValueRange {batchIterArgs.front()},
|
||||
[&](OpBuilder&, Location hLoc, Value hSlice, ValueRange hIterArgs, SmallVectorImpl<Value>& hYielded) {
|
||||
Value outputAcc = hIterArgs.front();
|
||||
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
||||
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, hLoc);
|
||||
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
||||
hLoc,
|
||||
outputSliceType,
|
||||
reduced,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value hOffset = affineMulConst(
|
||||
rewriter, hLoc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numOutRows),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
Value next =
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, hLoc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult();
|
||||
hYielded.push_back(next);
|
||||
return success();
|
||||
});
|
||||
if (failed(hLoop))
|
||||
return failure();
|
||||
batchYielded.push_back(hLoop->results.front());
|
||||
return success();
|
||||
});
|
||||
if (failed(batchLoop))
|
||||
return failure();
|
||||
Value paddedOutput = batchLoop->results.front();
|
||||
Value result = paddedOutput;
|
||||
if (paddedOutType != outType) {
|
||||
SmallVector<OpFoldResult> outputOffsets {
|
||||
@@ -660,8 +683,11 @@ static Value createBatchedReductionCompute(Value partialPieces,
|
||||
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
return success();
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
return computeOp->getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulShapeInfo {
|
||||
@@ -841,22 +867,27 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
numOutHSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result = createBatchedReductionCompute(batchOp.getResult(0),
|
||||
partialPiecesType,
|
||||
directOutType,
|
||||
paddedOutType,
|
||||
shapeInfo->batch,
|
||||
numKSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto result = createBatchedReductionCompute(batchOp->getResult(0),
|
||||
partialPiecesType,
|
||||
directOutType,
|
||||
paddedOutType,
|
||||
shapeInfo->batch,
|
||||
numKSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
@@ -873,16 +904,21 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
false,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result =
|
||||
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto result =
|
||||
createBatchedDynamicOutputCompute(batchOp->getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -275,86 +276,102 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||
auto outputLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cOutputPatchCount,
|
||||
c1,
|
||||
ValueRange {pooledOutputInit},
|
||||
[&](OpBuilder&,
|
||||
Location nestedLoc,
|
||||
Value outputPatchIndex,
|
||||
ValueRange iterArgs,
|
||||
SmallVectorImpl<Value>& yielded) {
|
||||
Value pooledOutputAcc = iterArgs.front();
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value batchPatchIndex =
|
||||
arith::RemUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
|
||||
Value windowBaseH = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
|
||||
Value windowBaseW = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
|
||||
|
||||
Value outputPatchIndex = outputLoop.getInductionVar();
|
||||
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
|
||||
Value updatedOutput = pooledOutputAcc;
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value reducedWindow =
|
||||
createPoolFillTensor(rewriter, nestedLoc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, nestedLoc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
Value updatedOutput = pooledOutputAcc;
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value reducedWindow =
|
||||
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, nestedLoc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, nestedLoc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeTileTensor(rewriter, nestedLoc, windowValue);
|
||||
reducedWindow = ReduceOp::create(rewriter, nestedLoc, tileType, reducedWindow, windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, nestedLoc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeTileTensor(rewriter, nestedLoc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, nestedLoc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeTileTensor(rewriter, loc, windowValue);
|
||||
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||
updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, nestedLoc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||
}
|
||||
}
|
||||
yielded.push_back(updatedOutput);
|
||||
return success();
|
||||
});
|
||||
if (failed(outputLoop))
|
||||
return failure();
|
||||
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
SmallVector<OpFoldResult> scaleOffsets = {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||
}
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(outputLoop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, outputLoop->results.front());
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -42,13 +43,13 @@ static Value buildLoopSoftmaxSlice(Value input,
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<Value> buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
@@ -57,38 +58,50 @@ static Value buildLoopSoftmaxNest(Value input,
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value loopIndex = loop.getInductionVar();
|
||||
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||
outerIndices.push_back(loopIndex);
|
||||
Value updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||
outerIndices.pop_back();
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
auto loop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cUpper,
|
||||
c1,
|
||||
ValueRange {accumulator},
|
||||
[&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
outerIndices.push_back(loopIndex);
|
||||
auto updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc);
|
||||
outerIndices.pop_back();
|
||||
if (failed(updatedAccumulator))
|
||||
return failure();
|
||||
yielded.push_back(*updatedAccumulator);
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
return loop->results.front();
|
||||
}
|
||||
|
||||
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static FailureOr<Value> createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto computeOp = createSpatCompute<numInputs>(
|
||||
rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult {
|
||||
if (inputType.getRank() == 1) {
|
||||
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||
SmallVector<Value> outerIndices;
|
||||
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
if (failed(result))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, *result);
|
||||
return success();
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
return computeOp->getResult(0);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
@@ -108,7 +121,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (*axis == inputType.getRank() - 1) {
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
if (failed(computed))
|
||||
return failure();
|
||||
result = *computed;
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
@@ -122,8 +138,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
if (failed(transposedResult))
|
||||
return failure();
|
||||
result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -192,13 +193,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
promoted->newWeights,
|
||||
promoted->newInputs);
|
||||
auto laneCountAttr = pim::getCheckedI32Attr(
|
||||
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
||||
if (failed(laneCountAttr))
|
||||
return failure();
|
||||
auto newCompute = spatial::SpatComputeBatch::create(
|
||||
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -26,11 +27,11 @@ static Value buildNearestAsymmetricIndex(
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
}
|
||||
|
||||
static Value buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<Value> buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = resultType.getElementType();
|
||||
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||
@@ -48,54 +49,94 @@ static Value buildNearestResizeLoop(Value input,
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
auto batchLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cOutputN,
|
||||
c1,
|
||||
ValueRange {outputInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value outputN, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
|
||||
Value outputBatchAcc = batchIterArgs.front();
|
||||
Value inputN =
|
||||
buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, nestedLoc);
|
||||
|
||||
Value outputN = batchLoop.getInductionVar();
|
||||
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||
auto channelLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
nestedLoc,
|
||||
c0,
|
||||
cOutputC,
|
||||
c1,
|
||||
ValueRange {outputBatchAcc},
|
||||
[&](OpBuilder&,
|
||||
Location channelLoc,
|
||||
Value outputC,
|
||||
ValueRange channelIterArgs,
|
||||
SmallVectorImpl<Value>& channelYielded) {
|
||||
Value outputChannelAcc = channelIterArgs.front();
|
||||
Value inputC = buildNearestAsymmetricIndex(
|
||||
outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, channelLoc);
|
||||
|
||||
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||
auto heightLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
channelLoc,
|
||||
c0,
|
||||
cOutputH,
|
||||
c1,
|
||||
ValueRange {outputChannelAcc},
|
||||
[&](OpBuilder&,
|
||||
Location heightLoc,
|
||||
Value outputH,
|
||||
ValueRange heightIterArgs,
|
||||
SmallVectorImpl<Value>& heightYielded) {
|
||||
Value outputHeightAcc = heightIterArgs.front();
|
||||
Value inputH = buildNearestAsymmetricIndex(
|
||||
outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, heightLoc);
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
auto widthLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
heightLoc,
|
||||
c0,
|
||||
cOutputW,
|
||||
c1,
|
||||
ValueRange {outputHeightAcc},
|
||||
[&](OpBuilder&,
|
||||
Location widthLoc,
|
||||
Value outputW,
|
||||
ValueRange widthIterArgs,
|
||||
SmallVectorImpl<Value>& widthYielded) {
|
||||
Value outputWidthAcc = widthIterArgs.front();
|
||||
Value inputW = buildNearestAsymmetricIndex(
|
||||
outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, widthLoc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, widthLoc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(widthLoop);
|
||||
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(heightLoop);
|
||||
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(channelLoop);
|
||||
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
return batchLoop.getResult(0);
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, widthLoc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
widthYielded.push_back(updatedOutput);
|
||||
return success();
|
||||
});
|
||||
if (failed(widthLoop))
|
||||
return failure();
|
||||
heightYielded.push_back(widthLoop->results.front());
|
||||
return success();
|
||||
});
|
||||
if (failed(heightLoop))
|
||||
return failure();
|
||||
channelYielded.push_back(heightLoop->results.front());
|
||||
return success();
|
||||
});
|
||||
if (failed(channelLoop))
|
||||
return failure();
|
||||
batchYielded.push_back(channelLoop->results.front());
|
||||
return success();
|
||||
});
|
||||
if (failed(batchLoop))
|
||||
return failure();
|
||||
return batchLoop->results.front();
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
@@ -120,12 +161,17 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) -> LogicalResult {
|
||||
auto result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
if (failed(result))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), *result);
|
||||
return success();
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(resizeOp, computeOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -25,14 +26,21 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
||||
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
coreIds.push_back(*checkedCoreId);
|
||||
++fallbackCoreId;
|
||||
}
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
@@ -102,21 +110,24 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
if (failed(coreIds))
|
||||
return failure();
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||
|
||||
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||
ValueRange(batchWeights),
|
||||
ValueRange(batchInputs));
|
||||
auto laneCountAttr = pim::getCheckedI32Attr(
|
||||
rewriter, computeBatchOp, static_cast<uint64_t>(computeBatchOp.getLaneCount()), "pim core_batch lane count");
|
||||
if (failed(laneCountAttr))
|
||||
return failure();
|
||||
auto coreBatchOp =
|
||||
pim::PimCoreBatchOp::create(rewriter, loc, *laneCountAttr, ValueRange(batchWeights), ValueRange(batchInputs));
|
||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
@@ -160,14 +171,11 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
outputBuffer,
|
||||
newArg,
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), newArg);
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
auto copied = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, newArg, *sizeAttr)
|
||||
.getOutput();
|
||||
mapper.map(*oldArg, copied);
|
||||
}
|
||||
@@ -209,6 +217,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
hostTarget.getType(),
|
||||
@@ -216,7 +227,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
zeroOffset,
|
||||
hostTarget,
|
||||
mappedSource,
|
||||
getTensorSizeInBytesAttr(rewriter, mappedSource));
|
||||
*sizeAttr);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -232,15 +243,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
outputBuffer,
|
||||
clonedTensor,
|
||||
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||
.getOutput();
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), clonedTensor);
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
auto copied =
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, clonedTensor, *sizeAttr)
|
||||
.getOutput();
|
||||
mapper.map(toTensorOp.getResult(), copied);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -5,14 +5,18 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
FailureOr<IntegerAttr> getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) {
|
||||
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast<ShapedType>(value.getType()), anchor, "tensor byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size");
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Operation* anchor, mlir::Value value);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -54,10 +55,15 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
return static_cast<int32_t>(fallbackCoreId++);
|
||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
++fallbackCoreId;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
@@ -163,10 +169,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), *blockArg);
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, *sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveOp);
|
||||
@@ -206,8 +214,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId);
|
||||
if (failed(checkedCoreId))
|
||||
return failure();
|
||||
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
|
||||
if (failed(coreIdAttr))
|
||||
return failure();
|
||||
auto coreOp = PimCoreOp::create(rewriter, loc, ValueRange(computeWeights), *coreIdAttr);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
@@ -226,6 +239,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
if (!inputType)
|
||||
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), input);
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
auto copied =
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
@@ -234,7 +250,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(copied);
|
||||
}
|
||||
|
||||
@@ -14,8 +14,10 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(
|
||||
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -32,12 +34,11 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
auto outputType = cast<ShapedType>(op.getResult().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received = pim::PimReceiveOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
op.getSourceCoreId())
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
Value received = pim::PimReceiveOp::create(
|
||||
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -71,6 +72,20 @@ static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<i
|
||||
return indices;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t>
|
||||
getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* anchor, StringRef fieldName) {
|
||||
if (elementOffset < 0) {
|
||||
anchor->emitOpError() << fieldName << " requires a nonnegative element offset";
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto byteOffset =
|
||||
pim::checkedMul(static_cast<uint64_t>(elementOffset), static_cast<uint64_t>(elementSize), anchor, fieldName);
|
||||
if (failed(byteOffset))
|
||||
return failure();
|
||||
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
@@ -360,18 +375,21 @@ static void cloneHelperChain(Value sourceValue,
|
||||
}
|
||||
}
|
||||
|
||||
static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes,
|
||||
OperationFolder& constantFolder) {
|
||||
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
Value sourceValue,
|
||||
int64_t hostTargetOffset,
|
||||
int64_t deviceSourceOffset,
|
||||
uint64_t sizeInBytes,
|
||||
OperationFolder& constantFolder) {
|
||||
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||
Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset);
|
||||
Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, anchorOp, sizeInBytes, "return-path host copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
@@ -379,7 +397,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
deviceSourceOffsetValue,
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
@@ -433,18 +451,15 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(storedType, producerOp, "return-path host copy byte size");
|
||||
if (failed(byteSize))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
currentStoredValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
auto copied = emitHostCopy(rewriter, loc, outputTensor, currentStoredValue, 0, 0, *byteSize, constantFolder);
|
||||
if (failed(copied))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
@@ -455,23 +470,25 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
|
||||
if (failed(byteSize))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
storedValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
auto copied = emitHostCopy(rewriter, loc, outputTensor, storedValue, 0, 0, *byteSize, constantFolder);
|
||||
if (failed(copied))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
auto storedByteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "concat return-path copy byte size");
|
||||
if (failed(storedByteSize))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
@@ -480,14 +497,13 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
storedValue,
|
||||
static_cast<int32_t>(flatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
auto hostOffset = getCheckedByteOffset(flatOffset, elementSize, producerOp, "concat return-path host offset");
|
||||
if (failed(hostOffset))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
auto copied =
|
||||
emitHostCopy(rewriter, loc, outputTensor, storedValue, *hostOffset, 0, *storedByteSize, constantFolder);
|
||||
if (failed(copied))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
@@ -531,14 +547,18 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
outputTensor = emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize),
|
||||
constantFolder);
|
||||
auto hostOffset =
|
||||
getCheckedByteOffset(destinationFlatOffset, elementSize, producerOp, "concat helper return-path host offset");
|
||||
if (failed(hostOffset))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
auto elementByteSize = pim::checkedCast<uint64_t>(elementSize, producerOp, "return-path scalar copy byte size");
|
||||
if (failed(elementByteSize))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
auto copied = emitHostCopy(
|
||||
rewriter, loc, outputTensor, elementSlice.getResult(), *hostOffset, 0, *elementByteSize, constantFolder);
|
||||
if (failed(copied))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
outputTensor = *copied;
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
@@ -25,8 +25,9 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||
@@ -75,21 +76,28 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr =
|
||||
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value
|
||||
static FailureOr<Value>
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
@@ -101,10 +109,18 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput();
|
||||
auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
if (failed(zeroed))
|
||||
return failure();
|
||||
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
@@ -234,7 +250,11 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
if (failed(enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to enlarge VMM output tensors to crossbar size");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||
eraseOpsToRemove();
|
||||
|
||||
@@ -271,8 +291,9 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
bool hasFailure = false;
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
@@ -280,19 +301,23 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
if (failed(paddedInput)) {
|
||||
hasFailure = true;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
? vmmOp.getOutputBuffer()
|
||||
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
|
||||
vmmOp.getInputMutable().assign(paddedInput);
|
||||
vmmOp.getInputMutable().assign(*paddedInput);
|
||||
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
|
||||
|
||||
vmmOp.getOutput().setType(paddedOutputType);
|
||||
|
||||
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
|
||||
return;
|
||||
return WalkResult::advance();
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
|
||||
@@ -302,13 +327,16 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
|
||||
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
@@ -319,17 +347,28 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||
auto offsetBytes = pim::checkedMul(
|
||||
static_cast<size_t>(elementsOffset), elementByteSize, deviceTensor.getOperation(), "host input byte offset");
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(tensorType, deviceTensor.getOperation(), "host input copy byte size");
|
||||
auto sizeAttr =
|
||||
succeeded(byteSize)
|
||||
? pim::getCheckedI32Attr(rewriter, deviceTensor.getOperation(), *byteSize, "host input copy byte size")
|
||||
: FailureOr<IntegerAttr>(failure());
|
||||
if (failed(offsetBytes) || failed(sizeAttr)) {
|
||||
hasFailure = true;
|
||||
return;
|
||||
}
|
||||
|
||||
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
|
||||
getOrCreateIndexConstant(
|
||||
constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize)),
|
||||
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(*offsetBytes)),
|
||||
deviceTensor,
|
||||
inputTensor,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
*sizeAttr);
|
||||
|
||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||
};
|
||||
@@ -347,7 +386,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
|
||||
@@ -64,7 +64,7 @@ private:
|
||||
void markOpToRemove(mlir::Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
Reference in New Issue
Block a user