diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 514991df8..11a82ee1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,6 +57,7 @@ repos: "types-requests", "sqlmodel", "types-Markdown", + "types-cachetools", types-tzlocal, ] args: ["--check-untyped-defs", "--ignore-missing-imports"] diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 91d89612c..fd658adb5 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -26,6 +26,7 @@ from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls from ...utils.commands import WEB_SEARCH_COMMAND +from ...utils.hf_papers import get_recommended_papers from ...utils.rate_limit import check_rate_limit from .chat_panel import ChatPanel from .chat_suggestion import ChatSuggestion @@ -68,6 +69,36 @@ } """ +recommended_papers_js = """ +function() { + // Get all links and attach click event + var links = document.querySelectorAll("#related-papers a"); + + function submitPaper(event) { + event.preventDefault(); + var target = event.currentTarget; + var url = target.getAttribute("href"); + console.log("URL:", url); + + let newChatButton = document.querySelector("#new-conv-button"); + newChatButton.click(); + + setTimeout(() => { + let urlInput = document.querySelector("#quick-url-demo textarea"); + // Fill the URL input + urlInput.value = url; + urlInput.dispatchEvent(new Event("input", { bubbles: true })); + urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'})); + }, 500 + ); + } + + for (var i = 0; i < links.length; i++) { + links[i].onclick = submitPaper; + } +} +""" + clear_bot_message_selection_js = """ function() { var bot_messages = document.querySelectorAll( @@ -268,14 +299,17 @@ def on_building_ui(self): if not KH_DEMO_MODE: self.report_issue = ReportIssue(self._app) else: + with gr.Accordion(label="Related papers", open=False): + self.related_papers = gr.Markdown(elem_id="related-papers") + self.hint_page = HintPage(self._app) with gr.Column(scale=6, elem_id="chat-area"): - self.chat_panel = ChatPanel(self._app) - if KH_DEMO_MODE: self.paper_list = PaperListPage(self._app) + self.chat_panel = ChatPanel(self._app) + with gr.Accordion( label="Chat settings", elem_id="chat-settings-expand", @@ -360,6 +394,19 @@ def _json_to_plot(self, json_dict: dict | None): return plot def on_register_events(self): + # first index paper recommendation + if KH_DEMO_MODE and len(self._indices_input) > 0: + self._indices_input[1].change( + self.get_recommendations, + inputs=[self.first_selector_choices, self._indices_input[1]], + outputs=[self.related_papers], + ).then( + fn=None, + inputs=None, + outputs=None, + js=recommended_papers_js, + ) + chat_event = ( gr.on( triggers=[ @@ -916,6 +963,17 @@ def submit_msg( + [used_command] ) + def get_recommendations(self, first_selector_choices, file_ids): + first_selector_choices_map = { + item[1]: item[0] for item in first_selector_choices + } + file_names = [first_selector_choices_map[file_id] for file_id in file_ids] + if not file_names: + return "" + + first_file_name = file_names[0].split(".")[0].replace("_", " ") + return get_recommended_papers(first_file_name) + def toggle_delete(self, conv_id): if conv_id: return gr.update(visible=False), gr.update(visible=True) diff --git a/libs/ktem/ktem/pages/chat/paper_list.py b/libs/ktem/ktem/pages/chat/paper_list.py index f198c14d9..bddf4a432 100644 --- a/libs/ktem/ktem/pages/chat/paper_list.py +++ b/libs/ktem/ktem/pages/chat/paper_list.py @@ -13,8 +13,8 @@ def __init__(self, app): def on_building_ui(self): self.papers_state = gr.State(None) with gr.Accordion( - label="Browse daily top papers", - open=False, + label="Browse popular daily papers", + open=True, ) as self.accordion: self.examples = gr.DataFrame( value=[], diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index eed1b7871..1a87a00ea 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -435,7 +435,14 @@ def get_user_settings(cls) -> dict: }, "system_prompt": { "name": "System Prompt", - "value": "This is a question answering system", + "value": dedent( + """This is a question answering system. + Organize the answer in bullet points if applicable. + When asked for paper summary, provide a brief summary of the paper + with the following sections: + Background, Hypothesis, Method, Results, Conclusion & Future Work. + """ + ), }, "qa_prompt": { "name": "QA Prompt (contains {context}, {question}, {lang})", diff --git a/libs/ktem/ktem/utils/hf_papers.py b/libs/ktem/ktem/utils/hf_papers.py index afcd51e84..f755671bd 100644 --- a/libs/ktem/ktem/utils/hf_papers.py +++ b/libs/ktem/ktem/utils/hf_papers.py @@ -1,9 +1,15 @@ from datetime import datetime, timedelta import requests +from cachetools import TTLCache, cached HF_API_URL = "https://huggingface.co/api/daily_papers" ARXIV_URL = "https://arxiv.org/abs/{paper_id}" +SEMANTIC_SCHOLAR_QUERY_URL = "https://api.semanticscholar.org/graph/v1/paper/search/match?query={paper_name}" # noqa +SEMANTIC_SCHOLAR_RECOMMEND_URL = ( + "https://api.semanticscholar.org/recommendations/v1/papers/" # noqa +) +CACHE_TIME = 60 * 60 * 6 # 6 hours # Function to parse the date string @@ -11,6 +17,69 @@ def parse_date(date_str): return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ") +@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME)) +def get_recommendations_from_semantic_scholar(semantic_scholar_id: str): + try: + r = requests.post( + SEMANTIC_SCHOLAR_RECOMMEND_URL, + json={ + "positivePaperIds": [semantic_scholar_id], + }, + params={"fields": "externalIds,title,year", "limit": 14}, # type: ignore + ) + return r.json()["recommendedPapers"] + except KeyError as e: + print(e) + return [] + + +def filter_recommendations(recommendations, max_paper_count=5): + # include only arxiv papers + arxiv_paper = [ + r for r in recommendations if r["externalIds"].get("ArXiv", None) is not None + ] + if len(arxiv_paper) > max_paper_count: + arxiv_paper = arxiv_paper[:max_paper_count] + return arxiv_paper + + +def format_recommendation_into_markdown(recommendations): + comment = "(recommended by the Semantic Scholar API)\n\n" + for r in recommendations: + hub_paper_url = f"https://arxiv.org/abs/{r['externalIds']['ArXiv']}" + comment += f"* [{r['title']}]({hub_paper_url}) ({r['year']})\n" + + return comment + + +def get_paper_id_from_name(paper_name): + try: + response = requests.get( + SEMANTIC_SCHOLAR_QUERY_URL.format(paper_name=paper_name) + ) + response.raise_for_status() + items = response.json() + paper_id = items.get("data", [])[0].get("paperId") + except Exception as e: + print(e) + return None + + return paper_id + + +def get_recommended_papers(paper_name): + paper_id = get_paper_id_from_name(paper_name) + recommended_content = "" + if paper_id is None: + return recommended_content + + recommended_papers = get_recommendations_from_semantic_scholar(paper_id) + filtered_recommendations = filter_recommendations(recommended_papers) + + recommended_content = format_recommendation_into_markdown(filtered_recommendations) + return recommended_content + + def fetch_papers(top_n=5): try: response = requests.get(f"{HF_API_URL}?limit=100")