diff --git a/runtime/python/server.py b/runtime/python/server.py index a4114ac..acd2505 100644 --- a/runtime/python/server.py +++ b/runtime/python/server.py @@ -12,7 +12,9 @@ code_file = os.getenv("MOD_NAME") arg_type = os.getenv("ARG_TYPE", "bytes") -spec = importlib.util.spec_from_file_location(name="function", location=f"./code/{code_file}.py") +spec = importlib.util.spec_from_file_location( + name="function", location=f"./code/{code_file}.py" +) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -47,55 +49,69 @@ def read_all(self) -> bytes: def __iter__(self) -> Iterator[bytes]: while self._remaining_length > 0: - read_buffer = ( - self._chunk_size - if self._remaining_length >= self._chunk_size - else self._remaining_length - ) + read_buffer = min(self._chunk_size, self._remaining_length) self._remaining_length -= read_buffer yield self._rfile.read(read_buffer) class Handler(BaseHTTPRequestHandler): def log_request(self, *args): - # Don't log successful requests info. Unsuccessful logged by log_error(). + # Don't log successful requests info; log errors only. pass - def _set_headers(self): + def _set_headers(self, content_length=None): + """Set response headers""" self.send_response(200) self.send_header("Content-Type", "application/octet-stream") + if content_length is not None: + self.send_header("Content-Length", str(content_length)) self.end_headers() def do_PUT(self): - content_length = int(self.headers["Content-Length"]) - reader = StreamWrapper(self.rfile, content_length, CHUNK_SIZE) - if CHUNK_SIZE == 0: - result = transform(reader.read_all()) + """Handles PUT requests by applying a transformation function to the request body""" + try: + content_length = int(self.headers.get("Content-Length", 0)) + reader = StreamWrapper(self.rfile, content_length, CHUNK_SIZE) + + if CHUNK_SIZE == 0: + result = transform(reader.read_all()) + self._set_headers(content_length=len(result)) + self.wfile.write(result) + return + + # Streaming transform: writer is expected to write bytes to response self._set_headers() - self.wfile.write(result) - return + transform(reader, self.wfile) - # TODO: validate if transform takes writer as input - # NOTE: for streaming transforms the writer is expected to write bytes into response as stream. - self._set_headers() - transform(reader, self.wfile) + except Exception as e: + self.send_error(500, f"Error processing PUT request: {e}") def do_GET(self): - if self.path == "/health": - self._set_headers() - self.wfile.write(b"Running") - return - - query_path = host_target + self.path - - if arg_type == "url": - result = transform(query_path) - else: - input_bytes = requests.get(query_path).content - result = transform(input_bytes) - - self._set_headers() - self.wfile.write(result) + """Handles GET requests by fetching data, transforming it, and returning the response""" + try: + if self.path == "/health": + response = b"Running" + self._set_headers(content_length=len(response)) + self.wfile.write(response) + return + + query_path = host_target + self.path + + if arg_type == "url": + result = transform(query_path) + else: + response = requests.get(query_path) + response.raise_for_status() # Raise an error if request failed + input_bytes = response.content + result = transform(input_bytes) + + self._set_headers(content_length=len(result)) + self.wfile.write(result) + + except requests.exceptions.RequestException as e: + self.send_error(500, f"Failed to retrieve object: {e}") + except Exception as e: + self.send_error(500, f"Error processing GET request: {e}") class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): diff --git a/transformers/echo/main.py b/transformers/echo/main.py index eee8827..2c5442f 100644 --- a/transformers/echo/main.py +++ b/transformers/echo/main.py @@ -1,14 +1,15 @@ """ -A simple echo transformation using FastAPI framework and Gunicorn and Uvicorn webserver. +A simple echo transformation using FastAPI framework with Gunicorn and Uvicorn web server. -Steps to run: +Steps to run: $ # with uvicorn -$ uvicorn main:app --reload +$ uvicorn main:app --reload $ # with multiple uvicorn processes managed by gunicorn -$ gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 +$ gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 -Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. """ + # pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring, broad-exception-caught import os import urllib.parse @@ -59,18 +60,20 @@ async def get_handler( fetches the object from the AIS target based on the destination/name, transforms the bytes, and returns the modified bytes. """ - # Get destination/name of object from URL or from full_path variable - # Fetch object from AIS target based on the destination/name - # Transform the bytes - # Return the transformed bytes object_path = urllib.parse.quote(full_path, safe="@") object_url = f"{host_target}/{object_path}" resp = await client.get(object_url) if not resp or resp.status != 200: raise HTTPException( - status_code=500, detail="Error retreiving object ({full_path}) from target" + status_code=500, detail=f"Error retrieving object ({full_path}) from target" ) - return Response(content=await resp.read(), media_type="application/octet-stream") + + content = await resp.read() + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Length": str(len(content))}, # Set Content-Length + ) @app.put("/") @@ -81,8 +84,10 @@ async def put_handler(request: Request): Reads bytes from the request, performs byte transformation, and returns the modified bytes. """ - # Read bytes from request (request.body) - # Transform the bytes - # Return the transformed bytes + content = await request.body() - return Response(content=await request.body(), media_type="application/octet-stream") + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Length": str(len(content))}, # Set Content-Length + ) diff --git a/transformers/md5/server.py b/transformers/md5/server.py index 604ed11..ff0432c 100755 --- a/transformers/md5/server.py +++ b/transformers/md5/server.py @@ -7,42 +7,41 @@ from http.server import HTTPServer, BaseHTTPRequestHandler from socketserver import ThreadingMixIn -host_target = os.environ['AIS_TARGET_URL'] +host_target = os.environ["AIS_TARGET_URL"] class Handler(BaseHTTPRequestHandler): - def log_request(self, code='-', size='-'): - # Don't log successful requests info. Unsuccessful logged by log_error(). - pass + def log_request(self, code="-", size="-"): + pass # Disable request logging - def _set_headers(self): + def _set_headers(self, content_length): self.send_response(200) self.send_header("Content-Type", "text/plain") + self.send_header("Content-Length", str(content_length)) self.end_headers() def do_PUT(self): - content_length = int(self.headers['Content-Length']) + content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) - md5 = hashlib.md5() - md5.update(post_data) - self._set_headers() - self.wfile.write(md5.hexdigest().encode()) + digest = hashlib.md5(post_data).hexdigest().encode() + self._set_headers(len(digest)) + self.wfile.write(digest) def do_GET(self): if self.path == "/health": - self._set_headers() - self.wfile.write(b"Running") + response = b"Running" + self._set_headers(len(response)) + self.wfile.write(response) return - x = requests.get(host_target + self.path) - md5 = hashlib.md5() - md5.update(x.content) - self._set_headers() - self.wfile.write(md5.hexdigest().encode()) + content = requests.get(host_target + self.path).content + digest = hashlib.md5(content).hexdigest().encode() + self._set_headers(len(digest)) + self.wfile.write(digest) class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): - """Handle requests in a separate thread.""" + """Handle requests in separate threads.""" def run(addr="localhost", port=8000): diff --git a/transformers/tests/test_echo.py b/transformers/tests/test_echo.py index b015fd6..efa5781 100644 --- a/transformers/tests/test_echo.py +++ b/transformers/tests/test_echo.py @@ -1,11 +1,12 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. # # pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring -from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV -from aistore.sdk.etl_templates import ECHO +from aistore.sdk.etl.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV +from aistore.sdk.etl.etl_templates import ECHO +from aistore.sdk.etl import ETLConfig from tests.base import TestBase from tests.utils import git_test_mode_format_image_tag_test @@ -33,7 +34,9 @@ def initialize_template(self, communication_type: str): def compare_transformed_data(self, filename: str, source: str): transformed_bytes = ( - self.test_bck.object(filename).get(etl_name=self.test_etl.name).read_all() + self.test_bck.object(filename) + .get(etl=ETLConfig(self.test_etl.name)) + .read_all() ) with open(source, "rb") as file: diff --git a/transformers/tests/test_go_echo.py b/transformers/tests/test_go_echo.py index baabc87..68877aa 100644 --- a/transformers/tests/test_go_echo.py +++ b/transformers/tests/test_go_echo.py @@ -1,13 +1,14 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. # # pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring from tests.base import TestBase from tests.utils import git_test_mode_format_image_tag_test -from aistore.sdk.etl_const import ETL_COMM_HPULL -from aistore.sdk.etl_templates import GO_ECHO +from aistore.sdk.etl.etl_const import ETL_COMM_HPULL +from aistore.sdk.etl.etl_templates import GO_ECHO +from aistore.sdk.etl import ETLConfig class TestGoEchoTransformer(TestBase): @@ -30,12 +31,12 @@ def test_go_echo(self): transformed_image_bytes = ( self.test_bck.object(self.test_image_filename) - .get(etl_name=self.test_etl.name) + .get(etl=ETLConfig(self.test_etl.name)) .read_all() ) transformed_text_bytes = ( self.test_bck.object(self.test_text_filename) - .get(etl_name=self.test_etl.name) + .get(etl=ETLConfig(self.test_etl.name)) .read_all() ) diff --git a/transformers/tests/test_md5.py b/transformers/tests/test_md5.py index ef133c3..766159a 100644 --- a/transformers/tests/test_md5.py +++ b/transformers/tests/test_md5.py @@ -5,8 +5,9 @@ import hashlib -from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV -from aistore.sdk.etl_templates import MD5 +from aistore.sdk.etl.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV +from aistore.sdk.etl.etl_templates import MD5 +from aistore.sdk.etl import ETLConfig from tests.utils import git_test_mode_format_image_tag_test from tests.base import TestBase @@ -29,7 +30,9 @@ def md5_hash_file(self, filepath): def compare_transformed_data_with_md5_hash(self, filename, original_filepath): transformed_data_bytes = ( - self.test_bck.object(filename).get(etl_name=self.test_etl.name).read_all() + self.test_bck.object(filename) + .get(etl=ETLConfig(self.test_etl.name)) + .read_all() ) original_file_hash = self.md5_hash_file(original_filepath) self.assertEqual(transformed_data_bytes.decode("utf-8"), original_file_hash)