Skip to content

Commit

Permalink
refactor: add missing 'Content-Length' header for proper response han…
Browse files Browse the repository at this point in the history
…dling

Signed-off-by: Abhishek Gaikwad <gaikwadabhishek1997@gmail.com>
  • Loading branch information
gaikwadabhishek committed Feb 14, 2025
1 parent 5e5e6d7 commit 914de20
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 78 deletions.
82 changes: 49 additions & 33 deletions runtime/python/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
code_file = os.getenv("MOD_NAME")
arg_type = os.getenv("ARG_TYPE", "bytes")

spec = importlib.util.spec_from_file_location(name="function", location=f"./code/{code_file}.py")
spec = importlib.util.spec_from_file_location(
name="function", location=f"./code/{code_file}.py"
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

Expand Down Expand Up @@ -47,55 +49,69 @@ def read_all(self) -> bytes:

def __iter__(self) -> Iterator[bytes]:
while self._remaining_length > 0:
read_buffer = (
self._chunk_size
if self._remaining_length >= self._chunk_size
else self._remaining_length
)
read_buffer = min(self._chunk_size, self._remaining_length)
self._remaining_length -= read_buffer
yield self._rfile.read(read_buffer)


class Handler(BaseHTTPRequestHandler):
def log_request(self, *args):
# Don't log successful requests info. Unsuccessful logged by log_error().
# Don't log successful requests info; log errors only.
pass

def _set_headers(self):
def _set_headers(self, content_length=None):
"""Set response headers"""
self.send_response(200)
self.send_header("Content-Type", "application/octet-stream")
if content_length is not None:
self.send_header("Content-Length", str(content_length))
self.end_headers()

def do_PUT(self):
content_length = int(self.headers["Content-Length"])
reader = StreamWrapper(self.rfile, content_length, CHUNK_SIZE)
if CHUNK_SIZE == 0:
result = transform(reader.read_all())
"""Handles PUT requests by applying a transformation function to the request body"""
try:
content_length = int(self.headers.get("Content-Length", 0))
reader = StreamWrapper(self.rfile, content_length, CHUNK_SIZE)

if CHUNK_SIZE == 0:
result = transform(reader.read_all())
self._set_headers(content_length=len(result))
self.wfile.write(result)
return

# Streaming transform: writer is expected to write bytes to response
self._set_headers()
self.wfile.write(result)
return
transform(reader, self.wfile)

# TODO: validate if transform takes writer as input
# NOTE: for streaming transforms the writer is expected to write bytes into response as stream.
self._set_headers()
transform(reader, self.wfile)
except Exception as e:
self.send_error(500, f"Error processing PUT request: {e}")

def do_GET(self):
if self.path == "/health":
self._set_headers()
self.wfile.write(b"Running")
return

query_path = host_target + self.path

if arg_type == "url":
result = transform(query_path)
else:
input_bytes = requests.get(query_path).content
result = transform(input_bytes)

self._set_headers()
self.wfile.write(result)
"""Handles GET requests by fetching data, transforming it, and returning the response"""
try:
if self.path == "/health":
response = b"Running"
self._set_headers(content_length=len(response))
self.wfile.write(response)
return

query_path = host_target + self.path

if arg_type == "url":
result = transform(query_path)
else:
response = requests.get(query_path)
response.raise_for_status() # Raise an error if request failed
input_bytes = response.content
result = transform(input_bytes)

self._set_headers(content_length=len(result))
self.wfile.write(result)

except requests.exceptions.RequestException as e:
self.send_error(500, f"Failed to retrieve object: {e}")
except Exception as e:
self.send_error(500, f"Error processing GET request: {e}")


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
Expand Down
35 changes: 20 additions & 15 deletions transformers/echo/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
A simple echo transformation using FastAPI framework and Gunicorn and Uvicorn webserver.
A simple echo transformation using FastAPI framework with Gunicorn and Uvicorn web server.
Steps to run:
Steps to run:
$ # with uvicorn
$ uvicorn main:app --reload
$ uvicorn main:app --reload
$ # with multiple uvicorn processes managed by gunicorn
$ gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
$ gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
"""

# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring, broad-exception-caught
import os
import urllib.parse
Expand Down Expand Up @@ -59,18 +60,20 @@ async def get_handler(
fetches the object from the AIS target based on the destination/name,
transforms the bytes, and returns the modified bytes.
"""
# Get destination/name of object from URL or from full_path variable
# Fetch object from AIS target based on the destination/name
# Transform the bytes
# Return the transformed bytes
object_path = urllib.parse.quote(full_path, safe="@")
object_url = f"{host_target}/{object_path}"
resp = await client.get(object_url)
if not resp or resp.status != 200:
raise HTTPException(
status_code=500, detail="Error retreiving object ({full_path}) from target"
status_code=500, detail=f"Error retrieving object ({full_path}) from target"
)
return Response(content=await resp.read(), media_type="application/octet-stream")

content = await resp.read()
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Length": str(len(content))}, # Set Content-Length
)


@app.put("/")
Expand All @@ -81,8 +84,10 @@ async def put_handler(request: Request):
Reads bytes from the request, performs byte transformation,
and returns the modified bytes.
"""
# Read bytes from request (request.body)
# Transform the bytes
# Return the transformed bytes
content = await request.body()

return Response(content=await request.body(), media_type="application/octet-stream")
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Length": str(len(content))}, # Set Content-Length
)
35 changes: 17 additions & 18 deletions transformers/md5/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,41 @@
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn

host_target = os.environ['AIS_TARGET_URL']
host_target = os.environ["AIS_TARGET_URL"]


class Handler(BaseHTTPRequestHandler):
def log_request(self, code='-', size='-'):
# Don't log successful requests info. Unsuccessful logged by log_error().
pass
def log_request(self, code="-", size="-"):
pass # Disable request logging

def _set_headers(self):
def _set_headers(self, content_length):
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.send_header("Content-Length", str(content_length))
self.end_headers()

def do_PUT(self):
content_length = int(self.headers['Content-Length'])
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
md5 = hashlib.md5()
md5.update(post_data)
self._set_headers()
self.wfile.write(md5.hexdigest().encode())
digest = hashlib.md5(post_data).hexdigest().encode()
self._set_headers(len(digest))
self.wfile.write(digest)

def do_GET(self):
if self.path == "/health":
self._set_headers()
self.wfile.write(b"Running")
response = b"Running"
self._set_headers(len(response))
self.wfile.write(response)
return

x = requests.get(host_target + self.path)
md5 = hashlib.md5()
md5.update(x.content)
self._set_headers()
self.wfile.write(md5.hexdigest().encode())
content = requests.get(host_target + self.path).content
digest = hashlib.md5(content).hexdigest().encode()
self._set_headers(len(digest))
self.wfile.write(digest)


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
"""Handle requests in a separate thread."""
"""Handle requests in separate threads."""


def run(addr="localhost", port=8000):
Expand Down
11 changes: 7 additions & 4 deletions transformers/tests/test_echo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
#

# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring

from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV
from aistore.sdk.etl_templates import ECHO
from aistore.sdk.etl.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV
from aistore.sdk.etl.etl_templates import ECHO
from aistore.sdk.etl import ETLConfig

from tests.base import TestBase
from tests.utils import git_test_mode_format_image_tag_test
Expand Down Expand Up @@ -33,7 +34,9 @@ def initialize_template(self, communication_type: str):

def compare_transformed_data(self, filename: str, source: str):
transformed_bytes = (
self.test_bck.object(filename).get(etl_name=self.test_etl.name).read_all()
self.test_bck.object(filename)
.get(etl=ETLConfig(self.test_etl.name))
.read_all()
)

with open(source, "rb") as file:
Expand Down
11 changes: 6 additions & 5 deletions transformers/tests/test_go_echo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
#
# pylint: disable=missing-class-docstring, missing-function-docstring, missing-module-docstring

from tests.base import TestBase
from tests.utils import git_test_mode_format_image_tag_test

from aistore.sdk.etl_const import ETL_COMM_HPULL
from aistore.sdk.etl_templates import GO_ECHO
from aistore.sdk.etl.etl_const import ETL_COMM_HPULL
from aistore.sdk.etl.etl_templates import GO_ECHO
from aistore.sdk.etl import ETLConfig


class TestGoEchoTransformer(TestBase):
Expand All @@ -30,12 +31,12 @@ def test_go_echo(self):

transformed_image_bytes = (
self.test_bck.object(self.test_image_filename)
.get(etl_name=self.test_etl.name)
.get(etl=ETLConfig(self.test_etl.name))
.read_all()
)
transformed_text_bytes = (
self.test_bck.object(self.test_text_filename)
.get(etl_name=self.test_etl.name)
.get(etl=ETLConfig(self.test_etl.name))
.read_all()
)

Expand Down
9 changes: 6 additions & 3 deletions transformers/tests/test_md5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import hashlib

from aistore.sdk.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV
from aistore.sdk.etl_templates import MD5
from aistore.sdk.etl.etl_const import ETL_COMM_HPULL, ETL_COMM_HPUSH, ETL_COMM_HREV
from aistore.sdk.etl.etl_templates import MD5
from aistore.sdk.etl import ETLConfig

from tests.utils import git_test_mode_format_image_tag_test
from tests.base import TestBase
Expand All @@ -29,7 +30,9 @@ def md5_hash_file(self, filepath):

def compare_transformed_data_with_md5_hash(self, filename, original_filepath):
transformed_data_bytes = (
self.test_bck.object(filename).get(etl_name=self.test_etl.name).read_all()
self.test_bck.object(filename)
.get(etl=ETLConfig(self.test_etl.name))
.read_all()
)
original_file_hash = self.md5_hash_file(original_filepath)
self.assertEqual(transformed_data_bytes.decode("utf-8"), original_file_hash)
Expand Down

0 comments on commit 914de20

Please sign in to comment.