1
1
import argparse
2
2
import asyncio
3
+ import json
3
4
import logging
4
5
import os
5
6
import pathlib
6
7
from enum import Enum
7
- from typing import Any , Optional
8
8
9
9
import requests
10
10
from azure .ai .evaluation import ContentSafetyEvaluator
@@ -48,36 +48,15 @@ def get_azure_credential():
48
48
49
49
async def callback (
50
50
messages : list [dict ],
51
- stream : bool = False ,
52
- session_state : Any = None ,
53
- context : Optional [dict [str , Any ]] = None ,
54
51
target_url : str = "http://127.0.0.1:8000/chat" ,
55
52
):
56
53
messages_list = messages ["messages" ]
57
- latest_message = messages_list [- 1 ]
58
- query = latest_message ["content" ]
54
+ query = messages_list [- 1 ]["content" ]
59
55
headers = {"Content-Type" : "application/json" }
60
56
body = {
61
57
"messages" : [{"content" : query , "role" : "user" }],
62
- "stream" : stream ,
63
- "context" : {
64
- "overrides" : {
65
- "top" : 3 ,
66
- "temperature" : 0.3 ,
67
- "minimum_reranker_score" : 0 ,
68
- "minimum_search_score" : 0 ,
69
- "retrieval_mode" : "hybrid" ,
70
- "semantic_ranker" : True ,
71
- "semantic_captions" : False ,
72
- "suggest_followup_questions" : False ,
73
- "use_oid_security_filter" : False ,
74
- "use_groups_security_filter" : False ,
75
- "vector_fields" : ["embedding" ],
76
- "use_gpt4v" : False ,
77
- "gpt4v_input" : "textAndImages" ,
78
- "seed" : 1 ,
79
- }
80
- },
58
+ "stream" : False ,
59
+ "context" : {"overrides" : {"use_advanced_flow" : True , "top" : 3 , "retrieval_mode" : "hybrid" , "temperature" : 0.3 }},
81
60
}
82
61
url = target_url
83
62
r = requests .post (url , headers = headers , json = body )
@@ -86,8 +65,7 @@ async def callback(
86
65
message = {"content" : response ["error" ], "role" : "assistant" }
87
66
else :
88
67
message = response ["message" ]
89
- response ["messages" ] = messages_list + [message ]
90
- return response
68
+ return {"messages" : messages_list + [message ]}
91
69
92
70
93
71
async def run_simulator (target_url : str , max_simulations : int ):
@@ -104,9 +82,7 @@ async def run_simulator(target_url: str, max_simulations: int):
104
82
105
83
outputs = await adversarial_simulator (
106
84
scenario = scenario ,
107
- target = lambda messages , stream = False , session_state = None , context = None : callback (
108
- messages , stream , session_state , context , target_url
109
- ),
85
+ target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
110
86
max_simulation_results = max_simulations ,
111
87
language = SupportedLanguages .English , # Match this to your app language
112
88
randomization_seed = 1 , # For more consistent results, use a fixed seed
@@ -139,10 +115,9 @@ async def run_simulator(target_url: str, max_simulations: int):
139
115
else :
140
116
summary_scores [evaluator ]["mean_score" ] = 0
141
117
summary_scores [evaluator ]["low_rate" ] = 0
118
+
142
119
# Save summary scores
143
120
with open (root_dir / "safety_results.json" , "w" ) as f :
144
- import json
145
-
146
121
json .dump (summary_scores , f , indent = 2 )
147
122
148
123
0 commit comments