From 3a818590b7ed1117056a6684f1d026e27fc0884d Mon Sep 17 00:00:00 2001 From: "mariia.berdnyk" Date: Fri, 14 Mar 2025 16:07:12 +0100 Subject: [PATCH] fixing the small bug, adding the attempts count into benchmarking --- common/llm_calls.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/common/llm_calls.py b/common/llm_calls.py index 736723f..649682f 100644 --- a/common/llm_calls.py +++ b/common/llm_calls.py @@ -52,7 +52,7 @@ def extract_json_from_response(response): return None if not code_string else code_string -def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[str | None], List[str | None], List[int], List[float]]: +def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[str | None], List[str | None], List[int], List[float], List[int]]: try: import ollama except ImportError as e: @@ -63,6 +63,7 @@ def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[s generated_responses = [] completion_tokens = [] exec_times = [] + attempts = [] options = model_config["options"] options["additional_num_ctx"] += len(system_prompt) + len(max(user_prompts, key=len)) @@ -89,15 +90,16 @@ def call_ollama(model_config: Dict, system_prompt, user_prompts) -> Tuple[List[s completion_tokens.append(response_cur["eval_count"]) generated_codes.append(generated_code) generated_responses.append(generated_response) + attempts.append(attempt) break except Exception as e: print(f"Attempt {attempt + 1} failed with error: {e}") time.sleep(1) - return generated_responses, generated_codes, completion_tokens, exec_times + return generated_responses, generated_codes, completion_tokens, exec_times, attempts -def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple[List[str | None], List[str | None], List[int], List[float]]: +def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple[List[str | None], List[str | None], List[int], List[float], List[int]]: try: from langchain_core.messages import SystemMessage, HumanMessage from langchain_ibm import ChatWatsonx @@ -120,6 +122,7 @@ def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple generated_responses = [] number_tokens = [] exec_times = [] + attempts = [] for user_prompt in user_prompts: if len(messages) == 1: @@ -141,21 +144,22 @@ def call_watsonxai(model_config: Dict, system_prompt, user_prompts: []) -> Tuple number_tokens.append(response_cur.response_metadata['token_usage']['completion_tokens']) generated_codes.append(generated_code) generated_responses.append(generated_response) + attempts.append(attempt) break except Exception as e: print(f"Attempt {attempt + 1} failed with error: {e}") time.sleep(1) - return generated_responses, generated_codes, number_tokens, exec_times + return generated_responses, generated_codes, number_tokens, exec_times, attempts -def call_api(llm_api, model_config, system_prompt, user_prompts) -> Tuple[List[str], List[str], List[int], List[float]]: +def call_api(llm_api, model_config, system_prompt, user_prompts) -> Tuple[List[str], List[str], List[int], List[float], List[int]]: if llm_api == LLM_API.OLLAMA: return call_ollama(model_config, system_prompt, user_prompts) elif llm_api == LLM_API.WATSONXAI: return call_watsonxai(model_config, system_prompt, user_prompts) - return ["No supported LLM API type provided"], [""], [0], [0] + return ["No supported LLM API type provided"], [""], [0], [0], [0] def call_llm(policy_description_file_path, csv_file, data_generator_or_columns: DataGenerator | List[str], llm_api: LLM_API, config_file_path, result_output_path): @@ -170,7 +174,8 @@ def call_llm(policy_description_file_path, csv_file, data_generator_or_columns: dataFull = pd.read_csv(csv_file, na_filter=True).fillna("") dataFull = dataFull.sample(frac=1).reset_index(drop=True) - + dataFull = dataFull.head(1) + dataFull.to_csv(f'{time.time()}_mixed_test_data.csv', index=False) # Handle both DataGenerator and direct list input if isinstance(data_generator_or_columns, list): eval_column_names = data_generator_or_columns # Direct list of column names @@ -189,14 +194,15 @@ def call_llm(policy_description_file_path, csv_file, data_generator_or_columns: user_prompts.append(user_prompt_default.format(test_case=case.to_json())) print("Calling LLM with the created user prompts...") - generated_responses, generated_answers, numbers_tokens, exec_times = call_api(llm_api, model_config, system_prompt, user_prompts) + generated_responses, generated_answers, numbers_tokens, exec_times, attempts = call_api(llm_api, model_config, system_prompt, user_prompts) for idx in range(len(generated_answers)): results[idx] = { "test_case": dataFull.iloc[idx].to_dict(), "generated_response": generated_responses[idx], "generated_answer": json.loads(generated_answers[idx]) if generated_answers[idx] else None, "number_tokens": numbers_tokens[idx], - "execution_time": exec_times[idx] + "execution_time": exec_times[idx], + "attempt": attempts[idx] } print("Saving the generation results...") @@ -263,11 +269,10 @@ def load_data_generator_or_columns(value): help="The API to be used for the LLM call (ollama/watsonx).") parser.add_argument("--config_file", type=str, required=True, help="The model & api configuration file path.") parser.add_argument("--output_file", type=str, required=True, help="The path for the benchmarking results output") - parser.add_argument("--output_file", type=str, required=True, help="The path for the benchmarking results output") parser.add_argument("--column_mapping", type=str, required=False, help="Optional JSON string for column name mapping for benchmark metrics calculation. " "The way they are saved in reference csv dataset vs the way they are saved in the resulting output_file (generated by LLM) " - "(e.g., '{\"eligibility\": \"elig\"}').") + "(e.g., '{\"eligibility\": \"eligible\"}').") args = parser.parse_args()