Skip to content

Commit

Permalink
update keras
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Apr 9, 2024
1 parent 4e7583b commit 81110cb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 29 deletions.
21 changes: 0 additions & 21 deletions examples/mnist-keras/bin/get_data

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!./.mnist-keras/bin/python
import json
import os

Expand Down
10 changes: 6 additions & 4 deletions examples/mnist-keras/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python entrypoint init_seed
command: python entrypoint.py init_seed
startup:
command: python get_data.py
train:
command: python entrypoint train $ENTRYPOINT_OPTS
command: python entrypoint.py train $ENTRYPOINT_OPTS
validate:
command: python entrypoint validate $ENTRYPOINT_OPTS
command: python entrypoint.py validate $ENTRYPOINT_OPTS
infer:
command: python entrypoint infer $ENTRYPOINT_OPTS
command: python entrypoint.py infer $ENTRYPOINT_OPTS
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#!./.mnist-keras/bin/python
import os
from math import floor

import fire
import numpy as np
import tensorflow as tf


def splitset(dataset, parts):
Expand Down Expand Up @@ -38,5 +37,17 @@ def split(dataset='data/mnist.npz', outdir='data', n_splits=2):
y_test=data['y_test'][i])


def get_data(out_dir='data'):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
np.savez(f'{out_dir}/mnist.npz', x_train=x_train,
y_train=y_train, x_test=x_test, y_test=y_test)


if __name__ == '__main__':
fire.Fire(split)
get_data()
split()

0 comments on commit 81110cb

Please sign in to comment.