16
16
17
17
18
18
class Pinecone (VectorStore ):
19
+ pinecone : None
20
+
19
21
def __init__ (self , system : System ):
20
22
super ().__init__ (system )
21
23
api_key = os .environ .get ("PINECONE_API_KEY" )
22
- environment = os .environ .get ("PINECONE_ENVIRONMENT" )
23
24
if api_key is None :
24
25
raise ValueError ("PINECONE_API_KEY environment variable not set" )
25
- if environment is None :
26
- raise ValueError ("PINECONE_ENVIRONMENT environment variable not set" )
27
- pinecone .init (api_key = api_key , environment = environment )
26
+
27
+ self .pinecone = pinecone .Pinecone (api_key = api_key )
28
28
29
29
@override
30
30
def query (
@@ -34,7 +34,7 @@ def query(
34
34
collection : str ,
35
35
num_results : int ,
36
36
) -> list :
37
- index = pinecone .Index (collection )
37
+ index = self . pinecone .Index (name = collection )
38
38
db_connection_repository = DatabaseConnectionRepository (
39
39
self .system .instance (DB )
40
40
)
@@ -44,18 +44,18 @@ def query(
44
44
)
45
45
xq = embedding .embed_query (query_texts [0 ])
46
46
query_response = index .query (
47
- queries = [xq ],
47
+ vector = [xq ],
48
48
filter = {
49
49
"db_connection_id" : {"$eq" : db_connection_id },
50
50
},
51
51
top_k = num_results ,
52
52
include_metadata = True ,
53
53
)
54
- return query_response .to_dict ()["results" ][ 0 ][ " matches" ]
54
+ return query_response .to_dict ()["matches" ]
55
55
56
56
@override
57
57
def add_records (self , golden_sqls : List [GoldenSQL ], collection : str ):
58
- if collection not in pinecone .list_indexes ():
58
+ if collection not in self . pinecone .list_indexes (). names ():
59
59
self .create_collection (collection )
60
60
db_connection_repository = DatabaseConnectionRepository (
61
61
self .system .instance (DB )
@@ -66,7 +66,7 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
66
66
embedding = OpenAIEmbeddings (
67
67
openai_api_key = database_connection .decrypt_api_key (), model = EMBEDDING_MODEL
68
68
)
69
- index = pinecone .Index (collection )
69
+ index = self . pinecone .Index (name = collection )
70
70
batch_limit = 100
71
71
for limit_index in range (0 , len (golden_sqls ), batch_limit ):
72
72
golden_sql_batch = golden_sqls [limit_index : limit_index + batch_limit ]
@@ -101,7 +101,7 @@ def add_record(
101
101
metadata : Any ,
102
102
ids : List ,
103
103
):
104
- if collection not in pinecone .list_indexes ():
104
+ if collection not in self . pinecone .list_indexes (). names ():
105
105
self .create_collection (collection )
106
106
db_connection_repository = DatabaseConnectionRepository (
107
107
self .system .instance (DB )
@@ -110,22 +110,27 @@ def add_record(
110
110
embedding = OpenAIEmbeddings (
111
111
openai_api_key = database_connection .decrypt_api_key (), model = EMBEDDING_MODEL
112
112
)
113
- index = pinecone .Index (collection )
113
+ index = self . pinecone .Index (name = collection )
114
114
embeds = embedding .embed_documents ([documents ])
115
115
record = [(ids [0 ], embeds , metadata [0 ])]
116
116
index .upsert (vectors = record )
117
117
118
118
@override
119
119
def delete_record (self , collection : str , id : str ):
120
- if collection not in pinecone .list_indexes ():
120
+ if collection not in self . pinecone .list_indexes (). names ():
121
121
self .create_collection (collection )
122
- index = pinecone .Index (collection )
122
+ index = self . pinecone .Index (name = collection )
123
123
index .delete (ids = [id ])
124
124
125
125
@override
126
126
def delete_collection (self , collection : str ):
127
- return pinecone .delete_index (collection )
127
+ return self . pinecone .delete_index (name = collection )
128
128
129
129
@override
130
130
def create_collection (self , collection : str ):
131
- pinecone .create_index (name = collection , dimension = 1536 , metric = "cosine" )
131
+ self .pinecone .create_index (
132
+ name = collection ,
133
+ dimension = 1536 ,
134
+ metric = "cosine" ,
135
+ spec = pinecone .ServerlessSpec (cloud = "aws" , region = "us-west-2" ),
136
+ )
0 commit comments