-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathmultinode_runner.py
127 lines (105 loc) · 4.19 KB
/
multinode_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from multiprocessing import Pipe, Process
from multiprocessing import connection as mp_connection
import click
import fabric
from .hostinfo import HostInfo, HostInfoList
def run_on_host(
hostinfo: HostInfo,
workdir: str,
recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection,
env: dict,
) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
Args:
hostinfo (HostInfo): host information
workdir (str): the directory to execute the command
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
send_conn (multiprocessing.connection.Connection): send messages to the master receiver
env (dict): a dictionary for environment variables
"""
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
env_msg = " ".join([f"{k}='{v}'" for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
if cmds == "exit":
# exit from the loop
finish = True
break
else:
# execute the commands
try:
# cd to execute directory
with fab_conn.cd(workdir):
# propagate the runtime environment
with fab_conn.prefix(f"export {env_msg}"):
if hostinfo.is_local_host:
# execute on the local machine
fab_conn.local(cmds, hide=False)
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
send_conn.send("failure")
# shutdown
send_conn.send("finish")
fab_conn.close()
class MultiNodeRunner:
"""
A runner to execute commands on an array of machines. This runner
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
"""
def __init__(self):
self.processes = {}
self.master_send_conns = {}
self.master_recv_conns = {}
def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
"""
Establish connections to a list of hosts
Args:
host_info_list (HostInfoList): a list of HostInfo objects
workdir (str): the directory where command is executed
env (dict): environment variables to propagate to hosts
"""
for hostinfo in host_info_list:
master_send_conn, worker_recv_conn = Pipe()
master_recv_conn, worker_send_conn = Pipe()
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
p.start()
self.processes[hostinfo.hostname] = p
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
self.master_send_conns[hostinfo.hostname] = master_send_conn
def send(self, hostinfo: HostInfo, cmd: str) -> None:
"""
Send a command to a local/remote host.
Args:
hostinfo (HostInfo): host information
cmd (str): the command to execute
"""
assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
def stop_all(self) -> None:
"""
Stop connections to all hosts.
"""
for hostname, conn in self.master_send_conns.items():
conn.send("exit")
def recv_from_all(self) -> dict:
"""
Receive messages from all hosts
Returns:
msg_from_node (dict): a dictionary which contains messages from each node
"""
msg_from_node = dict()
for hostname, conn in self.master_recv_conns.items():
msg_from_node[hostname] = conn.recv()
return msg_from_node