Skip to content

Commit

Permalink
Feature/SK-852 | Add numpy raw binary helper (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored May 17, 2024
1 parent a327ddc commit 9d7f500
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 39 deletions.
21 changes: 10 additions & 11 deletions fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import traceback

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


class Aggregator(AggregatorBase):
""" Local SGD / Federated Averaging (FedAvg) aggregator. Computes a weighted mean
"""Local SGD / Federated Averaging (FedAvg) aggregator. Computes a weighted mean
of parameter updates.
:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
Expand Down Expand Up @@ -48,8 +50,7 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
nr_aggregated_models = 0
total_examples = 0

logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))
logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.model_updates.empty():
try:
Expand All @@ -61,28 +62,26 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id))
model_next, metadata = self.load_model_update(model_update, helper)

logger.info(
"AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))
logger.info("AGGREGATOR({}): Processing model update {}, metadata: {} ".format(self.name, model_update.model_update_id, metadata))

# Increment total number of examples
total_examples += metadata["num_examples"]

if nr_aggregated_models == 0:
model = model_next
else:
model = helper.increment_average(
model, model_next, metadata["num_examples"], total_examples)
model = helper.increment_average(model, model_next, metadata["num_examples"], total_examples)

nr_aggregated_models += 1
# Delete model from storage
if delete_models:
self.modelservice.temp_model_storage.delete(model_update.model_update_id)
logger.info(
"AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(self.name, model_update.model_update_id))
self.model_updates.task_done()
except Exception as e:
logger.error(
"AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))
tb = traceback.format_exc()
logger.error(f"AGGREGATOR({self.name}): Error encoutered while processing model update: {e}")
logger.error(tb)
self.model_updates.task_done()

data["nr_aggregated_models"] = nr_aggregated_models
Expand Down
11 changes: 0 additions & 11 deletions fedn/utils/helpers/helperbase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import tempfile
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -40,12 +38,3 @@ def load(self, fh):
:return: Weights in array-like format.
"""
pass

def get_tmp_path(self):
"""Return a temporary output path compatible with save_model, load_model.
:return: Path to file.
"""
fd, path = tempfile.mkstemp(suffix=".npz")
os.close(fd)
return path
16 changes: 16 additions & 0 deletions fedn/utils/helpers/plugins/binaryhelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fedn.utils.helpers.plugins.numpyhelper import Helper


class Helper(Helper):
"""FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays."""

def __init__(self):
"""Initialize helper."""
super().__init__()
self.name = "binaryhelper"

def load(self, path, file_type="raw_binary"):
return super().load(path, file_type)

def save(self, model, path=None, file_type="raw_binary"):
return super().save(model, path, file_type)
77 changes: 60 additions & 17 deletions fedn/utils/helpers/plugins/numpyhelper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
import tempfile
from io import BytesIO

import numpy as np

from fedn.utils.helpers.helperbase import HelperBase
Expand Down Expand Up @@ -137,33 +141,72 @@ def ones(self, m1, a):
res.append(np.ones(np.shape(x)) * a)
return res

def save(self, weights, path=None):
def save(self, weights, path=None, file_type="npz"):
"""Serialize weights to file. The serialized model must be a single binary object.
:param weights: List of weights in numpy format.
:param path: Path to file.
:param file_type: File type to save to. Can be 'npz' or 'raw_binary'. Default is 'npz'.
:return: Path to file.
"""
if not path:
path = self.get_tmp_path()
self.check_supported_file_type(file_type)

if file_type == "npz":
if not path:
path = self.get_tmp_path()

weights_dict = {}
for i, w in enumerate(weights):
weights_dict[str(i)] = w

np.savez_compressed(path, **weights_dict)
return path
else:
if not path:
path = self.get_tmp_path(suffix=".bin")
weights = np.concatenate(weights)
weights.tofile(path)
return path

def load(self, path, file_type="npz"):
"""Load weights from file or filelike.
weights_dict = {}
for i, w in enumerate(weights):
weights_dict[str(i)] = w
:param path: file path, filehandle, filelike.
:return: List of weights in numpy format.
"""
self.check_supported_file_type(file_type)
weights = []
if file_type == "npz":
a = np.load(path)
for i in range(len(a.files)):
weights.append(a[str(i)])
else:
if isinstance(path, BytesIO):
a = np.frombuffer(path.read(), dtype=np.float64)
else:
a = np.fromfile(path, dtype=np.float64)
weights.append(a)
return weights

np.savez_compressed(path, **weights_dict)
def get_tmp_path(self, suffix=".npz"):
"""Return a temporary output path compatible with save_model, load_model.
:param suffix: File suffix.
:return: Path to file.
"""
fd, path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
return path

def load(self, fh):
"""Load weights from file or filelike.
def check_supported_file_type(self, file_type):
"""Check if the file type is supported.
:param fh: file path, filehandle, filelike.
:return: List of weights in numpy format.
:param file_type: File type to check.
:type file_type: str
:return: True if supported, False otherwise.
:rtype: bool
"""
a = np.load(fh)

weights = []
for i in range(len(a.files)):
weights.append(a[str(i)])
return weights
supported_file_types = ["npz", "raw_binary"]
if file_type not in supported_file_types:
raise ValueError("File type not supported. Supported types are: {}".format(supported_file_types))
return True

0 comments on commit 9d7f500

Please sign in to comment.