Skip to content

Commit

Permalink
Adding dataset_name to API (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf authored and debermudez committed Mar 13, 2024
1 parent 3254c72 commit b407451
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
35 changes: 21 additions & 14 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def create_llm_inputs(
input_type: InputType,
input_format: InputFormat,
output_format: OutputFormat,
dataset_name: str = "",
model_name: str = "",
input_filename: str = "",
starting_index: int = DEFAULT_STARTING_INDEX,
Expand All @@ -91,6 +92,8 @@ def create_llm_inputs(
Optional Parameters
-------------------
dataset_name:
The name of the dataset
model_name:
The model name
starting_index:
Expand All @@ -103,12 +106,14 @@ def create_llm_inputs(
If true adds a steam field to each payload
"""

LlmInputs._check_for_valid_args(input_type, model_name, starting_index, length)
LlmInputs._check_for_valid_args(
input_type, dataset_name, starting_index, length
)

dataset = None
if input_type == InputType.URL:
dataset = LlmInputs._get_input_dataset_from_url(
model_name, starting_index, length
dataset_name, starting_index, length
)
else:
raise GenAiPAException(
Expand All @@ -128,32 +133,34 @@ def create_llm_inputs(

@classmethod
def _check_for_valid_args(
cls, input_type: InputType, model_name: str, starting_index: int, length: int
cls, input_type: InputType, dataset_name: str, starting_index: int, length: int
) -> None:
try:
LlmInputs._check_for_model_name_if_input_type_is_url(input_type, model_name)
LlmInputs._check_for_dataset_name_if_input_type_is_url(
input_type, dataset_name
)
LlmInputs._check_for_valid_starting_index(starting_index)
LlmInputs._check_for_valid_length(length)
except Exception as e:
raise GenAiPAException(e)

@classmethod
def _get_input_dataset_from_url(
cls, model_name: str, starting_index: int, length: int
cls, dataset_name: str, starting_index: int, length: int
) -> Response:
url = LlmInputs._resolve_url(model_name)
url = LlmInputs._resolve_url(dataset_name)
configured_url = LlmInputs._create_configured_url(url, starting_index, length)
dataset = LlmInputs._download_dataset(configured_url, starting_index, length)

return dataset

@classmethod
def _resolve_url(cls, model_name: str) -> str:
if model_name in LlmInputs.dataset_url_map:
return LlmInputs.dataset_url_map[model_name]
def _resolve_url(cls, dataset_name: str) -> str:
if dataset_name in LlmInputs.dataset_url_map:
return LlmInputs.dataset_url_map[dataset_name]
else:
raise GenAiPAException(
f"{model_name} does not have a corresponding URL in the dataset_url_map."
f"{dataset_name} does not have a corresponding URL in the dataset_url_map."
)

@classmethod
Expand Down Expand Up @@ -503,12 +510,12 @@ def _add_optional_tags_to_vllm_json(
return pa_json

@classmethod
def _check_for_model_name_if_input_type_is_url(
cls, input_type: InputType, model_name: str
def _check_for_dataset_name_if_input_type_is_url(
cls, input_type: InputType, dataset_name: str
) -> None:
if input_type == InputType.URL and not model_name:
if input_type == InputType.URL and not dataset_name:
raise GenAiPAException(
"Input type is URL, but model_name is not specified."
"Input type is URL, but dataset_name is not specified."
)

@classmethod
Expand Down
15 changes: 8 additions & 7 deletions src/c++/perf_analyzer/genai-pa/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def default_configured_url(self):

# TODO: Add tests that verify json schemas

def test_input_type_url_no_model_name(self):
def test_input_type_url_no_dataset_name(self):
"""
Test for exception when input type is URL and no model name
Test for exception when input type is URL and no dataset name
"""
with pytest.raises(GenAiPAException):
_ = LlmInputs._check_for_model_name_if_input_type_is_url(
input_type=InputType.URL, model_name=""
_ = LlmInputs._check_for_dataset_name_if_input_type_is_url(
input_type=InputType.URL, dataset_name=""
)

def test_illegal_starting_index(self):
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_llm_inputs_error_in_server_response(self):
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.OPENAI,
model_name=OPEN_ORCA,
dataset_name=OPEN_ORCA,
starting_index=LlmInputs.DEFAULT_STARTING_INDEX,
length=int(LlmInputs.DEFAULT_LENGTH * 100),
)
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_create_openai_llm_inputs_cnn_dailymail(self):
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.OPENAI,
model_name=CNN_DAILY_MAIL,
dataset_name=CNN_DAILY_MAIL,
)

os.remove(DEFAULT_INPUT_DATA_JSON)
Expand All @@ -193,6 +193,7 @@ def test_write_to_file(self):
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.OPENAI,
dataset_name=OPEN_ORCA,
model_name=OPEN_ORCA,
add_model_name=True,
add_stream=True,
Expand All @@ -214,7 +215,7 @@ def test_create_openai_to_vllm(self):
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.VLLM,
model_name=OPEN_ORCA,
dataset_name=OPEN_ORCA,
add_model_name=False,
add_stream=True,
)
Expand Down

0 comments on commit b407451

Please sign in to comment.