diff --git a/examples/mnist-keras/bin/get_data b/examples/mnist-keras/bin/get_data deleted file mode 100755 index 4c449d03e..000000000 --- a/examples/mnist-keras/bin/get_data +++ /dev/null @@ -1,21 +0,0 @@ -#!./.mnist-keras/bin/python -import os - -import fire -import numpy as np -import tensorflow as tf - - -def get_data(out_dir='data'): - # Make dir if necessary - if not os.path.exists(out_dir): - os.mkdir(out_dir) - - # Download data - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - np.savez(f'{out_dir}/mnist.npz', x_train=x_train, - y_train=y_train, x_test=x_test, y_test=y_test) - - -if __name__ == '__main__': - fire.Fire(get_data) diff --git a/examples/mnist-keras/client/entrypoint b/examples/mnist-keras/client/entrypoint.py similarity index 99% rename from examples/mnist-keras/client/entrypoint rename to examples/mnist-keras/client/entrypoint.py index af0d0271b..37b66ebfc 100755 --- a/examples/mnist-keras/client/entrypoint +++ b/examples/mnist-keras/client/entrypoint.py @@ -1,4 +1,3 @@ -#!./.mnist-keras/bin/python import json import os diff --git a/examples/mnist-keras/client/fedn.yaml b/examples/mnist-keras/client/fedn.yaml index eac55b402..403adfb48 100644 --- a/examples/mnist-keras/client/fedn.yaml +++ b/examples/mnist-keras/client/fedn.yaml @@ -1,10 +1,12 @@ python_env: python_env.yaml entry_points: build: - command: python entrypoint init_seed + command: python entrypoint.py init_seed + startup: + command: python get_data.py train: - command: python entrypoint train $ENTRYPOINT_OPTS + command: python entrypoint.py train $ENTRYPOINT_OPTS validate: - command: python entrypoint validate $ENTRYPOINT_OPTS + command: python entrypoint.py validate $ENTRYPOINT_OPTS infer: - command: python entrypoint infer $ENTRYPOINT_OPTS + command: python entrypoint.py infer $ENTRYPOINT_OPTS diff --git a/examples/mnist-keras/bin/split_data b/examples/mnist-keras/client/get_data.py similarity index 72% rename from examples/mnist-keras/bin/split_data rename to examples/mnist-keras/client/get_data.py index bb583b6d7..28a12bd20 100755 --- a/examples/mnist-keras/bin/split_data +++ b/examples/mnist-keras/client/get_data.py @@ -1,9 +1,8 @@ -#!./.mnist-keras/bin/python import os from math import floor -import fire import numpy as np +import tensorflow as tf def splitset(dataset, parts): @@ -38,5 +37,17 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2): y_test=data['y_test'][i]) +def get_data(out_dir='data'): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Download data + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + np.savez(f'{out_dir}/mnist.npz', x_train=x_train, + y_train=y_train, x_test=x_test, y_test=y_test) + + if __name__ == '__main__': - fire.Fire(split) + get_data() + split()