Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GradVac #249

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

Add GradVac #249

wants to merge 31 commits into from

Conversation

EmileAydar
Copy link

@EmileAydar EmileAydar commented Feb 10, 2025

[EDIT: This PR is now about GradVac. See comments below]

This pull request introduces a new GradNorm loss-balancing wrapper to the TorchJD library. The key changes are as follows:

  • A new module (located in torchjd/aggregation/gradnorm_wrapper.py) implements the GradNorm loss-balancing mechanism as described in Chen et al.'s ICML 2018 paper. Unlike other aggregators, this wrapper operates on a list of task losses and computes adaptive loss weights via a learned parameter vector.

  • The wrapper has been updated to accept a device parameter so that it runs seamlessly on both CPU and CUDA. Corresponding tests now run on multiple devices.

The tests in tests/unit/aggregation/test_gradnorm_wrapper.py have been developped to ensure high coverage as requested, to:

  • Validate basic functionality (forward pass, error handling, reset behavior).
  • Verify that the internal loss-scale parameters update correctly (including convergence behavior).
  • Handle edge cases such as zero gradients.
  • Integrate with existing TorchJD aggregators (e.g., Sum and MGDA) by re-weighting losses and stacking them into a matrix.
  • Run on both CPU and CUDA (with necessary adjustments for deterministic algorithm settings on CUDA).
  • Overall, the tests now achieve ~96% coverage for this module.

The documentation has been updated with detailed docstrings and a usage example showing how to use GradNormWrapper both standalone and in combination with an aggregator like MGDA.

Changelog and Version Update:

  • The CHANGELOG has been updated in the [Unreleased] section and a new version header ([0.5.1] - 2025-02-10) has been added. The version in pyproject.toml is updated to 0.5.1.

Please review these changes at your convenience. If any further adjustments are needed, feel free to let me know. Once approved, I will merge this PR and update the repository accordingly.

Thank you for your time and consideration!

Best regards,
Emile

Copy link

codecov bot commented Feb 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/torchjd/aggregation/gradvac.py 100.00% <100.00%> (ø)

@ValerianRey
Copy link
Contributor

Thanks for your interest and effort on this PR!

It seems there is a misunderstanding about the role of Aggregators in TorchJD. An Aggregator should be applied to a Jacobian matrix, not a matrix of losses. To properly combine GradNorm with Jacobian-based methods (such as MGDA, UPGrad, etc.), I think the process should be:

  1. Reweight the losses using GradNorm.
  2. Apply torchjd.backward or torchjd.mtl_backward to the reweighted losses. This will compute the Jacobian of these losses and aggregate it using some aggregator.

Do you think your intended use-case would work by following those two steps?

TorchJD is specifically designed to handle step 2, whereas step 1 (reweighting the losses) is independent of the Jacobian descent algorithm. With GradNorm, some gradients (i.e., Jacobian rows) are used to balance the losses, but this balancing also depends on the loss values and some preserved states.

For these reasons, we won’t be able to merge this PR. Additionally, we don’t see TorchJD incorporating methods like GradNorm in the future, as maintaining them falls outside our intended scope and would require more effort than we’re able to commit.

@EmileAydar
Copy link
Author

Dear Valerian,

Thank you for your detailed response. I agree with your explanation regarding the role of Aggregators in TorchJD.
My initial proposal for GradNorm was motivated by the need to adaptively balance losses, particularly in scenarios where newly introduced objectives (such as fairness constraints in multiobjective optimization) can naturally be on a smaller scale compared to the primary loss. The idea was to use GradNorm as a pre-aggregation wrapper to reweight the losses and balance their contribution to ensure a fairer aggregation step.

I understand and respect your decision not to integrate GradNorm given that it is not an aggregation heuristic in itself. On a related note, I have been preparing another pull request that focuses on an aggregator implementation, which I believe is more aligned with TorchJD’s intended design.

Thank you again for your prompt and comprehensive feedback. I appreciate the work you and the team do on TorchJD, and I look forward to contributing further in the future.

Warm regards,
Emile

@EmileAydar
Copy link
Author

---Update---
Following my previous comment, I have commited the files regarding an aggregator implementation I've been working on for TorchJD: GradVac.

GradVac is an aggregator that modifies gradients when the observed cosine similarity between tasks falls below a desired target. This target can be set as a constant or, it can also be computed adaptively using an exponential moving average scheme if the target is left as None.
In contrast, PCGrad addresses gradient conflicts by simply projecting gradients to remove interference. In fact, when the target in GradVac is set to zero, its behavior essentially replicates that of PCGrad, as PCGrad only removes negative dot products.

Would you prefer that I open a new pull request specifically for GradVac, or should I continue the discussion here to propose merging GradVac in place of GradNorm?

Again, thank you for your time and support !

Warm regards,
Emile

@PierreQuinton
Copy link
Contributor

PierreQuinton commented Feb 12, 2025

@EmileAydar Thank you for the proposition, we were unaware of the GradVac aggregator, we might consider implementing it in the future. This will require some extra work on our side, for instance determining the properties it satisfies. In the meantime, it seems that you are facing problems of conflict which PCGrad cannot solve (unless there are two objectives), you can try to train using an aggregator that satisfies the non-conflicting property, you can find the list here, I personally would recommend UPGrad or DualProj as the others may not find Pareto optimal points. Another way to pick your aggregator could be to checkout the trajectories on pages 15-16 of Jacobian descent for multi-objective optimization and select the aggregator corresponding to what you feel would be natural in your problem.

We will not have time to add GradVac to the list of aggregator in short terms, so you can also inherit from Aggregator locally and experiment with it, if you get encouraging results, let us know, in particular if it compares well to all aggregators on your task, then this would be relevant information for us!

@EmileAydar
Copy link
Author

EmileAydar commented Feb 15, 2025

@PierreQuinton Thank you for your detailed reply.

In my experiments, I noticed that :

  • MGDA is particularly sensitive to small gradients (like those from fairness considerations) as opposed to the other non-conflicting aggregators, which sometimes results in numerical instabilities, manifested by exploding or undefined Jacobian values. Similar issues were observed with PCGrad, while UPGrad demonstrated superior stability.
  • For one case, I added a small positive epsilon to the diagonal of NashMTL’s gramian. This ensured the matrix remained positive definite, maintained the hard positive weight constraints, and surprisingly improved its capacity to find balanced Pareto tradeoffs.
    I’d be glad to share more detailed and quantitative results if you think that would be helpful.

I am also exploring EPO aggregators that integrate a stakeholder preference vector, that I think could be of interest for TorchJD users, especially for engineers in business contexts.
For example, PMGDA shows potential, along with approaches like Exact Pareto Optimization
Exact Pareto Optimization and
Pareto Multi-Task Learning

Thanks also for considering GradVac for a future update. Although I understand it might be reimplemented on your side, I’ve developed an extension based on your PCGrad code and will be using it in my experiments. I’m happy to provide detailed feedback on its performance, and an acknowledgment in your documentation would be appreciated if my contribution proves valuable.

@ValerianRey
Copy link
Contributor

Hi @EmileAydar!
Thanks for your work on Gradient Vaccine. It's good to see a first implementation for this!

At the moment, TorchJD only supports stateless (immutable) aggregators. Gradient Vaccine is based on some exponential moving average of the cosine similarities between gradients, which is a state.

We are thinking about adding support for stateful methods in the near future, and Gradient Vaccine could definitely be a good candidate to test this.

So we will keep this PR open for now until we make progress on the stateful structure, and we will get back to it afterwards.

@ValerianRey ValerianRey changed the title GradNorm Loss Balancing Integration in TorchJD Add GradVac Feb 21, 2025
@ValerianRey ValerianRey added feat New feature or request package: aggregation labels Feb 21, 2025
@ValerianRey
Copy link
Contributor

I am also exploring EPO aggregators that integrate a stakeholder preference vector, that I think could be of interest for TorchJD users, especially for engineers in business contexts.
For example, PMGDA shows potential, along with approaches like Exact Pareto Optimization
Exact Pareto Optimization and
Pareto Multi-Task Learning

Thanks for these references!

I think Pareto Multi-Task Learning can't be simply integrated into TorchJD (they don't have the same objective as us).

At a first glance, it seems that PMGDA is based on a stateless aggregator, so it should be possible to integrate it into TorchJD. We would however need a more thourough understanding of its theoretical properties before even trying to integrate it.

As for Exact Pareto Optimization, I still have to read the paper. I'll update this comment when I find time for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat New feature or request package: aggregation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants