diff --git a/pymilvus/orm/types.py b/pymilvus/orm/types.py index eebbb50bc..fae883fa8 100644 --- a/pymilvus/orm/types.py +++ b/pymilvus/orm/types.py @@ -71,6 +71,14 @@ def is_numeric_datatype(data_type: DataType): return is_float_datatype(data_type) or is_integer_datatype(data_type) +def is_varchar_datatype(data_type: DataType): + return data_type in (DataType.VARCHAR,) + + +def is_bool_datatype(data_type: DataType): + return data_type in (DataType.BOOL,) + + # pylint: disable=too-many-return-statements def infer_dtype_by_scalar_data(data: Any): if isinstance(data, float): @@ -105,7 +113,7 @@ def infer_dtype_by_scalar_data(data: Any): return DataType.UNKNOWN -def infer_dtype_bydata(data: Any): +def infer_dtype_bydata(data: Any, **kwargs): d_type = DataType.UNKNOWN if is_scalar(data): return infer_dtype_by_scalar_data(data) @@ -121,7 +129,16 @@ def infer_dtype_bydata(data: Any): failed = True if not failed: d_type = dtype_str_map.get(type_str, DataType.UNKNOWN) - return DataType.FLOAT_VECTOR if is_numeric_datatype(d_type) else DataType.UNKNOWN + if is_varchar_datatype(d_type) or is_bool_datatype(d_type): + return DataType.ARRAY + if ( + kwargs is None + or len(kwargs) == 0 + or (kwargs["type"] is not None and kwargs["type"] == "vector") + ): + return DataType.FLOAT_VECTOR if is_numeric_datatype(d_type) else DataType.UNKNOWN + else: + return DataType.ARRAY if d_type == DataType.UNKNOWN: try: