Skip to content

Commit 020d289

Browse files
authored
[OAI Serving] Validate greedy generation args (#13113)
* validate greedy generation args Signed-off-by: jenchen13 <jennifchen@nvidia.com>
1 parent 11741e5 commit 020d289

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

nemo/deploy/service/fastapi_interface_to_pytriton.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import requests
1616
from fastapi import FastAPI, HTTPException
17-
from pydantic import BaseModel
17+
from pydantic import BaseModel, model_validator
1818
from pydantic_settings import BaseSettings
1919

2020
from nemo.deploy.nlp import NemoQueryLLMPyTorch
@@ -81,6 +81,14 @@ class CompletionRequest(BaseModel):
8181
top_k: int = 0
8282
logprobs: int = None
8383

84+
@model_validator(mode='after')
85+
def set_greedy_params(self):
86+
"""Validate parameters for greedy decoding."""
87+
if self.temperature == 0 and self.top_p == 0:
88+
logging.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.")
89+
self.top_k = 1
90+
return self
91+
8492

8593
@app.get("/v1/health")
8694
def health_check():

0 commit comments

Comments
 (0)