diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index b503837..f6e56c6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -239,8 +239,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { for (auto compute : funcOp.getOps()) if (compute->hasOneUse()) { - auto user = *compute->getUsers().begin(); - if (llvm::isa(user) && user->getNumOperands() == 1) + auto user = dyn_cast(*compute->getUsers().begin()); + + if (user && user.getInputs().size() == 1) trivialComputes.push_back(compute); } @@ -255,19 +256,48 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { auto child = cast(*compute->getUsers().begin()); rewriter.setInsertionPointAfter(compute.getOperation()); + auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); newCompute.getProperties().setOperandSegmentSizes( {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); IRMapping mapper; + auto weightMutableIter = newCompute.getWeightsMutable(); + for (auto weight : child.getWeights()) { + auto founded = llvm::find(newCompute.getWeights(), weight); + if (founded == newCompute.getWeights().end()) { + weightMutableIter.append(weight); + auto last = weightMutableIter.end(); + last = std::prev(last, 1); + mapper.map(weight, last->get()); + } + else { + mapper.map(weight, *founded); + } + } + compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); auto newTerminator = newCompute.getBody().front().getTerminator(); mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); newTerminator->erase(); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); - for (auto& op : child.getBody().front()) - rewriter.clone(op, mapper); + for (auto& op : child.getBody().front()) { + auto newInst = rewriter.clone(op, mapper); + + if (auto vmOp = llvm::dyn_cast(newInst)) { + auto oldIndex = vmOp.getWeightIndex(); + auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex)); + auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight)); + vmOp.setWeightIndex(newIndex); + } + if (auto vmOp = llvm::dyn_cast(newInst)) { + auto oldIndex = vmOp.getWeightIndex(); + auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex)); + auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight)); + vmOp.setWeightIndex(newIndex); + } + } child.replaceAllUsesWith(newCompute); toErase.insert(child); @@ -277,8 +307,8 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { toErase.insert(compute); if (newCompute->hasOneUse()) { - auto user = *newCompute->getUsers().begin(); - if (llvm::isa(user) && user->getNumOperands() == 1) + auto user = dyn_cast(*newCompute->getUsers().begin()); + if (user && user.getInputs().size() == 1) trivialComputes.push_back(newCompute); } } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 83dce8a..075e498 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -219,7 +219,6 @@ private: auto weightMutableIter = toCompute.getWeightsMutable(); for (auto weight : fromCompute.getWeights()) { - // TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight auto founded = llvm::find(toCompute.getWeights(), weight); if (founded == toCompute.getWeights().end()) { size_t sizeW = toCompute.getWeights().size();