Skip to content

Commit 72c92b1

Browse files
authored
VertexAI Google Cloud Palm2 Support (#226)
* feat(bard): added * docs(readme): update * chore(print): removed
1 parent f432766 commit 72c92b1

File tree

11 files changed

+36
-16
lines changed

11 files changed

+36
-16
lines changed

.backend_env.example

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ SUPABASE_SERVICE_KEY=eyXXXXX
33
OPENAI_API_KEY=sk-XXXXXX
44
ANTHROPIC_API_KEY=XXXXXX
55
JWT_SECRET_KEY=Found in Supabase settings in the API tab
6-
AUTHENTICATE="true"
6+
AUTHENTICATE=true
7+
GOOGLE_APPLICATION_CREDENTIALS=/code/application_default_credentials.json
8+
GOOGLE_CLOUD_PROJECT=XXXXX to be changed with your GCP id

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ streamlit-demo/.streamlit/secrets.toml
5050
.frontend_env
5151
backend/pandoc-*
5252
**/.pandoc-*
53+
backend/application_default_credentials.json

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,15 @@ cp .frontend_env.example frontend/.env
8888
- **Step 3**: Update the `backend/.env` and `frontend/.env` file
8989

9090
> _Your `supabase_service_key` can be found in your Supabase dashboard under Project Settings -> API. Use the `anon` `public` key found in the `Project API keys` section._
91+
92+
9193
> _Your `JWT_SECRET_KEY`can be found in your supabase settings under Project Settings -> JWT Settings -> JWT Secret_
9294
95+
> _To activate vertexAI with PaLM from GCP follow the instructions [here](https://python.langchain.com/en/latest/modules/models/llms/integrations/google_vertex_ai_palm.html) and update `bacend/.env`- It is an advanced feature, please be expert in GCP before trying to use it_
96+
97+
- [ ] Change variables in `backend/.env`
98+
- [ ] Change variables in `frontend/.env`
99+
93100
- **Step 4**: Run the following migration scripts on the Supabase database via the web interface (SQL Editor -> `New query`)
94101

95102
[Migration Script 1](scripts/supabase_new_store_documents.sql)

backend/Dockerfile

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
FROM python:3.11
1+
FROM python:3.11-buster
2+
3+
# Install GEOS library
4+
RUN apt-get update && apt-get install -y libgeos-dev
25

36
WORKDIR /code
47

@@ -8,4 +11,4 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt --timeout 100
811

912
COPY . /code/
1013

11-
CMD ["uvicorn", "api:app", "--reload", "--host", "0.0.0.0", "--port", "5050"]
14+
CMD ["uvicorn", "api:app", "--reload", "--host", "0.0.0.0", "--port", "5050"]

backend/api.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
import shutil
3+
import time
34
from tempfile import SpooledTemporaryFile
45
from typing import Annotated, List, Tuple
56

67
import pypandoc
78
from auth_bearer import JWTBearer
89
from crawl.crawler import CrawlWebsite
9-
from fastapi import Depends, FastAPI, File, Header, HTTPException, UploadFile
10+
from fastapi import (Depends, FastAPI, File, Header, HTTPException, Request,
11+
UploadFile)
1012
from fastapi.middleware.cors import CORSMiddleware
1113
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
1214
from llm.qa import get_qa_llm
@@ -25,13 +27,11 @@
2527
from parsers.powerpoint import process_powerpoint
2628
from parsers.txt import process_txt
2729
from pydantic import BaseModel
28-
from utils import ChatMessage, CommonsDep, similarity_search
29-
3030
from supabase import Client
31+
from utils import ChatMessage, CommonsDep, similarity_search
3132

3233
logger = get_logger(__name__)
3334

34-
3535
app = FastAPI()
3636

3737
origins = [
@@ -49,6 +49,7 @@
4949
)
5050

5151

52+
5253
@app.on_event("startup")
5354
async def startup_event():
5455
pypandoc.download_pandoc()

backend/auth_handler.py

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def decode_access_token(token: str):
2323
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False})
2424
return payload
2525
except JWTError as e:
26-
print(f"JWTError: {str(e)}")
2726
return None
2827

2928
def get_user_email_from_token(token: str):

backend/llm/qa.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Any, List
33

44
from langchain.chains import ConversationalRetrievalChain
5-
from langchain.chat_models import ChatOpenAI
5+
from langchain.chat_models import ChatOpenAI, ChatVertexAI
66
from langchain.chat_models.anthropic import ChatAnthropic
77
from langchain.docstore.document import Document
88
from langchain.embeddings.openai import OpenAIEmbeddings
9+
from langchain.llms import VertexAI
910
from langchain.memory import ConversationBufferMemory
1011
from langchain.vectorstores import SupabaseVectorStore
1112
from llm import LANGUAGE_PROMPT
@@ -94,6 +95,9 @@ def get_qa_llm(chat_message: ChatMessage, user_id: str):
9495
temperature=chat_message.temperature, max_tokens=chat_message.max_tokens),
9596
vector_store.as_retriever(), memory=memory, verbose=True,
9697
max_tokens_limit=1024)
98+
elif chat_message.model.startswith("vertex"):
99+
qa = ConversationalRetrievalChain.from_llm(
100+
ChatVertexAI(), vector_store.as_retriever(), memory=memory, verbose=False, max_tokens_limit=1024)
97101
elif anthropic_api_key and chat_message.model.startswith("claude"):
98102
qa = ConversationalRetrievalChain.from_llm(
99103
ChatAnthropic(

backend/parsers/audio.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import os
2-
from tempfile import NamedTemporaryFile
32
import tempfile
4-
from io import BytesIO
53
import time
4+
from io import BytesIO
5+
from tempfile import NamedTemporaryFile
6+
67
import openai
8+
from fastapi import UploadFile
79
from langchain.document_loaders import TextLoader
810
from langchain.embeddings.openai import OpenAIEmbeddings
11+
from langchain.schema import Document
912
from langchain.text_splitter import RecursiveCharacterTextSplitter
1013
from utils import compute_sha1_from_content, documents_vector_store
11-
from langchain.schema import Document
12-
from fastapi import UploadFile
1314

1415
# # Create a function to transcribe audio using Whisper
1516
# def _transcribe_audio(api_key, audio_file, stats_db):
@@ -52,7 +53,6 @@ async def process_audio(upload_file: UploadFile, stats_db):
5253

5354
file_sha = compute_sha1_from_content(transcript.text.encode("utf-8"))
5455
file_size = len(transcript.text.encode("utf-8"))
55-
print(file_size)
5656

5757
# Load chunk size and overlap from sidebar
5858
chunk_size = 500

backend/requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
langchain==0.0.166
1+
langchain==0.0.187
22
Markdown==3.4.3
33
openai==0.27.6
44
pdf2image==1.16.3
@@ -15,4 +15,5 @@ uvicorn==0.22.0
1515
pypandoc==1.11
1616
docx2txt==0.8
1717
guidance==0.0.53
18-
python-jose==3.3.0
18+
python-jose==3.3.0
19+
google_cloud_aiplatform==1.25.0

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ services:
2525
restart: always
2626
volumes:
2727
- ./backend/:/code/
28+
- ~/.config/gcloud:/root/.config/gcloud
2829
ports:
2930
- 5050:5050

frontend/app/chat/page.tsx

+1
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ export default function ChatPage() {
159159
>
160160
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
161161
<option value="gpt-4">gpt-4</option>
162+
<option value="vertexai">vertexai</option>
162163
</select>
163164
</fieldset>
164165
<fieldset className="w-full flex">

0 commit comments

Comments
 (0)