Skip to content

Commit

Permalink
updated check for model repo to align with checks for other options
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 authored Jan 22, 2024
1 parent f0ff1e5 commit 6e91ff3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def test_stop(self):

server.stop()

def test_model_repository_not_specified(self):
with self.assertRaises(tritonserver.InvalidArgumentError):
tritonserver.Server(model_repository=None).start()


class InferenceTests(unittest.TestCase):
def setup_method(self, method):
Expand Down Expand Up @@ -452,7 +456,3 @@ def test_basic_inference(self):
):
fp16_output = numpy.from_dlpack(response.outputs["fp16_output"])
numpy.testing.assert_array_equal(fp16_input, fp16_output)

def test_model_repository_not_specified(self):
with self.assertRaises(tritonserver.InvalidArgumentError):
tritonserver.Server(model_repository=None)
4 changes: 2 additions & 2 deletions python/tritonserver/_api/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def _create_tritonserver_server_options(

options.set_server_id(self.server_id)

if self.model_repository is None:
raise InvalidArgumentError("Model repository must be specified.")
if not isinstance(self.model_repository, list):
self.model_repository = [self.model_repository]
for model_repository_path in self.model_repository:
Expand Down Expand Up @@ -526,8 +528,6 @@ def __init__(
if options is None:
options = Options(**kwargs)
self.options: Options = options
if self.options.model_repository is None:
raise InvalidArgumentError("Model repository must be specified.")
self._server = Server._UnstartedServer()

def start(
Expand Down

0 comments on commit 6e91ff3

Please sign in to comment.