forked from kakaoenterprise/JORLDY
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocess.py
96 lines (90 loc) · 2.89 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import traceback
import time
from threading import Thread
# Interact (for async distributed train)
def interact_process(
DistributedManager,
distributed_manager_config,
trans_queue,
sync_queue,
run_step,
update_period,
):
distributed_manager = DistributedManager(*distributed_manager_config)
num_workers = distributed_manager.num_workers
step = 0
try:
while step < run_step:
transitions = distributed_manager.run(update_period)
delta_t = len(transitions) / num_workers
step += delta_t
trans_queue.put((int(step), transitions))
if sync_queue.full():
distributed_manager.sync(sync_queue.get())
while trans_queue.full():
time.sleep(0.1)
except Exception as e:
traceback.print_exc()
finally:
distributed_manager.terminate()
# Manage
def manage_process(
Agent,
agent_config,
result_queue,
sync_queue,
path_queue,
run_step,
print_period,
MetricManager,
EvalManager,
eval_manager_config,
LogManager,
log_manager_config,
config_manager,
):
agent = Agent(**agent_config)
eval_manager = EvalManager(*eval_manager_config)
metric_manager = MetricManager()
log_manager = LogManager(*log_manager_config)
path_queue.put(log_manager.path)
config_manager.dump(log_manager.path)
step, print_stamp, eval_thread = 0, 0, None
try:
while step < run_step:
wait = True
while wait or not result_queue.empty():
_step, result = result_queue.get()
metric_manager.append(result)
wait = False
print_stamp += _step - step
step = _step
if print_stamp >= print_period or step >= run_step:
if (
eval_thread is None
or not eval_thread.is_alive()
or step >= run_step
):
if eval_thread is not None:
eval_thread.join()
agent.sync_in(**sync_queue.get())
statistics = metric_manager.get_statistics()
eval_thread = Thread(
target=evaluate_thread,
args=(agent, step, statistics, eval_manager, log_manager),
)
eval_thread.start()
print_stamp = 0
except Exception as e:
traceback.print_exc()
if eval_thread is not None:
eval_thread.terminate()
finally:
if eval_thread is not None:
eval_thread.join()
# Evaluate
def evaluate_thread(agent, step, statistics, eval_manager, log_manager):
score, frames = eval_manager.evaluate(agent, step)
statistics["score"] = score
print(f"Step : {step} / {statistics}")
log_manager.write(statistics, frames, step)