@@ -78,54 +78,19 @@ def __init__(self):
78
78
"""
79
79
self ._alias = {}
80
80
self ._connected_alias = {}
81
- self ._env_uri = None
81
+ self ._connection_references = {}
82
+ self ._con_lock = threading .RLock ()
82
83
83
- if Config .MILVUS_URI != "" :
84
- address , parsed_uri = self .__parse_address_from_uri (Config .MILVUS_URI )
85
- self ._env_uri = (address , parsed_uri )
84
+ address , user , _ , db_name = self .__parse_info (Config .MILVUS_URI )
86
85
87
- default_conn_config = {
88
- "user" : parsed_uri .username if parsed_uri .username is not None else "" ,
89
- "address" : address ,
90
- }
91
- else :
92
- default_conn_config = {
93
- "user" : "" ,
94
- "address" : f"{ Config .DEFAULT_HOST } :{ Config .DEFAULT_PORT } " ,
95
- }
86
+ default_conn_config = {
87
+ "user" : user ,
88
+ "address" : address ,
89
+ "db_name" : db_name ,
90
+ }
96
91
97
92
self .add_connection (** {Config .MILVUS_CONN_ALIAS : default_conn_config })
98
93
99
- def __verify_host_port (self , host , port ):
100
- if not is_legal_host (host ):
101
- raise ConnectionConfigException (message = ExceptionsMessage .HostType )
102
- if not is_legal_port (port ):
103
- raise ConnectionConfigException (message = ExceptionsMessage .PortType )
104
- if not 0 <= int (port ) < 65535 :
105
- raise ConnectionConfigException (message = f"port number { port } out of range, valid range [0, 65535)" )
106
-
107
- def __parse_address_from_uri (self , uri : str ) -> (str , parse .ParseResult ):
108
- illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'"
109
- try :
110
- parsed_uri = parse .urlparse (uri )
111
- except (Exception ) as e :
112
- raise ConnectionConfigException (
113
- message = f"{ illegal_uri_msg .format (uri )} : <{ type (e ).__name__ } , { e } >" ) from None
114
-
115
- if len (parsed_uri .netloc ) == 0 :
116
- raise ConnectionConfigException (message = f"{ illegal_uri_msg .format (uri )} " ) from None
117
-
118
- host = parsed_uri .hostname if parsed_uri .hostname is not None else Config .DEFAULT_HOST
119
- port = parsed_uri .port if parsed_uri .port is not None else Config .DEFAULT_PORT
120
- addr = f"{ host } :{ port } "
121
-
122
- self .__verify_host_port (host , port )
123
-
124
- if not is_legal_address (addr ):
125
- raise ConnectionConfigException (message = illegal_uri_msg .format (uri ))
126
-
127
- return addr , parsed_uri
128
-
129
94
def add_connection (self , ** kwargs ):
130
95
""" Configures a milvus connection.
131
96
@@ -157,41 +122,24 @@ def add_connection(self, **kwargs):
157
122
)
158
123
"""
159
124
for alias , config in kwargs .items ():
160
- addr , _ = self .__get_full_address (
161
- config .get ("address" , "" ),
162
- config .get ("uri" , "" ),
163
- config .get ("host" , "" ),
164
- config .get ("port" , "" ))
125
+ address , user , _ , db_name = self .__parse_info (** config )
165
126
166
127
if alias in self ._connected_alias :
167
- if self ._alias [alias ].get ("address" ) != addr :
128
+ if (
129
+ self ._alias [alias ].get ("address" ) != address
130
+ or self ._alias [alias ].get ("user" ) != user
131
+ or self ._alias [alias ].get ("db_name" ) != db_name
132
+ ):
168
133
raise ConnectionConfigException (message = ExceptionsMessage .ConnDiffConf % alias )
169
134
170
135
alias_config = {
171
- "address" : addr ,
172
- "user" : config .get ("user" , "" ),
136
+ "address" : address ,
137
+ "user" : user ,
138
+ "db_name" : db_name ,
173
139
}
174
140
175
141
self ._alias [alias ] = alias_config
176
142
177
- def __get_full_address (self , address : str = "" , uri : str = "" , host : str = "" , port : str = "" ) -> (
178
- str , parse .ParseResult ):
179
- if address != "" :
180
- if not is_legal_address (address ):
181
- raise ConnectionConfigException (
182
- message = f"Illegal address: { address } , should be in form 'localhost:19530'" )
183
- return address , None
184
-
185
- if uri != "" :
186
- address , parsed = self .__parse_address_from_uri (uri )
187
- return address , parsed
188
-
189
- host = host if host != "" else Config .DEFAULT_HOST
190
- port = port if port != "" else Config .DEFAULT_PORT
191
- self .__verify_host_port (host , port )
192
-
193
- return f"{ host } :{ port } " , None
194
-
195
143
def disconnect (self , alias : str ):
196
144
""" Disconnects connection from the registry.
197
145
@@ -201,8 +149,13 @@ def disconnect(self, alias: str):
201
149
if not isinstance (alias , str ):
202
150
raise ConnectionConfigException (message = ExceptionsMessage .AliasType % type (alias ))
203
151
204
- if alias in self ._connected_alias :
205
- self ._connected_alias .pop (alias ).close ()
152
+ with self ._con_lock :
153
+ if alias in self ._connected_alias :
154
+ gh = self ._connected_alias .pop (alias )
155
+ self ._connection_references [id (gh )] -= 1
156
+ if self ._connection_references [id (gh )] <= 0 :
157
+ gh .close ()
158
+ del self ._connection_references [id (gh )]
206
159
207
160
def remove_connection (self , alias : str ):
208
161
""" Removes connection from the registry.
@@ -266,96 +219,74 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name=
266
219
>>> from pymilvus import connections
267
220
>>> connections.connect("test", host="localhost", port="19530")
268
221
"""
222
+ # pylint: disable=too-many-statements
269
223
270
224
def connect_milvus (** kwargs ):
271
- gh = GrpcHandler (** kwargs )
225
+ with self ._con_lock :
226
+ gh = None
227
+ for key , connection_details in self ._alias .items ():
228
+
229
+ if (
230
+ key in self ._connected_alias
231
+ and connection_details ["address" ] == kwargs ["address" ]
232
+ and connection_details ["user" ] == kwargs ["user" ]
233
+ and connection_details ["db_name" ] == kwargs ["db_name" ]
234
+ ):
235
+ gh = self ._connected_alias [key ]
236
+ break
272
237
273
- t = kwargs .get ("timeout" )
274
- timeout = t if isinstance (t , (int , float )) else Config .MILVUS_CONN_TIMEOUT
238
+ if gh is None :
239
+ gh = GrpcHandler (** kwargs )
240
+ t = kwargs .get ("timeout" )
241
+ timeout = t if isinstance (t , (int , float )) else Config .MILVUS_CONN_TIMEOUT
242
+ gh ._wait_for_channel_ready (timeout = timeout )
275
243
276
- gh ._wait_for_channel_ready (timeout = timeout )
277
- kwargs .pop ('password' )
278
- kwargs .pop ('db_name' , None )
279
- kwargs .pop ('secure' , None )
280
- kwargs .pop ("db_name" , "" )
244
+ kwargs .pop ('password' , None )
245
+ kwargs .pop ('secure' , None )
281
246
282
- self ._connected_alias [alias ] = gh
283
- self ._alias [alias ] = copy .deepcopy (kwargs )
247
+ self ._connected_alias [alias ] = gh
284
248
285
- def with_config (config : Tuple ) -> bool :
286
- for c in config :
287
- if c != "" :
288
- return True
249
+ self ._alias [alias ] = copy .deepcopy (kwargs )
289
250
290
- return False
251
+ if id (gh ) not in self ._connection_references :
252
+ self ._connection_references [id (gh )] = 1
253
+ else :
254
+ self ._connection_references [id (gh )] += 1
291
255
292
256
if not isinstance (alias , str ):
293
257
raise ConnectionConfigException (message = ExceptionsMessage .AliasType % type (alias ))
294
258
295
- config = (
296
- kwargs .pop ("address" , "" ),
297
- kwargs .pop ("uri" , "" ),
298
- kwargs .pop ("host" , "" ),
299
- kwargs .pop ("port" , "" )
300
- )
301
-
302
- # Make sure passed in None doesnt break
303
- user = user or ""
304
- password = password or ""
305
- # Make sure passed in are Strings
306
- user = str (user )
307
- password = str (password )
308
-
309
- # 1st Priority: connection from params
310
- if with_config (config ):
311
- in_addr , parsed_uri = self .__get_full_address (* config )
312
- kwargs ["address" ] = in_addr
313
-
314
- if self .has_connection (alias ):
315
- if self ._alias [alias ].get ("address" ) != in_addr :
316
- raise ConnectionConfigException (message = ExceptionsMessage .ConnDiffConf % alias )
317
-
318
- # uri might take extra info
319
- if parsed_uri is not None :
320
- user = parsed_uri .username if parsed_uri .username is not None else user
321
- password = parsed_uri .password if parsed_uri .password is not None else password
322
-
323
- group = parsed_uri .path .split ("/" )
324
- db_name = "default"
325
- if len (group ) > 1 :
326
- db_name = group [1 ]
327
-
328
- # Set secure=True if username and password are provided
329
- if len (user ) > 0 and len (password ) > 0 :
330
- kwargs ["secure" ] = True
331
-
259
+ address = kwargs .pop ("address" , "" )
260
+ uri = kwargs .pop ("uri" , "" )
261
+ host = kwargs .pop ("host" , "" )
262
+ port = kwargs .pop ("port" , "" )
263
+ user = '' if user is None else str (user )
264
+ password = '' if password is None else str (password )
265
+ db_name = '' if db_name is None else str (db_name )
266
+
267
+ if set ([address , uri , host , port ]) != {'' }:
268
+ address , user , password , db_name = self .__parse_info (address , uri , host , port , db_name , user , password )
269
+ kwargs ["address" ] = address
270
+
271
+ elif alias in self ._alias :
272
+ kwargs = dict (self ._alias [alias ].items ())
273
+ # If user is passed in, use it, if not, use previous connections user.
274
+ prev_user = kwargs .pop ("user" )
275
+ user = user if user != "" else prev_user
276
+ # If db_name is passed in, use it, if not, use previous db_name.
277
+ prev_db_name = kwargs .pop ("db_name" )
278
+ db_name = db_name if db_name != "" else prev_db_name
332
279
333
- connect_milvus (** kwargs , user = user , password = password , db_name = db_name )
334
- return
335
-
336
- # 2nd Priority, connection configs from env
337
- if self ._env_uri is not None :
338
- addr , parsed_uri = self ._env_uri
339
- kwargs ["address" ] = addr
340
-
341
- user = parsed_uri .username if parsed_uri .username is not None else ""
342
- password = parsed_uri .password if parsed_uri .password is not None else ""
343
- # Set secure=True if uri provided user and password
344
- if len (user ) > 0 and len (password ) > 0 :
345
- kwargs ["secure" ] = True
280
+ # No params, env, and cached configs for the alias
281
+ else :
282
+ raise ConnectionConfigException (message = ExceptionsMessage .ConnLackConf % alias )
346
283
347
- connect_milvus (** kwargs , user = user , password = password , db_name = db_name )
348
- return
284
+ # Set secure=True if username and password are provided
285
+ if len (user ) > 0 and len (password ) > 0 :
286
+ kwargs ["secure" ] = True
349
287
350
- # 3rd Priority, connect to cached configs with provided user and password
351
- if alias in self ._alias :
352
- connect_alias = dict (self ._alias [alias ].items ())
353
- connect_alias ["user" ] = user
354
- connect_milvus (** connect_alias , password = password , db_name = db_name , ** kwargs )
355
- return
288
+ connect_milvus (** kwargs , user = user , password = password , db_name = db_name )
356
289
357
- # No params, env, and cached configs for the alias
358
- raise ConnectionConfigException (message = ExceptionsMessage .ConnLackConf % alias )
359
290
360
291
def list_connections (self ) -> list :
361
292
""" List names of all connections.
@@ -369,7 +300,8 @@ def list_connections(self) -> list:
369
300
>>> connections.list_connections()
370
301
// TODO [('default', None), ('test', <pymilvus.client.grpc_handler.GrpcHandler object at 0x7f05003f3e80>)]
371
302
"""
372
- return [(k , self ._connected_alias .get (k , None )) for k in self ._alias ]
303
+ with self ._con_lock :
304
+ return [(k , self ._connected_alias .get (k , None )) for k in self ._alias ]
373
305
374
306
def get_connection_addr (self , alias : str ):
375
307
"""
@@ -414,7 +346,99 @@ def has_connection(self, alias: str) -> bool:
414
346
"""
415
347
if not isinstance (alias , str ):
416
348
raise ConnectionConfigException (message = ExceptionsMessage .AliasType % type (alias ))
417
- return alias in self ._connected_alias
349
+ with self ._con_lock :
350
+ return alias in self ._connected_alias
351
+
352
+ def __parse_info (
353
+ self ,
354
+ address : str = "" ,
355
+ uri : str = "" ,
356
+ host : str = "" ,
357
+ port : str = "" ,
358
+ db_name : str = "" ,
359
+ user : str = "" ,
360
+ password : str = "" ,
361
+ ** kwargs ) -> dict :
362
+
363
+ passed_in_address = ""
364
+ passed_in_user = ""
365
+ passed_in_password = ""
366
+ passed_in_db_name = ""
367
+
368
+ # If uri
369
+ if uri != "" :
370
+ passed_in_address , passed_in_user , passed_in_password , passed_in_db_name = (
371
+ self .__parse_address_from_uri (uri )
372
+ )
373
+
374
+ elif address != "" :
375
+ if not is_legal_address (address ):
376
+ raise ConnectionConfigException (
377
+ message = f"Illegal address: { address } , should be in form 'localhost:19530'" )
378
+ passed_in_address = address
379
+
380
+ else :
381
+ if host == "" :
382
+ host = Config .DEFAULT_HOST
383
+ if port == "" :
384
+ port = Config .DEFAULT_PORT
385
+ self .__verify_host_port (host , port )
386
+ passed_in_address = f"{ host } :{ port } "
387
+
388
+ passed_in_user = user if passed_in_user == "" else str (passed_in_user )
389
+ passed_in_user = Config .MILVUS_USER if passed_in_user == "" else str (passed_in_user )
390
+
391
+ passed_in_password = password if passed_in_password == "" else str (passed_in_password )
392
+ passed_in_password = Config .MILVUS_PASSWORD if passed_in_password == "" else str (passed_in_password )
393
+
394
+ passed_in_db_name = db_name if passed_in_db_name == "" else str (passed_in_db_name )
395
+ passed_in_db_name = Config .MILVUS_DB_NAME if passed_in_db_name == "" else str (passed_in_db_name )
396
+
397
+ return passed_in_address , passed_in_user , passed_in_password , passed_in_db_name
398
+
399
+ def __verify_host_port (self , host , port ):
400
+ if not is_legal_host (host ):
401
+ raise ConnectionConfigException (message = ExceptionsMessage .HostType )
402
+ if not is_legal_port (port ):
403
+ raise ConnectionConfigException (message = ExceptionsMessage .PortType )
404
+ if not 0 <= int (port ) < 65535 :
405
+ raise ConnectionConfigException (message = f"port number { port } out of range, valid range [0, 65535)" )
406
+
407
+ def __parse_address_from_uri (self , uri : str ) -> Tuple [str , str , str , str ]:
408
+ illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'"
409
+ try :
410
+ parsed_uri = parse .urlparse (uri )
411
+ except (Exception ) as e :
412
+ raise ConnectionConfigException (
413
+ message = f"{ illegal_uri_msg .format (uri )} : <{ type (e ).__name__ } , { e } >" ) from None
414
+
415
+ if len (parsed_uri .netloc ) == 0 :
416
+ raise ConnectionConfigException (message = f"{ illegal_uri_msg .format (uri )} " ) from None
417
+
418
+ group = parsed_uri .path .split ("/" )
419
+ if len (group ) > 1 :
420
+ db_name = group [1 ]
421
+ else :
422
+ db_name = ""
423
+
424
+ host = parsed_uri .hostname if parsed_uri .hostname is not None else ""
425
+ port = parsed_uri .port if parsed_uri .port is not None else ""
426
+ user = parsed_uri .username if parsed_uri .username is not None else ""
427
+ password = parsed_uri .password if parsed_uri .password is not None else ""
428
+
429
+ if host == "" :
430
+ raise ConnectionConfigException (message = f"Illegal uri: URI is missing host address: { uri } " )
431
+ if port == "" :
432
+ raise ConnectionConfigException (message = f"Illegal uri: URI is missing port: { uri } " )
433
+
434
+ self .__verify_host_port (host , port )
435
+ addr = f"{ host } :{ port } "
436
+
437
+ if not is_legal_address (addr ):
438
+ raise ConnectionConfigException (message = illegal_uri_msg .format (uri ))
439
+
440
+ return addr , user , password , db_name
441
+
418
442
419
443
def _fetch_handler (self , alias = Config .MILVUS_CONN_ALIAS ) -> GrpcHandler :
420
444
""" Retrieves a GrpcHandler by alias. """
0 commit comments