Effortless management for your ML models.
🚀 Quick start 📦 Examples 📚 Documentation 💬 Get help on Discord 📋 License: Apache 2.0
Lightning Models is a streamlined toolkit for effortlessly saving, loading, and managing your model checkpoints. Designed to simplify the entire model lifecycle—from training and inference to sharing, deployment, and cloud integration—Lightning Models supports any framework that produces model checkpoints, including but not limited to PyTorch Lightning.
✅ Seamless Model Saving & Loading ✅ Robust Checkpoint Management ✅ Cloud Integration Out of the Box ✅ Versatile Across Frameworks
Install Lightning Models via pip (more installation options below):
pip install -U litmodels
Or install directly from source:
pip install https://github.com/Lightning-AI/models/archive/refs/heads/main.zip
Lightning Models offers a simple API to manage your model checkpoints.
Train your model using your preferred framework (our fist examples show scikit-learn
) and then save your best checkpoint with a single function call.
from sklearn import datasets, model_selection, svm
from litmodels import upload_model
# Unique model identifier: <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
# Load example dataset
iris = datasets.load_iris()
X, y = iris.data, iris.target
# Split dataset into training and test sets
X_train, X_test, y_train, y_test = model_selection.train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train a simple SVC model
model = svm.SVC()
model.fit(X_train, y_train)
# Upload the saved model using litmodels
upload_model(model=model, name=MY_MODEL_NAME)
from litmodels import load_model
# Unique model identifier: <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
# Download and load the model file from cloud storage
model = load_model(name=MY_MODEL_NAME, download_dir="my_models")
# Example: run inference with the loaded model
sample_input = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(sample_input)
print(f"Prediction: {prediction}")
Next examples demonstrate seamless PyTorch Lightning integration with Lightning Models.
from lightning import Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}
# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(LitModel())
# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
from lightning import Trainer
from litmodels import download_model
from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>:<model-version>"
class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}
# Load the model from cloud storage
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
print(f"model: {checkpoint_path}")
# Train the model with extended training period
trainer = Trainer(max_epochs=4)
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
Advanced Checkpointing Workflow
Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch. While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
import os
import torch.utils.data as data
import torchvision as tv
from lightning import Callback, Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}
class UploadModelCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Get the best model path from the checkpoint callback
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
if checkpoint_path and os.path.exists(checkpoint_path):
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
trainer = Trainer(
max_epochs=2,
callbacks=[UploadModelCallback()],
)
trainer.fit(
LitModel(),
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)