Skip to content

Commit

Permalink
feat: add paper recommendation
Browse files Browse the repository at this point in the history
  • Loading branch information
taprosoft committed Jan 17, 2025
1 parent f4153c5 commit ad914bb
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 5 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ repos:
"types-requests",
"sqlmodel",
"types-Markdown",
"types-cachetools",
types-tzlocal,
]
args: ["--check-untyped-defs", "--ignore-missing-imports"]
Expand Down
62 changes: 60 additions & 2 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
from ...utils.commands import WEB_SEARCH_COMMAND
from ...utils.hf_papers import get_recommended_papers
from ...utils.rate_limit import check_rate_limit
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
Expand Down Expand Up @@ -68,6 +69,36 @@
}
"""

recommended_papers_js = """
function() {
// Get all links and attach click event
var links = document.querySelectorAll("#related-papers a");
function submitPaper(event) {
event.preventDefault();
var target = event.currentTarget;
var url = target.getAttribute("href");
console.log("URL:", url);
let newChatButton = document.querySelector("#new-conv-button");
newChatButton.click();
setTimeout(() => {
let urlInput = document.querySelector("#quick-url-demo textarea");
// Fill the URL input
urlInput.value = url;
urlInput.dispatchEvent(new Event("input", { bubbles: true }));
urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'}));
}, 500
);
}
for (var i = 0; i < links.length; i++) {
links[i].onclick = submitPaper;
}
}
"""

clear_bot_message_selection_js = """
function() {
var bot_messages = document.querySelectorAll(
Expand Down Expand Up @@ -268,14 +299,17 @@ def on_building_ui(self):
if not KH_DEMO_MODE:
self.report_issue = ReportIssue(self._app)
else:
with gr.Accordion(label="Related papers", open=False):
self.related_papers = gr.Markdown(elem_id="related-papers")

self.hint_page = HintPage(self._app)

with gr.Column(scale=6, elem_id="chat-area"):
self.chat_panel = ChatPanel(self._app)

if KH_DEMO_MODE:
self.paper_list = PaperListPage(self._app)

self.chat_panel = ChatPanel(self._app)

with gr.Accordion(
label="Chat settings",
elem_id="chat-settings-expand",
Expand Down Expand Up @@ -360,6 +394,19 @@ def _json_to_plot(self, json_dict: dict | None):
return plot

def on_register_events(self):
# first index paper recommendation
if KH_DEMO_MODE and len(self._indices_input) > 0:
self._indices_input[1].change(
self.get_recommendations,
inputs=[self.first_selector_choices, self._indices_input[1]],
outputs=[self.related_papers],
).then(
fn=None,
inputs=None,
outputs=None,
js=recommended_papers_js,
)

chat_event = (
gr.on(
triggers=[
Expand Down Expand Up @@ -916,6 +963,17 @@ def submit_msg(
+ [used_command]
)

def get_recommendations(self, first_selector_choices, file_ids):
first_selector_choices_map = {
item[1]: item[0] for item in first_selector_choices
}
file_names = [first_selector_choices_map[file_id] for file_id in file_ids]
if not file_names:
return ""

first_file_name = file_names[0].split(".")[0].replace("_", " ")
return get_recommended_papers(first_file_name)

def toggle_delete(self, conv_id):
if conv_id:
return gr.update(visible=False), gr.update(visible=True)
Expand Down
4 changes: 2 additions & 2 deletions libs/ktem/ktem/pages/chat/paper_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def __init__(self, app):
def on_building_ui(self):
self.papers_state = gr.State(None)
with gr.Accordion(
label="Browse daily top papers",
open=False,
label="Browse popular daily papers",
open=True,
) as self.accordion:
self.examples = gr.DataFrame(
value=[],
Expand Down
9 changes: 8 additions & 1 deletion libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,14 @@ def get_user_settings(cls) -> dict:
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",
"value": dedent(
"""This is a question answering system.
Organize the answer in bullet points if applicable.
When asked for paper summary, provide a brief summary of the paper
with the following sections:
Background, Hypothesis, Method, Results, Conclusion & Future Work.
"""
),
},
"qa_prompt": {
"name": "QA Prompt (contains {context}, {question}, {lang})",
Expand Down
69 changes: 69 additions & 0 deletions libs/ktem/ktem/utils/hf_papers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,85 @@
from datetime import datetime, timedelta

import requests
from cachetools import TTLCache, cached

HF_API_URL = "https://huggingface.co/api/daily_papers"
ARXIV_URL = "https://arxiv.org/abs/{paper_id}"
SEMANTIC_SCHOLAR_QUERY_URL = "https://api.semanticscholar.org/graph/v1/paper/search/match?query={paper_name}" # noqa
SEMANTIC_SCHOLAR_RECOMMEND_URL = (
"https://api.semanticscholar.org/recommendations/v1/papers/" # noqa
)
CACHE_TIME = 60 * 60 * 6 # 6 hours


# Function to parse the date string
def parse_date(date_str):
return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ")


@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME))
def get_recommendations_from_semantic_scholar(semantic_scholar_id: str):
try:
r = requests.post(
SEMANTIC_SCHOLAR_RECOMMEND_URL,
json={
"positivePaperIds": [semantic_scholar_id],
},
params={"fields": "externalIds,title,year", "limit": 14}, # type: ignore
)
return r.json()["recommendedPapers"]
except KeyError as e:
print(e)
return []


def filter_recommendations(recommendations, max_paper_count=5):
# include only arxiv papers
arxiv_paper = [
r for r in recommendations if r["externalIds"].get("ArXiv", None) is not None
]
if len(arxiv_paper) > max_paper_count:
arxiv_paper = arxiv_paper[:max_paper_count]
return arxiv_paper


def format_recommendation_into_markdown(recommendations):
comment = "(recommended by the Semantic Scholar API)\n\n"
for r in recommendations:
hub_paper_url = f"https://arxiv.org/abs/{r['externalIds']['ArXiv']}"
comment += f"* [{r['title']}]({hub_paper_url}) ({r['year']})\n"

return comment


def get_paper_id_from_name(paper_name):
try:
response = requests.get(
SEMANTIC_SCHOLAR_QUERY_URL.format(paper_name=paper_name)
)
response.raise_for_status()
items = response.json()
paper_id = items.get("data", [])[0].get("paperId")
except Exception as e:
print(e)
return None

return paper_id


def get_recommended_papers(paper_name):
paper_id = get_paper_id_from_name(paper_name)
recommended_content = ""
if paper_id is None:
return recommended_content

recommended_papers = get_recommendations_from_semantic_scholar(paper_id)
filtered_recommendations = filter_recommendations(recommended_papers)

recommended_content = format_recommendation_into_markdown(filtered_recommendations)
return recommended_content


def fetch_papers(top_n=5):
try:
response = requests.get(f"{HF_API_URL}?limit=100")
Expand Down

0 comments on commit ad914bb

Please sign in to comment.