sightly better bufferization
minor fixes
This commit is contained in:
@@ -381,7 +381,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
@@ -527,7 +527,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
|
||||
@@ -95,7 +95,7 @@ bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
|
||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||
mlir::Value result = operation->getResult(0);
|
||||
auto resultType = result.getType();
|
||||
|
||||
@@ -41,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
|
||||
|
||||
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
||||
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||
|
||||
inline mlir::tensor::EmptyOp
|
||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||
|
||||
@@ -16,13 +16,13 @@ def onnxToPimTranspose : Pat<
|
||||
>;
|
||||
|
||||
def spatToPimVMM : Pat<
|
||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimMVM : Pat<
|
||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimMVMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
@@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
rewriter.replaceOp(concatOp, concatenated);
|
||||
}
|
||||
|
||||
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
|
||||
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
|
||||
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
wvmmOps.push_back(wvmmOp);
|
||||
});
|
||||
|
||||
for (auto wvmmOp : wvmmOps) {
|
||||
rewriter.setInsertionPoint(wvmmOp);
|
||||
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
|
||||
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
|
||||
wvmmOp.getOutput().getType(),
|
||||
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
|
||||
wvmmOp.getInput(),
|
||||
outputBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatMapOp> mapOps;
|
||||
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
||||
@@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() {
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lowerRemainingSpatialMathOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
|
||||
Reference in New Issue
Block a user