diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 83d01d9ed..dbe378c63 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -21,6 +21,7 @@ jobs: --exclude-dir='.venv' --exclude-dir='.mnist-pytorch' --exclude-dir='.mnist-keras' + --exclude-dir='.huggingface' --exclude-dir='docs' --exclude-dir='flower-client' --exclude='tests.py' diff --git a/examples/huggingface/.dockerignore b/examples/huggingface/.dockerignore new file mode 100644 index 000000000..8ba9024ad --- /dev/null +++ b/examples/huggingface/.dockerignore @@ -0,0 +1,4 @@ +data +seed.npz +*.tgz +*.tar.gz \ No newline at end of file diff --git a/examples/huggingface/.gitignore b/examples/huggingface/.gitignore new file mode 100644 index 000000000..df80c4828 --- /dev/null +++ b/examples/huggingface/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.huggingface +client.yaml \ No newline at end of file diff --git a/examples/huggingface/README.rst b/examples/huggingface/README.rst new file mode 100644 index 000000000..3d5653b7b --- /dev/null +++ b/examples/huggingface/README.rst @@ -0,0 +1,147 @@ +Hugging Face Transformer Example +-------------------------------- + +This is an example project that demonstrates how one can make use of the Hugging Face Transformers library in FEDn. +In this example, a pre-trained BERT-tiny model from Hugging Face is fine-tuned to perform spam detection +on the Enron spam email dataset. + +Email communication often contains personal and sensitive information, and privacy regulations make it +impossible to collect the data to a central storage for model training. +Federated learning is a privacy preserving machine learning technique that enables the training of models on decentralized data sources. +Fine-tuning large language models (LLMs) on various data sources enhances both accuracy and generalizability. +In this example, the Enron email spam dataset is split among two clients. The BERT-tiny model is fine-tuned on the client data using +federated learning to predict whether an email is spam or not. +Execute the following steps to run the example: + +Prerequisites +------------- + +Using FEDn Studio: + +- `Python 3.8, 3.9, 3.10 or 3.11 `__ +- `A FEDn Studio account `__ + +If using pseudo-distributed mode with docker-compose: + +- `Docker `__ +- `Docker Compose `__ + +Creating the compute package and seed model +------------------------------------------- + +Install fedn: + +.. code-block:: + + pip install fedn + +Clone this repository, then locate into this directory: + +.. code-block:: + + git clone https://github.com/scaleoutsystems/fedn.git + cd fedn/examples/huggingface + +Create the compute package: + +.. code-block:: + + fedn package create --path client + +This should create a file 'package.tgz' in the project folder. + +Next, generate a seed model (the first model in a global model trail): + +.. code-block:: + + fedn run build --path client + +This will create a seed model called 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv). + + + +Using FEDn Studio (recommended) +------------------------------- + +Follow the instructions to register for FEDN Studio and start a project (https://fedn.readthedocs.io/en/stable/studio.html). + +In your Studio project: + +- Go to the 'Sessions' menu, click on 'New session', and upload the compute package (package.tgz) and seed model (seed.npz). +- In the 'Clients' menu, click on 'Connect client' and download the client configuration file (client.yaml) +- Save the client configuration file to the huggingface example directory (fedn/examples/huggingface) + +To connect a client, run the following command in your terminal: + +.. code-block:: + + fedn client start -in client.yaml --secure=True --force-ssl + + +Alternatively, if you prefer to use Docker, run the following: + +.. code-block:: + + docker run \ + -v $PWD/client.yaml:/app/client.yaml \ + -e CLIENT_NUMBER=0 \ + -e FEDN_PACKAGE_EXTRACT_DIR=package \ + ghcr.io/scaleoutsystems/fedn/fedn:0.9.0 client start -in client.yaml --secure=True --force-ssl + + +Running the example +------------------- + +After everything is set up, go to 'Sessions' and click on 'New Session'. Click on 'Start run' and the example +will execute. You can follow the training progress on 'Events' and 'Models', where you can view the calculated metrics. + + + +Running FEDn in local development mode: +--------------------------------------- + +Create the compute package and seed model as explained above. Then run the following command: + + +.. code-block:: + + docker-compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + up + + +This starts up local services for MongoDB, Minio, the API Server, one Combiner and two clients. You can verify the deployment using these urls: + +- API Server: http://localhost:8092/get_controller_status +- Minio: http://localhost:9000 +- Mongo Express: http://localhost:8081 + + +Upload the package and seed model to FEDn controller using the APIClient: + +.. code-block:: + + from fedn import APIClient + client = APIClient(host="localhost", port=8092) + client.set_active_package("package.tgz", helper="numpyhelper") + client.set_active_model("seed.npz") + + +You can now start a training session with 5 rounds (default) using the API client: + +.. code-block:: + + client.start_session() + +Clean up +-------- + +You can clean up by running + +.. code-block:: + + docker-compose \ + -f ../../docker-compose.yaml \ + -f docker-compose.override.yaml \ + down -v diff --git a/examples/huggingface/client/data.py b/examples/huggingface/client/data.py new file mode 100644 index 000000000..dcbc0d167 --- /dev/null +++ b/examples/huggingface/client/data.py @@ -0,0 +1,79 @@ +import os +from math import floor + +import torch +from datasets import load_dataset + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def load_data(data_path=None, is_train=True): + if data_path is None: + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/enron_spam.pt") + data = torch.load(data_path) + if is_train: + X = data["X_train"] + y = data["y_train"] + else: + X = data["X_test"] + y = data["y_test"] + return X, y + + +def splitset(dataset, parts): + n = len(dataset) + local_n = floor(n / parts) + result = [] + for i in range(parts): + result.append(dataset[i * local_n : (i + 1) * local_n]) + return result + + +def split(out_dir="data", n_splits=2): + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.makedirs(f"{out_dir}/clients") + + dataset = load_dataset("SetFit/enron_spam") + train_data = dataset["train"].to_pandas() + test_data = dataset["test"].to_pandas() + + X_train = train_data["text"].to_numpy() + y_train = train_data["label"].to_numpy() + X_test = test_data["text"].to_numpy() + y_test = test_data["label"].to_numpy() + + # Reduce data size for faster training + X_train = X_train[:5000] + y_train = y_train[:5000] + X_test = X_test[:700] + y_test = y_test[:700] + + data = { + "X_train": splitset(X_train, n_splits), + "y_train": splitset(y_train, n_splits), + "X_test": splitset(X_test, n_splits), + "y_test": splitset(y_test, n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save( + { + "X_train": data["X_train"][i], + "y_train": data["y_train"][i], + "X_test": data["X_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/enron_spam.pt", + ) + + +if __name__ == "__main__": + # Prepare data if not already done + if not os.path.exists(abs_path + "/data/clients/1"): + split() diff --git a/examples/huggingface/client/fedn.yaml b/examples/huggingface/client/fedn.yaml new file mode 100644 index 000000000..b05504102 --- /dev/null +++ b/examples/huggingface/client/fedn.yaml @@ -0,0 +1,10 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py \ No newline at end of file diff --git a/examples/huggingface/client/model.py b/examples/huggingface/client/model.py new file mode 100644 index 000000000..95b924f31 --- /dev/null +++ b/examples/huggingface/client/model.py @@ -0,0 +1,66 @@ +import collections + +import torch +from transformers import AutoModelForSequenceClassification + +from fedn.utils.helpers.helpers import get_helper + +MODEL = "google/bert_uncased_L-2_H-128_A-2" +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +def compile_model(): + """Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=2) + return model + + +def save_parameters(model, out_path): + """Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict( + {key: torch.tensor(x) for key, x in params_dict} + ) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed("../seed.npz") diff --git a/examples/huggingface/client/python_env.yaml b/examples/huggingface/client/python_env.yaml new file mode 100644 index 000000000..7f4c4afc3 --- /dev/null +++ b/examples/huggingface/client/python_env.yaml @@ -0,0 +1,11 @@ +name: huggingface +build_dependencies: + - pip + - setuptools + - wheel==0.37.1 +dependencies: + - torch==2.2.1 + - torchvision==0.17.1 + - fedn==0.9.0 + - transformers==4.39.3 + - datasets==2.19.0 \ No newline at end of file diff --git a/examples/huggingface/client/train.py b/examples/huggingface/client/train.py new file mode 100644 index 000000000..c357ce351 --- /dev/null +++ b/examples/huggingface/client/train.py @@ -0,0 +1,121 @@ +import os +import sys + +import torch +from data import load_data +from model import load_parameters, save_parameters +from torch.utils.data import DataLoader +from transformers import AdamW, AutoTokenizer + +from fedn.utils.helpers.helpers import save_metadata + +MODEL = "google/bert_uncased_L-2_H-128_A-2" + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class SpamDataset(torch.utils.data.Dataset): + def __init__(self, encodings, labels): + self.encodings = encodings + self.labels = labels + + def __getitem__(self, idx): + item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + item["labels"] = torch.tensor(self.labels[idx]) + return item + + def __len__(self): + return len(self.labels) + + +def preprocess(text): + """Preprocesses text input. + + :param text: The text to preprocess. + :type text: str + """ + text = text.lower() + text = text.replace("\n", " ") + return text + + +def train( + in_model_path, out_model_path, data_path=None, batch_size=16, epochs=1, lr=5e-5 +): + """Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + X_train, y_train = load_data(data_path, is_train=True) + + # preprocess + X_train = [preprocess(text) for text in X_train] + + # encode + tokenizer = AutoTokenizer.from_pretrained(MODEL) + train_encodings = tokenizer( + X_train, truncation=True, padding="max_length", max_length=512 + ) + train_dataset = SpamDataset(train_encodings, y_train) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.to(device) + model.train() + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + optim = AdamW(model.parameters(), lr=lr) + criterion = torch.nn.CrossEntropyLoss() + + for epoch in range(epochs): + for batch in train_loader: + optim.zero_grad() + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = batch["labels"].to(device) + + outputs = model(input_ids, attention_mask) + + loss = criterion(outputs.logits, labels) + print("loss: ", loss.item()) + loss.backward() + optim.step() + + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + "num_examples": len(train_dataset), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/huggingface/client/validate.py b/examples/huggingface/client/validate.py new file mode 100644 index 000000000..29ee9a4dd --- /dev/null +++ b/examples/huggingface/client/validate.py @@ -0,0 +1,105 @@ +import os +import sys + +import torch +from data import load_data +from model import load_parameters +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from fedn.utils.helpers.helpers import save_metrics + +MODEL = "google/bert_uncased_L-2_H-128_A-2" + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +class SpamDataset(torch.utils.data.Dataset): + def __init__(self, encodings, labels): + self.encodings = encodings + self.labels = labels + + def __getitem__(self, idx): + item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + item["labels"] = torch.tensor(self.labels[idx]) + return item + + def __len__(self): + return len(self.labels) + + +def preprocess(text): + text = text.lower() + text = text.replace("\n", " ") + return text + + +def validate(in_model_path, out_json_path, data_path=None): + """Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + # Load data + X_train, y_train = load_data(data_path, is_train=True) + X_test, y_test = load_data(data_path, is_train=False) + + # preprocess + X_test = [preprocess(text) for text in X_test] + X_train = [preprocess(text) for text in X_train] + + # test dataset + tokenizer = AutoTokenizer.from_pretrained(MODEL) + test_encodings = tokenizer(X_test, truncation=True, padding="max_length", max_length=512) + test_dataset = SpamDataset(test_encodings, y_test) + + # Load model + model = load_parameters(in_model_path) + model.eval() + + test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + criterion = torch.nn.CrossEntropyLoss() + + # test set validation + with torch.no_grad(): + correct = 0 + total_loss = 0 + total = 0 + for batch in test_loader: + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = batch["labels"].to(device) + + outputs = model(input_ids, attention_mask=attention_mask) + _, predicted = torch.max(outputs.logits, dim=1) # index of the max logit + + total += labels.size(0) + correct += (predicted == labels).sum().item() + loss = criterion(outputs.logits, labels) + total_loss += loss.item() * labels.size(0) + + test_accuracy = correct / total + print(f"Accuracy: {test_accuracy * 100:.2f}%") + + test_loss = total_loss / total + print("test loss: ", test_loss) + + # JSON schema + report = { + "test_loss": test_loss, + "test_accuracy": test_accuracy, + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2]) diff --git a/examples/huggingface/docker-compose.override.yaml b/examples/huggingface/docker-compose.override.yaml new file mode 100644 index 000000000..3d3a57647 --- /dev/null +++ b/examples/huggingface/docker-compose.override.yaml @@ -0,0 +1,35 @@ +# Compose schema version +version: '3.4' + +# Overriding requirements + +x-env: &defaults + GET_HOSTS_FROM: dns + FEDN_PACKAGE_EXTRACT_DIR: package + FEDN_NUM_DATA_SPLITS: 2 + +services: + + client1: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/1/enron_spam.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + + client2: + extends: + file: ${HOST_REPO_DIR:-.}/docker-compose.yaml + service: client + environment: + <<: *defaults + FEDN_DATA_PATH: /app/package/client/data/clients/2/enron_spam.pt + deploy: + replicas: 1 + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn