Skip to content

fixing the small bug, adding the attempts count into benchmarking #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions common/llm_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def extract_json_from_response(response):
return None if not code_string else code_string


def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[str | None], List[str | None], List[int], List[float]]:
def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[str | None], List[str | None], List[int], List[float], List[int]]:
try:
import ollama
except ImportError as e:
Expand All @@ -63,6 +63,7 @@ def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[s
generated_responses = []
completion_tokens = []
exec_times = []
attempts = []

options = model_config["options"]
options["additional_num_ctx"] += len(system_prompt) + len(max(user_prompts, key=len))
Expand All @@ -89,15 +90,16 @@ def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[s
completion_tokens.append(response_cur["eval_count"])
generated_codes.append(generated_code)
generated_responses.append(generated_response)
attempts.append(attempt)
break
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
time.sleep(1)

return generated_responses, generated_codes, completion_tokens, exec_times
return generated_responses, generated_codes, completion_tokens, exec_times, attempts


def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple[List[str | None], List[str | None], List[int], List[float]]:
def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple[List[str | None], List[str | None], List[int], List[float], List[int]]:
try:
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_ibm import ChatWatsonx
Expand All @@ -120,6 +122,7 @@ def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple
generated_responses = []
number_tokens = []
exec_times = []
attempts = []

for user_prompt in user_prompts:
if len(messages) == 1:
Expand All @@ -141,21 +144,22 @@ def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple
number_tokens.append(response_cur.response_metadata['token_usage']['completion_tokens'])
generated_codes.append(generated_code)
generated_responses.append(generated_response)
attempts.append(attempt)
break
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
time.sleep(1)

return generated_responses, generated_codes, number_tokens, exec_times
return generated_responses, generated_codes, number_tokens, exec_times, attempts


def call_api(llm_api, model_config, system_prompt, user_prompts) -> Tuple[List[str], List[str], List[int], List[float]]:
def call_api(llm_api, model_config, system_prompt, user_prompts) -> Tuple[List[str], List[str], List[int], List[float], List[int]]:
if llm_api == LLM_API.OLLAMA:
return call_ollama(model_config, system_prompt, user_prompts)
elif llm_api == LLM_API.WATSONXAI:
return call_watsonxai(model_config, system_prompt, user_prompts)

return ["No supported LLM API type provided"], [""], [0], [0]
return ["No supported LLM API type provided"], [""], [0], [0], [0]


def call_llm(policy_description_file_path, csv_file, data_generator_or_columns: DataGenerator | List[str], llm_api: LLM_API, config_file_path, result_output_path):
Expand All @@ -170,7 +174,8 @@ def call_llm(policy_description_file_path, csv_file, data_generator_or_columns:

dataFull = pd.read_csv(csv_file, na_filter=True).fillna("")
dataFull = dataFull.sample(frac=1).reset_index(drop=True)

dataFull = dataFull.head(1)
dataFull.to_csv(f'{time.time()}_mixed_test_data.csv', index=False)
# Handle both DataGenerator and direct list input
if isinstance(data_generator_or_columns, list):
eval_column_names = data_generator_or_columns # Direct list of column names
Expand All @@ -189,14 +194,15 @@ def call_llm(policy_description_file_path, csv_file, data_generator_or_columns:
user_prompts.append(user_prompt_default.format(test_case=case.to_json()))

print("Calling LLM with the created user prompts...")
generated_responses, generated_answers, numbers_tokens, exec_times = call_api(llm_api, model_config, system_prompt, user_prompts)
generated_responses, generated_answers, numbers_tokens, exec_times, attempts = call_api(llm_api, model_config, system_prompt, user_prompts)
for idx in range(len(generated_answers)):
results[idx] = {
"test_case": dataFull.iloc[idx].to_dict(),
"generated_response": generated_responses[idx],
"generated_answer": json.loads(generated_answers[idx]) if generated_answers[idx] else None,
"number_tokens": numbers_tokens[idx],
"execution_time": exec_times[idx]
"execution_time": exec_times[idx],
"attempt": attempts[idx]
}

print("Saving the generation results...")
Expand Down Expand Up @@ -263,11 +269,10 @@ def load_data_generator_or_columns(value):
help="The API to be used for the LLM call (ollama/watsonx).")
parser.add_argument("--config_file", type=str, required=True, help="The model & api configuration file path.")
parser.add_argument("--output_file", type=str, required=True, help="The path for the benchmarking results output")
parser.add_argument("--output_file", type=str, required=True, help="The path for the benchmarking results output")
parser.add_argument("--column_mapping", type=str, required=False,
help="Optional JSON string for column name mapping for benchmark metrics calculation. "
"The way they are saved in reference csv dataset vs the way they are saved in the resulting output_file (generated by LLM) "
"(e.g., '{\"eligibility\": \"elig\"}').")
"(e.g., '{\"eligibility\": \"eligible\"}').")

args = parser.parse_args()

Expand Down