Skip to content

Commit 7906f03

Browse files
authored
[DH-5595] Implement Pinecone package version to support serverless envs (#434)
1 parent 5160e8d commit 7906f03

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ celerybeat.pid
125125

126126
# Environments
127127
.env
128+
.env.local
128129
.venv
129130
env/
130131
venv/

dataherald/vector_store/pinecone.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717

1818
class Pinecone(VectorStore):
19+
pinecone: None
20+
1921
def __init__(self, system: System):
2022
super().__init__(system)
2123
api_key = os.environ.get("PINECONE_API_KEY")
22-
environment = os.environ.get("PINECONE_ENVIRONMENT")
2324
if api_key is None:
2425
raise ValueError("PINECONE_API_KEY environment variable not set")
25-
if environment is None:
26-
raise ValueError("PINECONE_ENVIRONMENT environment variable not set")
27-
pinecone.init(api_key=api_key, environment=environment)
26+
27+
self.pinecone = pinecone.Pinecone(api_key=api_key)
2828

2929
@override
3030
def query(
@@ -34,7 +34,7 @@ def query(
3434
collection: str,
3535
num_results: int,
3636
) -> list:
37-
index = pinecone.Index(collection)
37+
index = self.pinecone.Index(name=collection)
3838
db_connection_repository = DatabaseConnectionRepository(
3939
self.system.instance(DB)
4040
)
@@ -44,18 +44,18 @@ def query(
4444
)
4545
xq = embedding.embed_query(query_texts[0])
4646
query_response = index.query(
47-
queries=[xq],
47+
vector=[xq],
4848
filter={
4949
"db_connection_id": {"$eq": db_connection_id},
5050
},
5151
top_k=num_results,
5252
include_metadata=True,
5353
)
54-
return query_response.to_dict()["results"][0]["matches"]
54+
return query_response.to_dict()["matches"]
5555

5656
@override
5757
def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
58-
if collection not in pinecone.list_indexes():
58+
if collection not in self.pinecone.list_indexes().names():
5959
self.create_collection(collection)
6060
db_connection_repository = DatabaseConnectionRepository(
6161
self.system.instance(DB)
@@ -66,7 +66,7 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):
6666
embedding = OpenAIEmbeddings(
6767
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
6868
)
69-
index = pinecone.Index(collection)
69+
index = self.pinecone.Index(name=collection)
7070
batch_limit = 100
7171
for limit_index in range(0, len(golden_sqls), batch_limit):
7272
golden_sql_batch = golden_sqls[limit_index : limit_index + batch_limit]
@@ -101,7 +101,7 @@ def add_record(
101101
metadata: Any,
102102
ids: List,
103103
):
104-
if collection not in pinecone.list_indexes():
104+
if collection not in self.pinecone.list_indexes().names():
105105
self.create_collection(collection)
106106
db_connection_repository = DatabaseConnectionRepository(
107107
self.system.instance(DB)
@@ -110,22 +110,27 @@ def add_record(
110110
embedding = OpenAIEmbeddings(
111111
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
112112
)
113-
index = pinecone.Index(collection)
113+
index = self.pinecone.Index(name=collection)
114114
embeds = embedding.embed_documents([documents])
115115
record = [(ids[0], embeds, metadata[0])]
116116
index.upsert(vectors=record)
117117

118118
@override
119119
def delete_record(self, collection: str, id: str):
120-
if collection not in pinecone.list_indexes():
120+
if collection not in self.pinecone.list_indexes().names():
121121
self.create_collection(collection)
122-
index = pinecone.Index(collection)
122+
index = self.pinecone.Index(name=collection)
123123
index.delete(ids=[id])
124124

125125
@override
126126
def delete_collection(self, collection: str):
127-
return pinecone.delete_index(collection)
127+
return self.pinecone.delete_index(name=collection)
128128

129129
@override
130130
def create_collection(self, collection: str):
131-
pinecone.create_index(name=collection, dimension=1536, metric="cosine")
131+
self.pinecone.create_index(
132+
name=collection,
133+
dimension=1536,
134+
metric="cosine",
135+
spec=pinecone.ServerlessSpec(cloud="aws", region="us-west-2"),
136+
)

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ sqlalchemy-databricks==0.2.0
2929
sqlalchemy-bigquery==1.6.1
3030
chromadb==0.4.12
3131
pytest-dotenv==0.5.2
32-
pinecone-client==2.2.2
32+
pinecone-client==3.1.0
3333
cryptography==40.0.2
3434
sphinx==6.2.1
3535
sphinx-book-theme==1.0.1

0 commit comments

Comments
 (0)