-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GradVac
#249
base: main
Are you sure you want to change the base?
Add GradVac
#249
Conversation
for more information, see https://pre-commit.ci
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
Thanks for your interest and effort on this PR! It seems there is a misunderstanding about the role of
Do you think your intended use-case would work by following those two steps? TorchJD is specifically designed to handle step 2, whereas step 1 (reweighting the losses) is independent of the Jacobian descent algorithm. With GradNorm, some gradients (i.e., Jacobian rows) are used to balance the losses, but this balancing also depends on the loss values and some preserved states. For these reasons, we won’t be able to merge this PR. Additionally, we don’t see TorchJD incorporating methods like GradNorm in the future, as maintaining them falls outside our intended scope and would require more effort than we’re able to commit. |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Dear Valerian, Thank you for your detailed response. I agree with your explanation regarding the role of Aggregators in TorchJD. I understand and respect your decision not to integrate GradNorm given that it is not an aggregation heuristic in itself. On a related note, I have been preparing another pull request that focuses on an aggregator implementation, which I believe is more aligned with TorchJD’s intended design. Thank you again for your prompt and comprehensive feedback. I appreciate the work you and the team do on TorchJD, and I look forward to contributing further in the future. Warm regards, |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
---Update--- GradVac is an aggregator that modifies gradients when the observed cosine similarity between tasks falls below a desired target. This target can be set as a constant or, it can also be computed adaptively using an exponential moving average scheme if the target is left as None. Would you prefer that I open a new pull request specifically for GradVac, or should I continue the discussion here to propose merging GradVac in place of GradNorm? Again, thank you for your time and support ! Warm regards, |
@EmileAydar Thank you for the proposition, we were unaware of the GradVac aggregator, we might consider implementing it in the future. This will require some extra work on our side, for instance determining the properties it satisfies. In the meantime, it seems that you are facing problems of conflict which PCGrad cannot solve (unless there are two objectives), you can try to train using an aggregator that satisfies the non-conflicting property, you can find the list here, I personally would recommend UPGrad or DualProj as the others may not find Pareto optimal points. Another way to pick your aggregator could be to checkout the trajectories on pages 15-16 of Jacobian descent for multi-objective optimization and select the aggregator corresponding to what you feel would be natural in your problem. We will not have time to add GradVac to the list of aggregator in short terms, so you can also inherit from |
@PierreQuinton Thank you for your detailed reply. In my experiments, I noticed that :
I am also exploring EPO aggregators that integrate a stakeholder preference vector, that I think could be of interest for TorchJD users, especially for engineers in business contexts. Thanks also for considering GradVac for a future update. Although I understand it might be reimplemented on your side, I’ve developed an extension based on your PCGrad code and will be using it in my experiments. I’m happy to provide detailed feedback on its performance, and an acknowledgment in your documentation would be appreciated if my contribution proves valuable. |
Hi @EmileAydar! At the moment, TorchJD only supports stateless (immutable) aggregators. Gradient Vaccine is based on some exponential moving average of the cosine similarities between gradients, which is a state. We are thinking about adding support for stateful methods in the near future, and Gradient Vaccine could definitely be a good candidate to test this. So we will keep this PR open for now until we make progress on the stateful structure, and we will get back to it afterwards. |
GradVac
Thanks for these references! I think Pareto Multi-Task Learning can't be simply integrated into TorchJD (they don't have the same objective as us). At a first glance, it seems that PMGDA is based on a stateless aggregator, so it should be possible to integrate it into TorchJD. We would however need a more thourough understanding of its theoretical properties before even trying to integrate it. As for Exact Pareto Optimization, I still have to read the paper. I'll update this comment when I find time for this. |
[EDIT: This PR is now about GradVac. See comments below]
This pull request introduces a new GradNorm loss-balancing wrapper to the TorchJD library. The key changes are as follows:
A new module (located in torchjd/aggregation/gradnorm_wrapper.py) implements the GradNorm loss-balancing mechanism as described in Chen et al.'s ICML 2018 paper. Unlike other aggregators, this wrapper operates on a list of task losses and computes adaptive loss weights via a learned parameter vector.
The wrapper has been updated to accept a device parameter so that it runs seamlessly on both CPU and CUDA. Corresponding tests now run on multiple devices.
The tests in tests/unit/aggregation/test_gradnorm_wrapper.py have been developped to ensure high coverage as requested, to:
The documentation has been updated with detailed docstrings and a usage example showing how to use GradNormWrapper both standalone and in combination with an aggregator like MGDA.
Changelog and Version Update:
Please review these changes at your convenience. If any further adjustments are needed, feel free to let me know. Once approved, I will merge this PR and update the repository accordingly.
Thank you for your time and consideration!
Best regards,
Emile