Skip to content

Model Exporting (.pt2) #2732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
isaaccorley opened this issue Apr 15, 2025 · 5 comments
Open

Model Exporting (.pt2) #2732

isaaccorley opened this issue Apr 15, 2025 · 5 comments
Assignees
Labels
models Models and pretrained weights

Comments

@isaaccorley
Copy link
Collaborator

isaaccorley commented Apr 15, 2025

Summary

It's common in production environments to export a torch model to a file which can be loaded without the need for the model code (only need the checkpoint). We should consider storing our weights in HuggingFace with additional versions which are exported to .pt2 archive format. If I understand correctly, this gets around the security issue with pickling, and is also capable of storing metadata like transforms and other hyperparameters within the .pt2 archive.

A common example of this is:

import torch
from torchgeo.models import Unet_Weights, unet

weights = Unet_Weights.SENTINEL2_3CLASS_FTW
model = unet(weights=weights)

args = (torch.randn(1, 8, 256, 256),)
exported = torch.export.export(model, args=args)
torch.export.save(exported, 'model.pt2')

Then the model can be loaded with only torch as a dependency (no model code or dependencies needed!) like so:

import torch

model = torch.export.load("model.pt2").module()
x = torch.randn(1, 8, 256, 256)
print(model(x).shape)  # (1, 3, 256, 256)

CC @rbavery @ljstrnadiii @jiayuasu @calebrob6 @adamjstewart

@isaaccorley isaaccorley self-assigned this Apr 15, 2025
@adamjstewart
Copy link
Collaborator

I really like this in general. I am marginally concerned about this though:

Image

If this is something Wherobots needs, I'm inclined to move forward with it. However, if the main motivation is to get around security issues with pickling, note that torch.hub.load_state_dict_from_url has the following option which gets around some (but not all) security issues:

weights_only (bool, optional) – If True, only weights will be loaded and no complex pickled objects. Recommended for untrusted sources. See load() for more details.

@adamjstewart adamjstewart added the models Models and pretrained weights label Apr 16, 2025
@isaaccorley
Copy link
Collaborator Author

This is mainly needed for our production environments. Note that I can simply create a script that loops through all torchgeo pretrained models and exports them to .pt2 format. We don't need to replace the .pt files in HuggingFace only support an additional option -- I can take on this work of managing them in HuggingFace.

@rbavery
Copy link

rbavery commented Apr 16, 2025

Ditto what Isaac said! Some additional reasons .pt2 is really helpful for us and others cataloguing models and using them for inference:

  • We can store GeoAI specific metadata in a predictable way. We'd like to define a regular way to store STAC MLM metadata within a .pt2 archive. This would probably go in the .pt2 's extra/ directory. We could define that the MLM JSON items should be stored in a JSON file with a .mlm extension for discoverability. Or require a name like mlm_item_<model_name>.json
  • We can store additional metadata about data provenance and other domain specific metadata in addition to storing metadata about model transforms and hyperparameters.
  • We'd like to support a single archive that packages weights, inference optimized model artifacts, and metadata in a standardized way
  • there are some other feature of the .pt2 spec that could be useful to us later, like storing model server request and response objects and sample data inputs for reproducibility/documentation

On the breaking changes callout, I expect the spec and APIs for torch.export to evolve, partly with feedback from users like us and TorchGeo. But I think we could adapt our internal use and implementation of torch.export without user facing changes in TorchGeo.

@adamjstewart
Copy link
Collaborator

Alright, I would say let's move forward with it. We can still rehost all weights on TorchGeo's HF. Note that we've been trying to follow the following naming convention and I would like to keep this:

check_hash (bool, optional) – If True, the filename part of the URL should follow the naming convention filename-<sha256>.ext where <sha256> is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False

@rbavery
Copy link

rbavery commented Apr 22, 2025

I spoke with Angela Yi from the Pytorch team in the Pytorch slack about improving support for storing nn.Module transforms from kornia in the same .pt2 archive as the model. This would allow for loading models and inference-only transforms (or any kind of transforms) together as a single callable. Would make it easier to immediately use the model without figuring out how to run the correct transforms. She said she'll look into this and that it seems like a generally useful feature for Pytorch. cc @isaaccorley

For now I talked with Isaac who is working on making the .pt2 archives and we can skip storing transforms in .pt2, then add them back when/if that feature becomes available upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights
Projects
None yet
Development

No branches or pull requests

3 participants