compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -65,6 +66,66 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
RankedTensorType matrixType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numRows = matrixType.getDimSize(0);
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
};
|
||||
|
||||
auto cloneBatchInputChainIntoSliceCompute =
|
||||
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
||||
Value transformedMatrix = input;
|
||||
if (!chainOps.empty()) {
|
||||
IRMapping mapper;
|
||||
mapper.map(rootValue, input);
|
||||
for (Operation* chainOp : chainOps)
|
||||
rewriter.clone(*chainOp, mapper);
|
||||
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
||||
});
|
||||
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
||||
return rowSlices;
|
||||
};
|
||||
|
||||
SmallVector<Operation*> chainOps;
|
||||
Value rootValue = matrix;
|
||||
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
||||
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(
|
||||
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
if (definingOp->getNumOperands() != 1)
|
||||
break;
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
break;
|
||||
|
||||
chainOps.push_back(definingOp);
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
@@ -156,8 +217,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
@@ -313,15 +373,116 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
auto concatComputeOp =
|
||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
||||
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
||||
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||
return failure();
|
||||
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user