This commit is contained in:
@@ -402,24 +402,37 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||
for (Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
for (Value input : aHSlices[coreId]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
Block* body =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
||||
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
@@ -530,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
TypeRange {outType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||
SmallVector<Location> blockArgLocs(4, loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value lane = batchOp.getLaneArgument();
|
||||
Value weight = batchOp.getWeightArgument(0);
|
||||
Value packedInput = batchOp.getInputArgument(0);
|
||||
Value packedOutput = batchOp.getOutputArgument(0);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
||||
unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user