Skip to content

Commit 18fd399

Browse files
author
Filip Haltmayer
committed
Reuse GRPC channel for same connections
Signed-off-by: Filip Haltmayer <filip.haltmayer@zilliz.com>
1 parent 256a523 commit 18fd399

File tree

3 files changed

+212
-155
lines changed

3 files changed

+212
-155
lines changed

Diff for: pymilvus/orm/connections.py

+172-148
Original file line numberDiff line numberDiff line change
@@ -78,54 +78,19 @@ def __init__(self):
7878
"""
7979
self._alias = {}
8080
self._connected_alias = {}
81-
self._env_uri = None
81+
self._connection_references = {}
82+
self._con_lock = threading.RLock()
8283

83-
if Config.MILVUS_URI != "":
84-
address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI)
85-
self._env_uri = (address, parsed_uri)
84+
address, user, _, db_name = self.__parse_info(Config.MILVUS_URI)
8685

87-
default_conn_config = {
88-
"user": parsed_uri.username if parsed_uri.username is not None else "",
89-
"address": address,
90-
}
91-
else:
92-
default_conn_config = {
93-
"user": "",
94-
"address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}",
95-
}
86+
default_conn_config = {
87+
"user": user,
88+
"address": address,
89+
"db_name": db_name,
90+
}
9691

9792
self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config})
9893

99-
def __verify_host_port(self, host, port):
100-
if not is_legal_host(host):
101-
raise ConnectionConfigException(message=ExceptionsMessage.HostType)
102-
if not is_legal_port(port):
103-
raise ConnectionConfigException(message=ExceptionsMessage.PortType)
104-
if not 0 <= int(port) < 65535:
105-
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)")
106-
107-
def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult):
108-
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'"
109-
try:
110-
parsed_uri = parse.urlparse(uri)
111-
except (Exception) as e:
112-
raise ConnectionConfigException(
113-
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None
114-
115-
if len(parsed_uri.netloc) == 0:
116-
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None
117-
118-
host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST
119-
port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT
120-
addr = f"{host}:{port}"
121-
122-
self.__verify_host_port(host, port)
123-
124-
if not is_legal_address(addr):
125-
raise ConnectionConfigException(message=illegal_uri_msg.format(uri))
126-
127-
return addr, parsed_uri
128-
12994
def add_connection(self, **kwargs):
13095
""" Configures a milvus connection.
13196
@@ -157,41 +122,24 @@ def add_connection(self, **kwargs):
157122
)
158123
"""
159124
for alias, config in kwargs.items():
160-
addr, _ = self.__get_full_address(
161-
config.get("address", ""),
162-
config.get("uri", ""),
163-
config.get("host", ""),
164-
config.get("port", ""))
125+
address, user, _, db_name = self.__parse_info(**config)
165126

166127
if alias in self._connected_alias:
167-
if self._alias[alias].get("address") != addr:
128+
if (
129+
self._alias[alias].get("address") != address
130+
or self._alias[alias].get("user") != user
131+
or self._alias[alias].get("db_name") != db_name
132+
):
168133
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)
169134

170135
alias_config = {
171-
"address": addr,
172-
"user": config.get("user", ""),
136+
"address": address,
137+
"user": user,
138+
"db_name": db_name,
173139
}
174140

175141
self._alias[alias] = alias_config
176142

177-
def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> (
178-
str, parse.ParseResult):
179-
if address != "":
180-
if not is_legal_address(address):
181-
raise ConnectionConfigException(
182-
message=f"Illegal address: {address}, should be in form 'localhost:19530'")
183-
return address, None
184-
185-
if uri != "":
186-
address, parsed = self.__parse_address_from_uri(uri)
187-
return address, parsed
188-
189-
host = host if host != "" else Config.DEFAULT_HOST
190-
port = port if port != "" else Config.DEFAULT_PORT
191-
self.__verify_host_port(host, port)
192-
193-
return f"{host}:{port}", None
194-
195143
def disconnect(self, alias: str):
196144
""" Disconnects connection from the registry.
197145
@@ -201,8 +149,13 @@ def disconnect(self, alias: str):
201149
if not isinstance(alias, str):
202150
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))
203151

204-
if alias in self._connected_alias:
205-
self._connected_alias.pop(alias).close()
152+
with self._con_lock:
153+
if alias in self._connected_alias:
154+
gh = self._connected_alias.pop(alias)
155+
self._connection_references[id(gh)] -= 1
156+
if self._connection_references[id(gh)] <= 0:
157+
gh.close()
158+
del self._connection_references[id(gh)]
206159

207160
def remove_connection(self, alias: str):
208161
""" Removes connection from the registry.
@@ -266,96 +219,74 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name=
266219
>>> from pymilvus import connections
267220
>>> connections.connect("test", host="localhost", port="19530")
268221
"""
222+
# pylint: disable=too-many-statements
269223

270224
def connect_milvus(**kwargs):
271-
gh = GrpcHandler(**kwargs)
225+
with self._con_lock:
226+
gh = None
227+
for key, connection_details in self._alias.items():
228+
229+
if (
230+
key in self._connected_alias
231+
and connection_details["address"] == kwargs["address"]
232+
and connection_details["user"] == kwargs["user"]
233+
and connection_details["db_name"] == kwargs["db_name"]
234+
):
235+
gh = self._connected_alias[key]
236+
break
272237

273-
t = kwargs.get("timeout")
274-
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
238+
if gh is None:
239+
gh = GrpcHandler(**kwargs)
240+
t = kwargs.get("timeout")
241+
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
242+
gh._wait_for_channel_ready(timeout=timeout)
275243

276-
gh._wait_for_channel_ready(timeout=timeout)
277-
kwargs.pop('password')
278-
kwargs.pop('db_name', None)
279-
kwargs.pop('secure', None)
280-
kwargs.pop("db_name", "")
244+
kwargs.pop('password', None)
245+
kwargs.pop('secure', None)
281246

282-
self._connected_alias[alias] = gh
283-
self._alias[alias] = copy.deepcopy(kwargs)
247+
self._connected_alias[alias] = gh
284248

285-
def with_config(config: Tuple) -> bool:
286-
for c in config:
287-
if c != "":
288-
return True
249+
self._alias[alias] = copy.deepcopy(kwargs)
289250

290-
return False
251+
if id(gh) not in self._connection_references:
252+
self._connection_references[id(gh)] = 1
253+
else:
254+
self._connection_references[id(gh)] += 1
291255

292256
if not isinstance(alias, str):
293257
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))
294258

295-
config = (
296-
kwargs.pop("address", ""),
297-
kwargs.pop("uri", ""),
298-
kwargs.pop("host", ""),
299-
kwargs.pop("port", "")
300-
)
301-
302-
# Make sure passed in None doesnt break
303-
user = user or ""
304-
password = password or ""
305-
# Make sure passed in are Strings
306-
user = str(user)
307-
password = str(password)
308-
309-
# 1st Priority: connection from params
310-
if with_config(config):
311-
in_addr, parsed_uri = self.__get_full_address(*config)
312-
kwargs["address"] = in_addr
313-
314-
if self.has_connection(alias):
315-
if self._alias[alias].get("address") != in_addr:
316-
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)
317-
318-
# uri might take extra info
319-
if parsed_uri is not None:
320-
user = parsed_uri.username if parsed_uri.username is not None else user
321-
password = parsed_uri.password if parsed_uri.password is not None else password
322-
323-
group = parsed_uri.path.split("/")
324-
db_name = "default"
325-
if len(group) > 1:
326-
db_name = group[1]
327-
328-
# Set secure=True if username and password are provided
329-
if len(user) > 0 and len(password) > 0:
330-
kwargs["secure"] = True
331-
259+
address = kwargs.pop("address", "")
260+
uri = kwargs.pop("uri", "")
261+
host = kwargs.pop("host", "")
262+
port = kwargs.pop("port", "")
263+
user = '' if user is None else str(user)
264+
password = '' if password is None else str(password)
265+
db_name = '' if db_name is None else str(db_name)
266+
267+
if set([address, uri, host, port]) != {''}:
268+
address, user, password, db_name = self.__parse_info(address, uri, host, port, db_name, user, password)
269+
kwargs["address"] = address
270+
271+
elif alias in self._alias:
272+
kwargs = dict(self._alias[alias].items())
273+
# If user is passed in, use it, if not, use previous connections user.
274+
prev_user = kwargs.pop("user")
275+
user = user if user != "" else prev_user
276+
# If db_name is passed in, use it, if not, use previous db_name.
277+
prev_db_name = kwargs.pop("db_name")
278+
db_name = db_name if db_name != "" else prev_db_name
332279

333-
connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
334-
return
335-
336-
# 2nd Priority, connection configs from env
337-
if self._env_uri is not None:
338-
addr, parsed_uri = self._env_uri
339-
kwargs["address"] = addr
340-
341-
user = parsed_uri.username if parsed_uri.username is not None else ""
342-
password = parsed_uri.password if parsed_uri.password is not None else ""
343-
# Set secure=True if uri provided user and password
344-
if len(user) > 0 and len(password) > 0:
345-
kwargs["secure"] = True
280+
# No params, env, and cached configs for the alias
281+
else:
282+
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)
346283

347-
connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
348-
return
284+
# Set secure=True if username and password are provided
285+
if len(user) > 0 and len(password) > 0:
286+
kwargs["secure"] = True
349287

350-
# 3rd Priority, connect to cached configs with provided user and password
351-
if alias in self._alias:
352-
connect_alias = dict(self._alias[alias].items())
353-
connect_alias["user"] = user
354-
connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs)
355-
return
288+
connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
356289

357-
# No params, env, and cached configs for the alias
358-
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)
359290

360291
def list_connections(self) -> list:
361292
""" List names of all connections.
@@ -369,7 +300,8 @@ def list_connections(self) -> list:
369300
>>> connections.list_connections()
370301
// TODO [('default', None), ('test', <pymilvus.client.grpc_handler.GrpcHandler object at 0x7f05003f3e80>)]
371302
"""
372-
return [(k, self._connected_alias.get(k, None)) for k in self._alias]
303+
with self._con_lock:
304+
return [(k, self._connected_alias.get(k, None)) for k in self._alias]
373305

374306
def get_connection_addr(self, alias: str):
375307
"""
@@ -414,7 +346,99 @@ def has_connection(self, alias: str) -> bool:
414346
"""
415347
if not isinstance(alias, str):
416348
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))
417-
return alias in self._connected_alias
349+
with self._con_lock:
350+
return alias in self._connected_alias
351+
352+
def __parse_info(
353+
self,
354+
address: str = "",
355+
uri: str = "",
356+
host: str = "",
357+
port: str = "",
358+
db_name: str = "",
359+
user: str = "",
360+
password: str = "",
361+
**kwargs) -> dict:
362+
363+
passed_in_address = ""
364+
passed_in_user = ""
365+
passed_in_password = ""
366+
passed_in_db_name = ""
367+
368+
# If uri
369+
if uri != "":
370+
passed_in_address, passed_in_user, passed_in_password, passed_in_db_name = (
371+
self.__parse_address_from_uri(uri)
372+
)
373+
374+
elif address != "":
375+
if not is_legal_address(address):
376+
raise ConnectionConfigException(
377+
message=f"Illegal address: {address}, should be in form 'localhost:19530'")
378+
passed_in_address = address
379+
380+
else:
381+
if host == "":
382+
host = Config.DEFAULT_HOST
383+
if port == "":
384+
port = Config.DEFAULT_PORT
385+
self.__verify_host_port(host, port)
386+
passed_in_address = f"{host}:{port}"
387+
388+
passed_in_user = user if passed_in_user == "" else str(passed_in_user)
389+
passed_in_user = Config.MILVUS_USER if passed_in_user == "" else str(passed_in_user)
390+
391+
passed_in_password = password if passed_in_password == "" else str(passed_in_password)
392+
passed_in_password = Config.MILVUS_PASSWORD if passed_in_password == "" else str(passed_in_password)
393+
394+
passed_in_db_name = db_name if passed_in_db_name == "" else str(passed_in_db_name)
395+
passed_in_db_name = Config.MILVUS_DB_NAME if passed_in_db_name == "" else str(passed_in_db_name)
396+
397+
return passed_in_address, passed_in_user, passed_in_password, passed_in_db_name
398+
399+
def __verify_host_port(self, host, port):
400+
if not is_legal_host(host):
401+
raise ConnectionConfigException(message=ExceptionsMessage.HostType)
402+
if not is_legal_port(port):
403+
raise ConnectionConfigException(message=ExceptionsMessage.PortType)
404+
if not 0 <= int(port) < 65535:
405+
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)")
406+
407+
def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]:
408+
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'"
409+
try:
410+
parsed_uri = parse.urlparse(uri)
411+
except (Exception) as e:
412+
raise ConnectionConfigException(
413+
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None
414+
415+
if len(parsed_uri.netloc) == 0:
416+
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None
417+
418+
group = parsed_uri.path.split("/")
419+
if len(group) > 1:
420+
db_name = group[1]
421+
else:
422+
db_name = ""
423+
424+
host = parsed_uri.hostname if parsed_uri.hostname is not None else ""
425+
port = parsed_uri.port if parsed_uri.port is not None else ""
426+
user = parsed_uri.username if parsed_uri.username is not None else ""
427+
password = parsed_uri.password if parsed_uri.password is not None else ""
428+
429+
if host == "":
430+
raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}")
431+
if port == "":
432+
raise ConnectionConfigException(message=f"Illegal uri: URI is missing port: {uri}")
433+
434+
self.__verify_host_port(host, port)
435+
addr = f"{host}:{port}"
436+
437+
if not is_legal_address(addr):
438+
raise ConnectionConfigException(message=illegal_uri_msg.format(uri))
439+
440+
return addr, user, password, db_name
441+
418442

419443
def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler:
420444
""" Retrieves a GrpcHandler by alias. """

0 commit comments

Comments
 (0)