Skip to content

enhance: Support AddCollectionField API #2722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils, utils
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
from .abstract import (
AnnSearchRequest,
BaseRanker,
CollectionSchema,
FieldSchema,
MutationResult,
SearchResult,
)
from .asynch import (
CreateIndexFuture,
FlushFuture,
Expand Down Expand Up @@ -324,6 +331,20 @@ def drop_collection(self, collection_name: str, timeout: Optional[float] = None)
status = rf.result()
check_status(status)

@retry_on_rpc_failure()
def add_collection_field(
self,
collection_name: str,
field_schema: FieldSchema,
timeout: Optional[float] = None,
**kwargs,
):
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.add_collection_field_request(collection_name, field_schema)
rf = self._stub.AddCollectionField.future(request, timeout=timeout)
status = rf.result()
check_status(status)

@retry_on_rpc_failure()
def alter_collection_properties(
self, collection_name: str, properties: List, timeout: Optional[float] = None, **kwargs
Expand Down
14 changes: 13 additions & 1 deletion pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pymilvus.grpc_gen import common_pb2 as common_types
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
from pymilvus.grpc_gen import schema_pb2 as schema_types
from pymilvus.orm.schema import CollectionSchema
from pymilvus.orm.schema import CollectionSchema, FieldSchema
from pymilvus.orm.types import infer_dtype_by_scalar_data

from . import __version__, blob, check, entity_helper, ts_utils, utils
Expand Down Expand Up @@ -272,6 +272,18 @@ def get_schema(
def drop_collection_request(cls, collection_name: str) -> milvus_types.DropCollectionRequest:
return milvus_types.DropCollectionRequest(collection_name=collection_name)

@classmethod
def add_collection_field_request(
cls,
collection_name: str,
field_schema: FieldSchema,
) -> milvus_types.AddCollectionFieldRequest:
(field_schema_proto, _, _) = cls.get_field_schema(field=field_schema.to_dict())
return milvus_types.AddCollectionFieldRequest(
collection_name=collection_name,
schema=bytes(field_schema_proto.SerializeToString()),
)

@classmethod
def describe_collection_request(
cls,
Expand Down
592 changes: 296 additions & 296 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ class SubSearchRequest(_message.Message):
def __init__(self, dsl: _Optional[str] = ..., placeholder_group: _Optional[bytes] = ..., dsl_type: _Optional[_Union[_common_pb2.DslType, str]] = ..., search_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., nq: _Optional[int] = ..., expr_template_values: _Optional[_Mapping[str, _schema_pb2.TemplateValue]] = ...) -> None: ...

class SearchRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_names", "dsl", "placeholder_group", "dsl_type", "output_fields", "search_params", "travel_timestamp", "guarantee_timestamp", "nq", "not_return_all_meta", "consistency_level", "use_default_consistency", "search_by_primary_keys", "sub_reqs", "expr_template_values")
__slots__ = ("base", "db_name", "collection_name", "partition_names", "dsl", "placeholder_group", "dsl_type", "output_fields", "search_params", "travel_timestamp", "guarantee_timestamp", "nq", "not_return_all_meta", "consistency_level", "use_default_consistency", "search_by_primary_keys", "sub_reqs", "expr_template_values", "function_score")
class ExprTemplateValuesEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -845,6 +845,7 @@ class SearchRequest(_message.Message):
SEARCH_BY_PRIMARY_KEYS_FIELD_NUMBER: _ClassVar[int]
SUB_REQS_FIELD_NUMBER: _ClassVar[int]
EXPR_TEMPLATE_VALUES_FIELD_NUMBER: _ClassVar[int]
FUNCTION_SCORE_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
Expand All @@ -863,7 +864,8 @@ class SearchRequest(_message.Message):
search_by_primary_keys: bool
sub_reqs: _containers.RepeatedCompositeFieldContainer[SubSearchRequest]
expr_template_values: _containers.MessageMap[str, _schema_pb2.TemplateValue]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_names: _Optional[_Iterable[str]] = ..., dsl: _Optional[str] = ..., placeholder_group: _Optional[bytes] = ..., dsl_type: _Optional[_Union[_common_pb2.DslType, str]] = ..., output_fields: _Optional[_Iterable[str]] = ..., search_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., nq: _Optional[int] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ..., search_by_primary_keys: bool = ..., sub_reqs: _Optional[_Iterable[_Union[SubSearchRequest, _Mapping]]] = ..., expr_template_values: _Optional[_Mapping[str, _schema_pb2.TemplateValue]] = ...) -> None: ...
function_score: _schema_pb2.FunctionScore
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_names: _Optional[_Iterable[str]] = ..., dsl: _Optional[str] = ..., placeholder_group: _Optional[bytes] = ..., dsl_type: _Optional[_Union[_common_pb2.DslType, str]] = ..., output_fields: _Optional[_Iterable[str]] = ..., search_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., nq: _Optional[int] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ..., search_by_primary_keys: bool = ..., sub_reqs: _Optional[_Iterable[_Union[SubSearchRequest, _Mapping]]] = ..., expr_template_values: _Optional[_Mapping[str, _schema_pb2.TemplateValue]] = ..., function_score: _Optional[_Union[_schema_pb2.FunctionScore, _Mapping]] = ...) -> None: ...

class Hits(_message.Message):
__slots__ = ("IDs", "row_data", "scores")
Expand All @@ -888,7 +890,7 @@ class SearchResults(_message.Message):
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ..., session_ts: _Optional[int] = ...) -> None: ...

class HybridSearchRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency")
__slots__ = ("base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency", "function_score")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -901,6 +903,7 @@ class HybridSearchRequest(_message.Message):
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
CONSISTENCY_LEVEL_FIELD_NUMBER: _ClassVar[int]
USE_DEFAULT_CONSISTENCY_FIELD_NUMBER: _ClassVar[int]
FUNCTION_SCORE_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
Expand All @@ -913,7 +916,8 @@ class HybridSearchRequest(_message.Message):
output_fields: _containers.RepeatedScalarFieldContainer[str]
consistency_level: _common_pb2.ConsistencyLevel
use_default_consistency: bool
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_names: _Optional[_Iterable[str]] = ..., requests: _Optional[_Iterable[_Union[SearchRequest, _Mapping]]] = ..., rank_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., not_return_all_meta: bool = ..., output_fields: _Optional[_Iterable[str]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ...) -> None: ...
function_score: _schema_pb2.FunctionScore
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_names: _Optional[_Iterable[str]] = ..., requests: _Optional[_Iterable[_Union[SearchRequest, _Mapping]]] = ..., rank_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., not_return_all_meta: bool = ..., output_fields: _Optional[_Iterable[str]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ..., function_score: _Optional[_Union[_schema_pb2.FunctionScore, _Mapping]] = ...) -> None: ...

class FlushRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_names")
Expand Down
116 changes: 59 additions & 57 deletions pymilvus/grpc_gen/schema_pb2.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions pymilvus/grpc_gen/schema_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class FunctionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
Unknown: _ClassVar[FunctionType]
BM25: _ClassVar[FunctionType]
TextEmbedding: _ClassVar[FunctionType]
Rerank: _ClassVar[FunctionType]

class FieldState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
Expand Down Expand Up @@ -66,6 +67,7 @@ Int8Vector: DataType
Unknown: FunctionType
BM25: FunctionType
TextEmbedding: FunctionType
Rerank: FunctionType
FieldCreated: FieldState
FieldCreating: FieldState
FieldDropping: FieldState
Expand Down Expand Up @@ -129,6 +131,14 @@ class FunctionSchema(_message.Message):
params: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, name: _Optional[str] = ..., id: _Optional[int] = ..., description: _Optional[str] = ..., type: _Optional[_Union[FunctionType, str]] = ..., input_field_names: _Optional[_Iterable[str]] = ..., input_field_ids: _Optional[_Iterable[int]] = ..., output_field_names: _Optional[_Iterable[str]] = ..., output_field_ids: _Optional[_Iterable[int]] = ..., params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class FunctionScore(_message.Message):
__slots__ = ("functions", "params")
FUNCTIONS_FIELD_NUMBER: _ClassVar[int]
PARAMS_FIELD_NUMBER: _ClassVar[int]
functions: _containers.RepeatedCompositeFieldContainer[FunctionSchema]
params: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, functions: _Optional[_Iterable[_Union[FunctionSchema, _Mapping]]] = ..., params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class CollectionSchema(_message.Message):
__slots__ = ("name", "description", "autoID", "fields", "enable_dynamic_field", "properties", "functions", "dbName")
NAME_FIELD_NUMBER: _ClassVar[int]
Expand Down
47 changes: 46 additions & 1 deletion pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ServerVersionIncompatibleException,
)
from pymilvus.orm import utility
from pymilvus.orm.collection import CollectionSchema
from pymilvus.orm.collection import CollectionSchema, FieldSchema
from pymilvus.orm.connections import connections
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
from pymilvus.orm.iterator import QueryIterator, SearchIterator
Expand Down Expand Up @@ -873,6 +873,23 @@ def create_schema(cls, **kwargs):
kwargs["check_fields"] = False # do not check fields for now
return CollectionSchema([], **kwargs)

@classmethod
def create_field_schema(
cls, name: str, data_type: DataType, desc: str = "", **kwargs
) -> FieldSchema:
"""Create a field schema. Wrapping orm.FieldSchema.

Args:
name (str): The name of the field.
dtype (DataType): The data type of the field.
desc (str): The description of the field.
**kwargs: Additional keyword arguments.

Returns:
FieldSchema: the FieldSchema created.
"""
return FieldSchema(name, data_type, desc, **kwargs)

@classmethod
def prepare_index_params(cls, field_name: str = "", **kwargs) -> IndexParams:
index_params = IndexParams()
Expand Down Expand Up @@ -1103,6 +1120,34 @@ def alter_collection_field(
**kwargs,
)

def add_collection_field(
self,
collection_name: str,
field_schema: FieldSchema,
timeout: Optional[float] = None,
**kwargs,
):
"""Add a new field to the collection.

Args:
collection_name(``string``): The name of collection.
field_schema (``FieldSchema``): The field schema to add.
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
If timeout is set to None, the client keeps waiting until the server
responds or an error occurs.
**kwargs (``dict``): Optional field params

Raises:
MilvusException: If anything goes wrong
"""
conn = self._get_connection()
conn.add_collection_field(
collection_name,
field_schema,
timeout=timeout,
**kwargs,
)

def create_partition(
self, collection_name: str, partition_name: str, timeout: Optional[float] = None, **kwargs
):
Expand Down