Skip to content

Commit c586900

Browse files
committed
adapted start scripts with parameter to download and use sample data instead of randomly generating data for fast demo purposes and increased tensorflow and pillow versions to address security issues
1 parent 25dd721 commit c586900

File tree

9 files changed

+162
-31
lines changed

9 files changed

+162
-31
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,16 @@ Artificial neural networks is a popular field of research in artificial intellig
2121
2. Create a neural network model and process it. An example of this process is given in `examples/process_mnist_model.py` on [MNIST](http://yann.lecun.com/exdb/mnist/) data.
2222
3. Start the visualization tool `start_tool.py` and select the neural network via `Load Processed Network` to render the representation of the neural network.
2323

24+
Or
25+
26+
1. Run `start_tool.py --demo` to download data of an already processed model and render it.
27+
2428
Multiple scripts are located in `examples`, which can be adapted to create and process neural networks. `examples/evaluation_plots.py` for example can be used to recreate the evaluation data and plots of my thesis.
2529

30+
### Sample Model Importance
31+
32+
A processed model can be found [here](https://drive.google.com/file/d/1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD/view?usp=sharing).
33+
2634
## Rendering Tool
2735
The visualization tool `start_tool.py` can be used to render and/or process neural networks. Instead of existing ones, you can also generate random networks and process them of various sizes. For neural networks the visualization results in a more structured view of a neural network in regards to their trained parameters compared to the most common ones.
2836

VR_TOOL.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ This tool can be used to render a processed neural network in VR.
2121

2222
* `python start_tool_vr.py`
2323

24+
Or
25+
26+
* Run `start_tool.py --demo` to download data of an already processed model and render it.
27+
2428
## Controls
2529

2630
Using Oculus Quest 2 controller:
@@ -35,7 +39,7 @@ Using Oculus Quest 2 controller:
3539

3640

3741
### GUI
38-
See [README.md](./README.md)) for information on the desktop GUI
42+
See [README.md](./README.md) for information on the desktop GUI
3943

4044
## Used Systems
4145

configs/processing.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"layer_distance": 0.4,
55
"layer_width": 1.0,
66
"node_bandwidth_reduction": 0.95,
7-
"prune_percentage": 0.0,
7+
"prune_percentage": 0.9,
88
"sampling_rate": 15.0,
99
"smoothing": true,
1010
"smoothing_iterations": 8

configs/window.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
"monitor_id": 0,
66
"screen_height": 900,
77
"screen_width": 1600,
8-
"screen_x": 2765,
9-
"screen_y": 127,
8+
"screen_x": 1644,
9+
"screen_y": 237,
1010
"title": "NNVis Render",
1111
"width": 1600
1212
}

gui/ui_window.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ def save_processed_nn_file(self) -> None:
7979
def open_processed_nn_file(self) -> None:
8080
filename = filedialog.askopenfilename(initialdir=DATA_PATH, title='Select A File',
8181
filetypes=(('processed nn files', '*.pro.npz'),))
82-
data_loader: ProcessedNNHandler = ProcessedNNHandler(filename)
83-
self.settings['network_name'] = ntpath.basename(
84-
filename) + '_processed'
85-
self.update_layer(data_loader.layer_data, processed_nn=data_loader)
82+
if filename != '':
83+
data_loader: ProcessedNNHandler = ProcessedNNHandler(filename)
84+
self.settings['network_name'] = ntpath.basename(
85+
filename) + '_processed'
86+
self.update_layer(data_loader.layer_data, processed_nn=data_loader)
8687

8788
def open_importance_file(self) -> None:
8889
filename = filedialog.askopenfilename(initialdir=DATA_PATH, title='Select A File',

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ glfw==2.0.0
22
matplotlib==3.3.3
33
numpy==1.23.5
44
pandas==1.2.0
5-
Pillow==9.0.1
5+
Pillow==9.3.0
66
progressbar2==3.53.1
77
PyOpenGL==3.1.5
88
pyrr==0.10.3
99
scikit_learn==0.24.0
10-
tensorflow==2.9.2
10+
tensorflow==2.9.3
11+
wget==3.2

requirements_vr.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ matplotlib==3.3.3
33
numpy==1.23.5
44
openvr
55
pandas==1.2.0
6-
Pillow==9.0.1
6+
Pillow==9.3.0
77
progressbar2==3.53.1
88
PyOpenGL==3.1.5
99
pyrr==0.10.3
1010
scikit_learn==0.24.0
11-
tensorflow==2.9.2
11+
tensorflow==2.9.3
12+
wget==3.2

start_tool.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import logging
2+
import ntpath
3+
import os
24
import threading
35
import time
6+
import zipfile
7+
from argparse import ArgumentParser
48
from typing import Optional
59

10+
import wget
611
from OpenGL.GL import GL_MAJOR_VERSION, GL_MINOR_VERSION, glGetIntegerv
712

8-
from definitions import CameraPose
13+
from data.data_handler import ProcessedNNHandler
14+
from definitions import DATA_PATH, CameraPose
915
from gui.constants import StatisticLink
1016
from gui.ui_window import OptionGui
1117
from opengl_helper.screenshot import create_screenshot
@@ -15,13 +21,34 @@
1521
from utility.performance import track_time
1622
from utility.window import Window, WindowHandler
1723

18-
global options_gui
19-
options_gui = OptionGui()
20-
setup_logger('tool')
24+
25+
def download_and_unzip_sample() -> str:
26+
output_directory = DATA_PATH
27+
filename = wget.download(
28+
'https://drive.google.com/uc?export=download&id=1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD', out=output_directory)
29+
zip_filepath = os.path.join(output_directory, filename)
30+
with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
31+
zip_ref.extractall(DATA_PATH)
32+
return os.path.join(DATA_PATH, 'sample_model.npz')
33+
34+
35+
def open_processed_network(option_gui: OptionGui, filename: str) -> None:
36+
data_loader: ProcessedNNHandler = ProcessedNNHandler(filename)
37+
option_gui.processing_config['prune_percentage'] = 0.9
38+
option_gui.processing_setting.set()
39+
option_gui.settings['network_name'] = ntpath.basename(
40+
filename) + '_processed'
41+
option_gui.update_layer(data_loader.layer_data, processed_nn=data_loader)
2142

2243

2344
def compute_render(some_name: str) -> None:
2445
global options_gui
46+
global use_sample
47+
48+
if use_sample:
49+
global sample_filepath
50+
logging.info('Loading sample model...')
51+
open_processed_network(options_gui, sample_filepath)
2552

2653
width, height = 1920, 1200
2754

@@ -72,6 +99,7 @@ def frame() -> None:
7299
if not options_gui.settings['Closed']:
73100
print('Start building network: ' +
74101
str(options_gui.settings['current_layer_data']))
102+
options_gui.settings['update_model'] = False
75103
network_processor = NetworkProcessor(options_gui.settings['current_layer_data'],
76104
options_gui.processing_config,
77105
importance_data=options_gui.settings['importance_data'],
@@ -154,9 +182,39 @@ def frame() -> None:
154182
options_gui.destroy()
155183

156184

157-
compute_render_thread: threading.Thread = threading.Thread(
158-
target=compute_render, args=(1,))
159-
compute_render_thread.setDaemon(True)
160-
compute_render_thread.start()
185+
def parse_args() -> bool:
186+
parser = ArgumentParser(prog='Start nn_vis tool')
187+
parser.add_argument('--demo', action='store_true',
188+
help='Download sample of a processed model and render it with 90% pruned edges instead of generating a random model.')
189+
args = parser.parse_args()
190+
return args.demo
191+
161192

162-
options_gui.start()
193+
if __name__ == '__main__':
194+
global options_gui
195+
options_gui = OptionGui()
196+
197+
global sample_filepath
198+
sample_filepath = 'sample_model.npz'
199+
200+
global use_sample
201+
use_sample = parse_args()
202+
203+
setup_logger('tool')
204+
205+
if use_sample:
206+
expected_sample_path = os.path.join(DATA_PATH, sample_filepath)
207+
if not os.path.exists(expected_sample_path):
208+
logging.info(
209+
f'Downloading sample model to "{expected_sample_path}". This might take a minute ...')
210+
sample_filepath = download_and_unzip_sample()
211+
else:
212+
logging.info(
213+
f'Using sample model at "{expected_sample_path}"')
214+
sample_filepath = expected_sample_path
215+
compute_render_thread: threading.Thread = threading.Thread(
216+
target=compute_render, args=(1,))
217+
compute_render_thread.setDaemon(True)
218+
compute_render_thread.start()
219+
220+
options_gui.start()

start_tool_vr.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import logging
2+
import ntpath
3+
import os
24
import threading
35
import time
6+
import zipfile
7+
from argparse import ArgumentParser
48
from typing import List, Optional, Tuple
59

10+
import wget
611
from OpenGL.GL import GL_MAJOR_VERSION, GL_MINOR_VERSION, glGetIntegerv
712

13+
from data.data_handler import ProcessedNNHandler
14+
from definitions import DATA_PATH
815
from gui.constants import StatisticLink
916
from gui.ui_window import OptionGui
1017
from processing.network_processing import NetworkProcessor
@@ -13,15 +20,36 @@
1320
from utility.performance import track_time
1421
from vr.vr_handler import VRHandler
1522

16-
global options_gui
17-
options_gui = OptionGui()
18-
setup_logger('tool')
19-
2023
RENDER_MODES: List[Tuple[int, int]] = [(3, 2), (4, 1), (1, 1), (2, 2)]
2124

2225

26+
def download_and_unzip_sample() -> str:
27+
output_directory = DATA_PATH
28+
filename = wget.download(
29+
'https://drive.google.com/uc?export=download&id=1LiVzBfB7LPrR95q_VO44wx4MyGNTj6vD', out=output_directory)
30+
zip_filepath = os.path.join(output_directory, filename)
31+
with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
32+
zip_ref.extractall(DATA_PATH)
33+
return os.path.join(DATA_PATH, 'sample_model.npz')
34+
35+
36+
def open_processed_network(option_gui: OptionGui, filename: str) -> None:
37+
data_loader: ProcessedNNHandler = ProcessedNNHandler(filename)
38+
option_gui.processing_config['prune_percentage'] = 0.9
39+
option_gui.processing_setting.set()
40+
option_gui.settings['network_name'] = ntpath.basename(
41+
filename) + '_processed'
42+
option_gui.update_layer(data_loader.layer_data, processed_nn=data_loader)
43+
44+
2345
def compute_render(_: str) -> None:
2446
global options_gui
47+
global use_sample
48+
49+
if use_sample:
50+
global sample_filepath
51+
logging.info('Loading sample model...')
52+
open_processed_network(options_gui, sample_filepath)
2553

2654
FileHandler().read_statistics()
2755

@@ -101,6 +129,7 @@ def frame() -> None:
101129
'Start building network: ' +
102130
str(options_gui.settings['current_layer_data'])
103131
)
132+
options_gui.settings['update_model'] = False
104133
network_processor = NetworkProcessor(
105134
options_gui.settings['current_layer_data'],
106135
options_gui.processing_config,
@@ -191,10 +220,39 @@ def frame() -> None:
191220
options_gui.destroy()
192221

193222

194-
compute_render_thread: threading.Thread = threading.Thread(
195-
target=compute_render, args=(1,)
196-
)
197-
compute_render_thread.setDaemon(True)
198-
compute_render_thread.start()
223+
def parse_args() -> bool:
224+
parser = ArgumentParser(prog='Start nn_vis tool for VR')
225+
parser.add_argument('--demo', action='store_true',
226+
help='Download sample of a processed model and render it with 90% pruned edges instead of generating a random model.')
227+
args = parser.parse_args()
228+
return args.demo
229+
230+
231+
if __name__ == '__main__':
232+
global options_gui
233+
options_gui = OptionGui()
234+
235+
global sample_filepath
236+
sample_filepath = 'sample_model.npz'
237+
238+
global use_sample
239+
use_sample = parse_args()
240+
241+
setup_logger('tool_vr')
242+
243+
if use_sample:
244+
expected_sample_path = os.path.join(DATA_PATH, sample_filepath)
245+
if not os.path.exists(expected_sample_path):
246+
logging.info(
247+
f'Downloading sample model to "{expected_sample_path}". This might take a minute ...')
248+
sample_filepath = download_and_unzip_sample()
249+
else:
250+
logging.info(
251+
f'Using sample model at "{expected_sample_path}"')
252+
sample_filepath = expected_sample_path
253+
compute_render_thread: threading.Thread = threading.Thread(
254+
target=compute_render, args=(1,))
255+
compute_render_thread.setDaemon(True)
256+
compute_render_thread.start()
199257

200-
options_gui.start()
258+
options_gui.start()

0 commit comments

Comments
 (0)