|
| 1 | +import time |
| 2 | +import random |
| 3 | +import numpy as np |
| 4 | +from pymilvus import ( |
| 5 | + connections, |
| 6 | + utility, |
| 7 | + FieldSchema, CollectionSchema, DataType, |
| 8 | + Collection, |
| 9 | + ) |
| 10 | + |
| 11 | + |
| 12 | +bin_index_types = ["BIN_FLAT", "BIN_IVF_FLAT"] |
| 13 | + |
| 14 | +default_bin_index_params = [{"nlist": 128}, {"nlist": 128}] |
| 15 | + |
| 16 | +def gen_binary_vectors(num, dim): |
| 17 | + raw_vectors = [] |
| 18 | + binary_vectors = [] |
| 19 | + for _ in range(num): |
| 20 | + raw_vector = [random.randint(0, 1) for _ in range(dim)] |
| 21 | + raw_vectors.append(raw_vector) |
| 22 | + # packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints |
| 23 | + binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) |
| 24 | + return raw_vectors, binary_vectors |
| 25 | + |
| 26 | + |
| 27 | +def binary_vector_search(): |
| 28 | + connections.connect() |
| 29 | + int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True) |
| 30 | + dim = 128 |
| 31 | + nb = 3000 |
| 32 | + vector_field_name = "binary_vector" |
| 33 | + binary_vector = FieldSchema(name=vector_field_name, dtype=DataType.BINARY_VECTOR, dim=dim) |
| 34 | + schema = CollectionSchema(fields=[int64_field, binary_vector], enable_dynamic_field=True) |
| 35 | + |
| 36 | + has = utility.has_collection("hello_milvus") |
| 37 | + if has: |
| 38 | + hello_milvus = Collection("hello_milvus_bin") |
| 39 | + hello_milvus.drop() |
| 40 | + else: |
| 41 | + hello_milvus = Collection("hello_milvus_bin", schema) |
| 42 | + |
| 43 | + _, vectors = gen_binary_vectors(nb, dim) |
| 44 | + rows = [ |
| 45 | + {vector_field_name: vectors[0]}, |
| 46 | + {vector_field_name: vectors[1]}, |
| 47 | + {vector_field_name: vectors[2]}, |
| 48 | + {vector_field_name: vectors[3]}, |
| 49 | + {vector_field_name: vectors[4]}, |
| 50 | + {vector_field_name: vectors[5]}, |
| 51 | + ] |
| 52 | + |
| 53 | + hello_milvus.insert(rows) |
| 54 | + hello_milvus.flush() |
| 55 | + for i, index_type in enumerate(bin_index_types): |
| 56 | + index_params = default_bin_index_params[i] |
| 57 | + hello_milvus.create_index(vector_field_name, |
| 58 | + index_params={"index_type": index_type, "params": index_params, "metric_type": "HAMMING"}) |
| 59 | + hello_milvus.load() |
| 60 | + print("index_type = ", index_type) |
| 61 | + res = hello_milvus.search(vectors[:1], vector_field_name, {"metric_type": "HAMMING"}, limit=1) |
| 62 | + print("res = ", res) |
| 63 | + hello_milvus.release() |
| 64 | + hello_milvus.drop_index() |
| 65 | + hello_milvus.drop() |
| 66 | + |
| 67 | + |
| 68 | +if __name__ == "__main__": |
| 69 | + binary_vector_search() |
0 commit comments