diff --git a/app/services/twitter/oauth2.py b/app/services/twitter/oauth2.py index 95f9b81..b343417 100644 --- a/app/services/twitter/oauth2.py +++ b/app/services/twitter/oauth2.py @@ -1,5 +1,7 @@ """Twitter OAuth2 authentication module.""" +from urllib.parse import urlencode + from fastapi import APIRouter, Depends from pydantic import BaseModel from requests.auth import HTTPBasicAuth @@ -30,11 +32,17 @@ def __init__(self, *, client_id, redirect_uri, scope, client_secret=None): self._client.create_code_verifier(128), "S256" ) - def get_authorization_url(self, agent_id: str): - """Get the authorization URL to redirect the user to""" + def get_authorization_url(self, agent_id: str, result_uri: str): + """Get the authorization URL to redirect the user to + + Args: + agent_id: ID of the agent to authenticate + result_uri: URI to redirect to after authorization + """ + state_params = {"agent_id": agent_id, "result_uri": result_uri} authorization_url, _ = self.authorization_url( "https://twitter.com/i/oauth2/authorize", - state=agent_id, + state=urlencode(state_params), code_challenge=self.code_challenge, code_challenge_method="S256", ) @@ -93,26 +101,28 @@ class TwitterAuthResponse(BaseModel): response_model=TwitterAuthResponse, dependencies=[Depends(verify_jwt)], ) -async def get_twitter_auth_url(agent_id: str) -> TwitterAuthResponse: +async def get_twitter_auth_url(agent_id: str, result_uri: str) -> TwitterAuthResponse: """Get Twitter OAuth2 authorization URL. Args: agent_id: ID of the agent to authenticate + result_uri: URI to redirect to after authorization Returns: Object containing agent_id and authorization URL """ - url = oauth2_user_handler.get_authorization_url(agent_id) + url = oauth2_user_handler.get_authorization_url(agent_id, result_uri) return TwitterAuthResponse(agent_id=agent_id, url=url) -def get_authorization_url(agent_id: str) -> str: +def get_authorization_url(agent_id: str, result_uri: str) -> str: """Get Twitter OAuth2 authorization URL. Args: agent_id: ID of the agent to authenticate + result_uri: URI to redirect to after authorization Returns: Authorization URL with agent_id as state parameter """ - return oauth2_user_handler.get_authorization_url(agent_id) + return oauth2_user_handler.get_authorization_url(agent_id, result_uri) diff --git a/app/services/twitter/oauth2_callback.py b/app/services/twitter/oauth2_callback.py index 8baab6c..5eb156f 100644 --- a/app/services/twitter/oauth2_callback.py +++ b/app/services/twitter/oauth2_callback.py @@ -1,10 +1,11 @@ """Twitter OAuth2 callback handler.""" from datetime import datetime, timezone +from urllib.parse import parse_qs, urlencode, urlparse import tweepy from fastapi import APIRouter, HTTPException -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, RedirectResponse from app.config.config import config from app.services.twitter.oauth2 import oauth2_user_handler @@ -13,6 +14,22 @@ router = APIRouter(prefix="/callback/auth", tags=["Callback"]) +def is_valid_url(url: str) -> bool: + """Check if a URL is valid. + + Args: + url: URL to validate + + Returns: + bool: True if URL is valid, False otherwise + """ + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except (ValueError, AttributeError, TypeError): + return False + + @router.get("/twitter") async def twitter_oauth_callback( state: str, @@ -25,11 +42,11 @@ async def twitter_oauth_callback( them in the database. Args: - state: Agent ID from authorization request + state: URL-encoded state containing agent_id and result_uri code: Authorization code from Twitter Returns: - JSONResponse with success message + JSONResponse or RedirectResponse depending on result_uri Raises: HTTPException: If state/code is missing or token exchange fails @@ -38,7 +55,16 @@ async def twitter_oauth_callback( raise HTTPException(status_code=400, detail="Missing state or code parameter") try: - agent_id = state + # Parse state parameter + state_params = parse_qs(state) + agent_id = state_params.get("agent_id", [""])[0] + result_uri = state_params.get("result_uri", [""])[0] + + if not agent_id: + raise HTTPException( + status_code=400, detail="Missing agent_id in state parameter" + ) + agent = await Agent.get(agent_id) if not agent: raise HTTPException(status_code=404, detail=f"Agent {agent_id} not found") @@ -64,21 +90,48 @@ async def twitter_oauth_callback( client = tweepy.Client(bearer_token=token["access_token"], return_type=dict) me = client.get_me(user_auth=False) + username = None if me and "data" in me: agent_data.twitter_id = me.get("data").get("id") - agent_data.twitter_username = me.get("data").get("username") + username = me.get("data").get("username") + agent_data.twitter_username = username agent_data.twitter_name = me.get("data").get("name") # Commit changes await agent_data.save() - return JSONResponse( - content={"message": "Authentication successful, you can close this window"}, - status_code=200, - ) + # Handle response based on result_uri + if result_uri and is_valid_url(result_uri): + params = {"twitter_auth": "success", "username": username} + redirect_url = ( + f"{result_uri}{'&' if '?' in result_uri else '?'}{urlencode(params)}" + ) + return RedirectResponse(url=redirect_url) + else: + return JSONResponse( + content={ + "message": "Authentication successful, you can close this window", + "username": username, + }, + status_code=200, + ) except HTTPException as http_exc: + # Handle error response + if result_uri and is_valid_url(result_uri): + params = {"twitter_auth": "failed", "error": str(http_exc.detail)} + redirect_url = ( + f"{result_uri}{'&' if '?' in result_uri else '?'}{urlencode(params)}" + ) + return RedirectResponse(url=redirect_url) # Re-raise HTTP exceptions to preserve their status codes raise http_exc except Exception as e: + # Handle error response for unexpected errors + if result_uri and is_valid_url(result_uri): + params = {"twitter_auth": "failed", "error": str(e)} + redirect_url = ( + f"{result_uri}{'&' if '?' in result_uri else '?'}{urlencode(params)}" + ) + return RedirectResponse(url=redirect_url) # For unexpected errors, use 500 status code raise HTTPException(status_code=500, detail=str(e)) diff --git a/skills/acolyt/ask.py b/skills/acolyt/ask.py index e3def55..e6bcc07 100644 --- a/skills/acolyt/ask.py +++ b/skills/acolyt/ask.py @@ -67,13 +67,11 @@ class AcolytAskGpt(AcolytBaseTool): """ name: str = "acolyt_ask_gpt" - description: str = ( - """ + description: str = """ This tool allows users to ask questions which are then sent to the Acolyt API. this should be run if the user requests to ask Acolyt explicitly. The API response is processed and summarized before being returned to the user. """ - ) args_schema: Type[BaseModel] = AcolytAskGptInput def _run(self, question: str) -> AcolytAskGptOutput: