From aed8c4ec32aa0e64f146389877d3f707d15de9e4 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 8 Mar 2024 23:39:14 +0000 Subject: [PATCH] add Cohere support to the chatbot example (#199) --- example-apps/chatbot-rag-app/README.md | 10 +++ example-apps/chatbot-rag-app/api/chat.py | 5 +- .../chatbot-rag-app/api/llm_integrations.py | 11 ++- example-apps/chatbot-rag-app/env.example | 5 ++ example-apps/chatbot-rag-app/requirements.in | 3 + example-apps/chatbot-rag-app/requirements.txt | 90 ++++++++++++++++--- 6 files changed, 112 insertions(+), 12 deletions(-) diff --git a/example-apps/chatbot-rag-app/README.md b/example-apps/chatbot-rag-app/README.md index a3a1e4da..c15b8867 100644 --- a/example-apps/chatbot-rag-app/README.md +++ b/example-apps/chatbot-rag-app/README.md @@ -128,6 +128,16 @@ export MISTRAL_API_ENDPOINT=... # optional export MISTRAL_MODEL=... # optional ``` +### Cohere + +To use Cohere you need to set the following environment variables: + +``` +export LLM_TYPE=cohere +export COHERE_API_KEY=... +export COHERE_MODEL=... # optional +``` + ## Running the App Once you have indexed data into the Elasticsearch index, there are two ways to run the app: via Docker or locally. Docker is advised for testing & production use. Locally is advised for development. diff --git a/example-apps/chatbot-rag-app/api/chat.py b/example-apps/chatbot-rag-app/api/chat.py index 929dae5a..bc96020d 100644 --- a/example-apps/chatbot-rag-app/api/chat.py +++ b/example-apps/chatbot-rag-app/api/chat.py @@ -64,7 +64,10 @@ def ask_question(question, session_id): answer = "" for chunk in get_llm().stream(qa_prompt): - yield f"data: {chunk.content}\n\n" + content = chunk.content.replace( + "\n", " " + ) # the stream can get messed up with newlines + yield f"data: {content}\n\n" answer += chunk.content yield f"data: {DONE_TAG}\n\n" diff --git a/example-apps/chatbot-rag-app/api/llm_integrations.py b/example-apps/chatbot-rag-app/api/llm_integrations.py index d6f4bb46..b73e2cad 100644 --- a/example-apps/chatbot-rag-app/api/llm_integrations.py +++ b/example-apps/chatbot-rag-app/api/llm_integrations.py @@ -3,8 +3,8 @@ ChatVertexAI, AzureChatOpenAI, BedrockChat, + ChatCohere, ) -from langchain_core.messages import HumanMessage from langchain_mistralai.chat_models import ChatMistralAI import os import vertexai @@ -76,12 +76,21 @@ def init_mistral_chat(temperature): return ChatMistralAI(**kwargs) +def init_cohere_chat(temperature): + COHERE_API_KEY = os.getenv("COHERE_API_KEY") + COHERE_MODEL = os.getenv("COHERE_MODEL") + return ChatCohere( + cohere_api_key=COHERE_API_KEY, model=COHERE_MODEL, temperature=temperature + ) + + MAP_LLM_TYPE_TO_CHAT_MODEL = { "azure": init_azure_chat, "bedrock": init_bedrock, "openai": init_openai_chat, "vertex": init_vertex_chat, "mistral": init_mistral_chat, + "cohere": init_cohere_chat, } diff --git a/example-apps/chatbot-rag-app/env.example b/example-apps/chatbot-rag-app/env.example index cb8081c1..2dc651cc 100644 --- a/example-apps/chatbot-rag-app/env.example +++ b/example-apps/chatbot-rag-app/env.example @@ -40,3 +40,8 @@ ES_INDEX_CHAT_HISTORY=workplace-app-docs-chat-history # MISTRAL_API_KEY= # MISTRAL_API_ENDPOINT= # MISTRAL_MODEL= + +# Uncomment and complete if you want to use Cohere +# LLM_TYPE=cohere +# COHERE_API_KEY= +# COHERE_MODEL= diff --git a/example-apps/chatbot-rag-app/requirements.in b/example-apps/chatbot-rag-app/requirements.in index 840c43ae..3c532644 100644 --- a/example-apps/chatbot-rag-app/requirements.in +++ b/example-apps/chatbot-rag-app/requirements.in @@ -23,6 +23,9 @@ boto3 # Mistral dependencies langchain-mistralai +# Cohere dependencies +cohere + # TBD if these are still needed exceptiongroup importlib-metadata diff --git a/example-apps/chatbot-rag-app/requirements.txt b/example-apps/chatbot-rag-app/requirements.txt index d5129714..1dec6beb 100644 --- a/example-apps/chatbot-rag-app/requirements.txt +++ b/example-apps/chatbot-rag-app/requirements.txt @@ -6,7 +6,9 @@ # aiohttp==3.8.5 # via + # cohere # langchain + # langchain-community # openai aiosignal==1.3.1 # via aiohttp @@ -14,12 +16,14 @@ annotated-types==0.5.0 # via pydantic anyio==3.7.1 # via - # langchain + # httpx # langchain-core async-timeout==4.0.3 # via aiohttp attrs==23.1.0 # via aiohttp +backoff==2.2.1 + # via cohere blinker==1.6.2 # via flask boto3==1.28.61 @@ -35,6 +39,8 @@ cachetools==5.3.1 certifi==2023.7.22 # via # elastic-transport + # httpcore + # httpx # requests charset-normalizer==3.2.0 # via @@ -44,8 +50,12 @@ click==8.1.7 # via # flask # pip-tools +cohere==4.52 + # via -r requirements.in dataclasses-json==0.5.14 - # via langchain + # via + # langchain + # langchain-community elastic-transport==8.4.0 # via elasticsearch elasticsearch==8.12.1 @@ -54,6 +64,10 @@ elasticsearch==8.12.1 # langchain-elasticsearch exceptiongroup==1.2.0 # via -r requirements.in +fastavro==1.9.4 + # via cohere +filelock==3.13.1 + # via huggingface-hub flask==2.3.3 # via # -r requirements.in @@ -64,6 +78,8 @@ frozenlist==1.4.0 # via # aiohttp # aiosignal +fsspec==2024.2.0 + # via huggingface-hub google-api-core[grpc]==2.14.0 # via # google-cloud-aiplatform @@ -112,13 +128,24 @@ grpcio-status==1.59.3 # via # -r requirements.in # google-api-core +h11==0.14.0 + # via httpcore +httpcore==1.0.4 + # via httpx +httpx==0.25.2 + # via mistralai +huggingface-hub==0.21.4 + # via tokenizers idna==3.4 # via # anyio + # httpx # requests # yarl importlib-metadata==6.8.0 - # via -r requirements.in + # via + # -r requirements.in + # cohere itsdangerous==2.1.2 # via flask jinja2==3.1.2 @@ -135,13 +162,22 @@ jsonpointer==2.4 # via jsonpatch langchain==0.1.9 # via -r requirements.in -langchain-core==0.1.23 - # via langchain-elasticsearch +langchain-community==0.0.27 + # via langchain +langchain-core==0.1.30 + # via + # langchain + # langchain-community + # langchain-elasticsearch + # langchain-mistralai langchain-elasticsearch==0.1.0 # via -r requirements.in +langchain-mistralai==0.0.5 + # via -r requirements.in langsmith==0.1.10 # via # langchain + # langchain-community # langchain-core markupsafe==2.1.3 # via @@ -149,6 +185,8 @@ markupsafe==2.1.3 # werkzeug marshmallow==3.20.1 # via dataclasses-json +mistralai==0.1.3 + # via langchain-mistralai multidict==6.0.4 # via # aiohttp @@ -160,18 +198,28 @@ numexpr==2.8.5 numpy==1.25.2 # via # langchain + # langchain-community # langchain-elasticsearch # numexpr + # pandas + # pyarrow # shapely openai==0.27.9 # via -r requirements.in +orjson==3.9.15 + # via + # langsmith + # mistralai packaging==23.2 # via # build # google-cloud-aiplatform # google-cloud-bigquery + # huggingface-hub # langchain-core # marshmallow +pandas==2.2.1 + # via mistralai pip-tools==7.3.0 # via -r requirements.in proto-plus==1.22.3 @@ -189,6 +237,8 @@ protobuf==4.25.1 # grpc-google-iam-v1 # grpcio-status # proto-plus +pyarrow==15.0.1 + # via mistralai pyasn1==0.5.0 # via # pyasn1-modules @@ -200,6 +250,7 @@ pydantic==2.5.2 # langchain # langchain-core # langsmith + # mistralai pydantic-core==2.14.5 # via pydantic pyproject-hooks==1.0.0 @@ -208,20 +259,28 @@ python-dateutil==2.8.2 # via # botocore # google-cloud-bigquery + # pandas python-dotenv==1.0.0 # via -r requirements.in +pytz==2024.1 + # via pandas pyyaml==6.0.1 # via + # huggingface-hub # langchain + # langchain-community # langchain-core regex==2023.10.3 # via tiktoken requests==2.31.0 # via + # cohere # google-api-core # google-cloud-bigquery # google-cloud-storage + # huggingface-hub # langchain + # langchain-community # langchain-core # langsmith # openai @@ -235,28 +294,41 @@ shapely==2.0.2 six==1.16.0 # via python-dateutil sniffio==1.3.0 - # via anyio + # via + # anyio + # httpx sqlalchemy==2.0.20 - # via langchain + # via + # langchain + # langchain-community tenacity==8.2.3 # via # langchain + # langchain-community # langchain-core tiktoken==0.5.1 # via -r requirements.in +tokenizers==0.15.2 + # via langchain-mistralai tqdm==4.66.1 - # via openai + # via + # huggingface-hub + # openai typing-extensions==4.7.1 # via + # huggingface-hub # pydantic # pydantic-core # sqlalchemy # typing-inspect typing-inspect==0.9.0 # via dataclasses-json +tzdata==2024.1 + # via pandas urllib3==1.26.16 # via # botocore + # cohere # elastic-transport # requests werkzeug==2.3.7 @@ -268,8 +340,6 @@ yarl==1.9.2 zipp==3.17.0 # via importlib-metadata -langchain-mistralai==0.0.5 - # via -r requirements.in # The following packages are considered to be unsafe in a requirements file: # pip # setuptools