Skip to content

Commit

Permalink
feat: check verification result for contradictions
Browse files Browse the repository at this point in the history
Sometimes LLMs would incorrectly fail the verification and provide an
explanation that contradicts itself.

For example, given the following verification
"square A is positioned to the left of square B",
the LLM sometimes fails it but provide weird explanation:
"The visual representation shows square A to the right of square B,
indicating that square B is not positioned to the left of square A."

This commit compensates for such behavior by double-checking the
expected result and verification explanation for contradictions.
  • Loading branch information
p0deje committed Nov 19, 2024
1 parent 80dcbc6 commit b890ed7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 5 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ concurrency:
cancel-in-progress: true

env:
ALUMNIUM_DEBUG: 1
ALUMNIUM_MODEL: azure_openai
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_VERSION: ${{ secrets.AZURE_OPENAI_API_VERSION }}
Expand Down
1 change: 1 addition & 0 deletions alumnium/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .actor_agent import ActorAgent
from .contradiction_checker_agent import ContradictionCheckerAgent
from .loading_detector_agent import LoadingDetectorAgent
from .verifier_agent import VerifierAgent
41 changes: 41 additions & 0 deletions alumnium/agents/contradiction_checker_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from pathlib import Path

from langchain_core.language_models import BaseChatModel
from pydantic import BaseModel, Field


logger = logging.getLogger(__name__)


class Response(BaseModel):
result: bool = Field(description="True if contradiction is detected, False otherwise.")


class ContradictionCheckerAgent:
with open(Path(__file__).parent / "contradiction_checker_prompts/user.md") as f:
USER_MESSAGE = f.read()

def __init__(self, llm: BaseChatModel):
self.chain = llm.with_structured_output(Response, include_raw=True)

def invoke(self, statement: str, verification_explanation: str):
logger.info(f"Starting contradiction checking:")

message = self.chain.invoke(
[
(
"human",
self.USER_MESSAGE.format(
statement=statement,
verification_explanation=verification_explanation,
),
),
]
)

result = message["parsed"]
logger.info(f" <- Result: {result.result}")
logger.info(f' <- Usage: {message["raw"].usage_metadata}')

return result.result
4 changes: 4 additions & 0 deletions alumnium/agents/contradiction_checker_prompts/user.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Does the following two statements contradict each other? 

1. {statement}
2. {verification_explanation}
6 changes: 4 additions & 2 deletions alumnium/agents/verifier_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, Field

from alumnium.drivers import SeleniumDriver
from . import LoadingDetectorAgent
from . import ContradictionCheckerAgent, LoadingDetectorAgent

logger = logging.getLogger(__name__)

Expand All @@ -28,6 +28,7 @@ def __init__(self, driver: SeleniumDriver, llm: BaseChatModel):
self.driver = driver
self.chain = llm.with_structured_output(Verification, include_raw=True)

self.contradiction_checker_agent = ContradictionCheckerAgent(llm)
self.loading_detector_agent = LoadingDetectorAgent(llm)
self.retry_count = LoadingDetectorAgent.timeout / LoadingDetectorAgent.delay

Expand Down Expand Up @@ -79,4 +80,5 @@ def invoke(self, statement: str, vision: bool = False):
self.retry_count -= 1
return self.invoke(statement, vision)
else:
raise e
if self.contradiction_checker_agent.invoke(statement, verification.explanation):
raise e
4 changes: 2 additions & 2 deletions examples/pytest/drag_and_drop_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
def test_drag_and_drop(al, driver):
driver.get("https://the-internet.herokuapp.com/drag_and_drop")
al.check("square A is positioned to the left from square B", vision=True)
al.check("square A is positioned to the left of square B", vision=True)
al.do("move square A to square B")
al.check("square B is positioned to the left from square A", vision=True)
al.check("square B is positioned to the left of square A", vision=True)

0 comments on commit b890ed7

Please sign in to comment.