Skip to content

Commit b661c26

Browse files
committed
fix: [cherry-pick]restrict input/search type for vector fields (#2025)
See also: #2018, #2004, #2016 --------- Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent 0068119 commit b661c26

File tree

8 files changed

+221
-151
lines changed

8 files changed

+221
-151
lines changed

pymilvus/client/entity_helper.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ def entity_to_array_arr(entity: List[Any], field_info: Any):
291291
return convert_to_array_arr(entity.get("values", []), field_info)
292292

293293

294-
def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
294+
def pack_field_value_to_field_data(
295+
field_value: Any, field_data: schema_types.FieldData, field_info: Any
296+
):
295297
field_type = field_data.type
296298
if field_type == DataType.BOOL:
297299
field_data.scalars.bool_data.data.append(field_value)
@@ -304,26 +306,51 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
304306
elif field_type == DataType.DOUBLE:
305307
field_data.scalars.double_data.data.append(field_value)
306308
elif field_type == DataType.FLOAT_VECTOR:
307-
field_data.vectors.dim = len(field_value)
308-
field_data.vectors.float_vector.data.extend(field_value)
309+
f_value = field_value
310+
if isinstance(field_value, np.ndarray):
311+
if field_value.dtype not in ("float32", "float64"):
312+
raise ParamError(
313+
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
314+
)
315+
f_value = field_value.view(np.float32).tolist()
316+
317+
field_data.vectors.dim = len(f_value)
318+
field_data.vectors.float_vector.data.extend(f_value)
319+
309320
elif field_type == DataType.BINARY_VECTOR:
310321
field_data.vectors.dim = len(field_value) * 8
311322
field_data.vectors.binary_vector += bytes(field_value)
323+
312324
elif field_type == DataType.FLOAT16_VECTOR:
313-
v_bytes = (
314-
bytes(field_value)
315-
if not isinstance(field_value, np.ndarray)
316-
else field_value.view(np.uint8).tobytes()
317-
)
325+
if isinstance(field_value, bytes):
326+
v_bytes = field_value
327+
elif isinstance(field_value, np.ndarray):
328+
if field_value.dtype != "float16":
329+
raise ParamError(
330+
message="invalid input for float16 vector, expect np.ndarray with dtype=float16"
331+
)
332+
v_bytes = field_value.view(np.uint8).tobytes()
333+
else:
334+
raise ParamError(
335+
message="invalid input type for float16 vector, expect np.ndarray with dtype=float16"
336+
)
318337

319338
field_data.vectors.dim = len(v_bytes) // 2
320339
field_data.vectors.float16_vector += v_bytes
340+
321341
elif field_type == DataType.BFLOAT16_VECTOR:
322-
v_bytes = (
323-
bytes(field_value)
324-
if not isinstance(field_value, np.ndarray)
325-
else field_value.view(np.uint8).tobytes()
326-
)
342+
if isinstance(field_value, bytes):
343+
v_bytes = field_value
344+
elif isinstance(field_value, np.ndarray):
345+
if field_value.dtype != "bfloat16":
346+
raise ParamError(
347+
message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
348+
)
349+
v_bytes = field_value.view(np.uint8).tobytes()
350+
else:
351+
raise ParamError(
352+
message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
353+
)
327354

328355
field_data.vectors.dim = len(v_bytes) // 2
329356
field_data.vectors.bfloat16_vector += v_bytes

pymilvus/client/grpc_handler.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -475,46 +475,49 @@ def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwa
475475

476476
return fields_info, enable_dynamic
477477

478-
def _prepare_row_insert_request(
478+
@retry_on_rpc_failure()
479+
def insert_rows(
479480
self,
480481
collection_name: str,
481-
entity_rows: List,
482+
entities: Union[Dict, List[Dict]],
482483
partition_name: Optional[str] = None,
484+
schema: Optional[dict] = None,
483485
timeout: Optional[float] = None,
484486
**kwargs,
485487
):
486-
if not isinstance(entity_rows, list):
487-
raise ParamError(message="None rows, please provide valid row data.")
488-
489-
fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
490-
return Prepare.row_insert_param(
491-
collection_name,
492-
entity_rows,
493-
partition_name,
494-
fields_info,
495-
enable_dynamic=enable_dynamic,
488+
request = self._prepare_row_insert_request(
489+
collection_name, entities, partition_name, timeout, **kwargs
496490
)
491+
resp = self._stub.Insert(request=request, timeout=timeout)
492+
check_status(resp.status)
493+
ts_utils.update_collection_ts(collection_name, resp.timestamp)
494+
return MutationResult(resp)
497495

498-
@retry_on_rpc_failure()
499-
def insert_rows(
496+
def _prepare_row_insert_request(
500497
self,
501498
collection_name: str,
502-
entities: List,
499+
entity_rows: Union[List[Dict], Dict],
503500
partition_name: Optional[str] = None,
501+
schema: Optional[dict] = None,
504502
timeout: Optional[float] = None,
505503
**kwargs,
506504
):
507-
if isinstance(entities, dict):
508-
entities = [entities]
509-
request = self._prepare_row_insert_request(
510-
collection_name, entities, partition_name, timeout, **kwargs
505+
if isinstance(entity_rows, dict):
506+
entity_rows = [entity_rows]
507+
508+
if not isinstance(schema, dict):
509+
schema = self.describe_collection(collection_name, timeout=timeout)
510+
511+
fields_info = schema.get("fields")
512+
enable_dynamic = schema.get("enable_dynamic_field", False)
513+
514+
return Prepare.row_insert_param(
515+
collection_name,
516+
entity_rows,
517+
partition_name,
518+
fields_info,
519+
enable_dynamic=enable_dynamic,
511520
)
512-
rf = self._stub.Insert.future(request, timeout=timeout)
513-
response = rf.result()
514-
check_status(response.status)
515-
m = MutationResult(response)
516-
ts_utils.update_collection_ts(collection_name, m.timestamp)
517-
return m
518521

519522
def _prepare_batch_insert_request(
520523
self,
@@ -1376,7 +1379,7 @@ def _wait_for_flushed(
13761379
end = time.time()
13771380
if timeout is not None and end - start > timeout:
13781381
raise MilvusException(
1379-
message=f"wait for flush timeout, collection: {collection_name}"
1382+
message=f"wait for flush timeout, collection: {collection_name}, flusht_ts: {flush_ts}"
13801383
)
13811384

13821385
if not flush_ret:

pymilvus/client/prepare.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,10 @@ def _parse_row_request(
364364
field["name"]: field for field in fields_info if not field.get("auto_id", False)
365365
}
366366

367-
meta_field = (
368-
schema_types.FieldData(is_dynamic=True, type=DataType.JSON) if enable_dynamic else None
369-
)
370-
if meta_field is not None:
371-
field_info_map[meta_field.field_name] = meta_field
372-
fields_data[meta_field.field_name] = meta_field
367+
if enable_dynamic:
368+
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
369+
fields_data[d_field.field_name] = d_field
370+
field_info_map[d_field.field_name] = d_field
373371

374372
try:
375373
for entity in entities:
@@ -390,7 +388,7 @@ def _parse_row_request(
390388

391389
if enable_dynamic:
392390
json_value = entity_helper.convert_to_json(json_dict)
393-
meta_field.scalars.json_data.data.append(json_value)
391+
d_field.scalars.json_data.data.append(json_value)
394392

395393
except (TypeError, ValueError) as e:
396394
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e
@@ -400,7 +398,7 @@ def _parse_row_request(
400398
)
401399

402400
if enable_dynamic:
403-
request.fields_data.append(meta_field)
401+
request.fields_data.append(d_field)
404402

405403
_, _, auto_id_loc = traverse_rows_info(fields_info, entities)
406404
if auto_id_loc is not None:
@@ -418,16 +416,18 @@ def row_insert_param(
418416
collection_name: str,
419417
entities: List,
420418
partition_name: str,
421-
fields_info: Any,
419+
fields_info: Dict,
422420
enable_dynamic: bool = False,
423421
):
424422
if not fields_info:
425423
raise ParamError(message="Missing collection meta to validate entities")
426424

427425
# insert_request.hash_keys won't be filled in client.
428-
tag = partition_name if isinstance(partition_name, str) else ""
426+
p_name = partition_name if isinstance(partition_name, str) else ""
429427
request = milvus_types.InsertRequest(
430-
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
428+
collection_name=collection_name,
429+
partition_name=p_name,
430+
num_rows=len(entities),
431431
)
432432

433433
return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
@@ -445,9 +445,11 @@ def row_upsert_param(
445445
raise ParamError(message="Missing collection meta to validate entities")
446446

447447
# upsert_request.hash_keys won't be filled in client.
448-
tag = partition_name if isinstance(partition_name, str) else ""
448+
p_name = partition_name if isinstance(partition_name, str) else ""
449449
request = milvus_types.UpsertRequest(
450-
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
450+
collection_name=collection_name,
451+
partition_name=p_name,
452+
num_rows=len(entities),
451453
)
452454

453455
return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
@@ -469,7 +471,7 @@ def _pre_batch_check(
469471
if not fields_info:
470472
raise ParamError(message="Missing collection meta to validate entities")
471473

472-
location, primary_key_loc, auto_id_loc = traverse_info(fields_info, entities)
474+
location, primary_key_loc, auto_id_loc = traverse_info(fields_info)
473475

474476
# though impossible from sdk
475477
if primary_key_loc is None:
@@ -583,16 +585,20 @@ def _prepare_placeholder_str(cls, data: Any):
583585

584586
elif isinstance(data[0], np.ndarray):
585587
dtype = data[0].dtype
586-
pl_values = (array.tobytes() for array in data)
587588

588589
if dtype == "bfloat16":
589590
pl_type = PlaceholderType.BFLOAT16_VECTOR
591+
pl_values = (array.tobytes() for array in data)
590592
elif dtype == "float16":
591593
pl_type = PlaceholderType.FLOAT16_VECTOR
592-
elif dtype == "float32":
594+
pl_values = (array.tobytes() for array in data)
595+
elif dtype in ("float32", "float64"):
593596
pl_type = PlaceholderType.FloatVector
597+
pl_values = (blob.vector_float_to_bytes(entity) for entity in data)
598+
594599
elif dtype == "byte":
595600
pl_type = PlaceholderType.BinaryVector
601+
pl_values = data
596602

597603
else:
598604
err_msg = f"unsupported data type: {dtype}"

pymilvus/client/utils.py

+2-53
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def traverse_rows_info(fields_info: Any, entities: List):
250250
return location, primary_key_loc, auto_id_loc
251251

252252

253-
def traverse_info(fields_info: Any, entities: List):
253+
def traverse_info(fields_info: Any):
254254
location, primary_key_loc, auto_id_loc = {}, None, None
255255
for i, field in enumerate(fields_info):
256256
if field.get("is_primary", False):
@@ -259,58 +259,7 @@ def traverse_info(fields_info: Any, entities: List):
259259
if field.get("auto_id", False):
260260
auto_id_loc = i
261261
continue
262-
263-
match_flag = False
264-
field_name = field["name"]
265-
field_type = field["type"]
266-
267-
for entity in entities:
268-
entity_name, entity_type = entity["name"], entity["type"]
269-
270-
if field_name == entity_name:
271-
if field_type != entity_type:
272-
raise ParamError(
273-
message=f"Collection field type is {field_type}"
274-
f", but entities field type is {entity_type}"
275-
)
276-
277-
entity_dim, field_dim = 0, 0
278-
if entity_type in [
279-
DataType.FLOAT_VECTOR,
280-
DataType.BINARY_VECTOR,
281-
DataType.BFLOAT16_VECTOR,
282-
DataType.FLOAT16_VECTOR,
283-
]:
284-
field_dim = field["params"]["dim"]
285-
entity_dim = len(entity["values"][0])
286-
287-
if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim:
288-
raise ParamError(
289-
message=f"Collection field dim is {field_dim}"
290-
f", but entities field dim is {entity_dim}"
291-
)
292-
293-
if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim:
294-
raise ParamError(
295-
message=f"Collection field dim is {field_dim}"
296-
f", but entities field dim is {entity_dim * 8}"
297-
)
298-
299-
if (
300-
entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]
301-
and int(entity_dim // 2) != field_dim
302-
):
303-
raise ParamError(
304-
message=f"Collection field dim is {field_dim}"
305-
f", but entities field dim is {int(entity_dim // 2)}"
306-
)
307-
308-
location[field["name"]] = i
309-
match_flag = True
310-
break
311-
312-
if not match_flag:
313-
raise ParamError(message=f"Field {field['name']} don't match in entities")
262+
location[field["name"]] = i
314263

315264
return location, primary_key_loc, auto_id_loc
316265

0 commit comments

Comments
 (0)