Skip to content

Commit f1d493f

Browse files
committed
Initial commit
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
1 parent 7837407 commit f1d493f

File tree

2 files changed

+61
-34
lines changed

2 files changed

+61
-34
lines changed

Dockerfile

+23-20
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
1-
# Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile
2-
# DockerHub -> https://hub.docker.com/r/runpod/base/tags
3-
FROM runpod/base:0.4.0-cuda11.8.0
1+
# Start with NVIDIA CUDA base image
2+
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
43

5-
# The base image comes with many system dependencies pre-installed to help you get started quickly.
6-
# Please refer to the base image's Dockerfile for more information before adding additional dependencies.
7-
# IMPORTANT: The base image overrides the default huggingface cache location.
4+
# Avoid prompts from apt
5+
ENV DEBIAN_FRONTEND=noninteractive
86

7+
# Install system dependencies and Python
8+
RUN apt-get update -y && \
9+
apt-get install -y python3-pip python3-dev git libopenmpi-dev && \
10+
apt-get clean && \
11+
rm -rf /var/lib/apt/lists/*
912

10-
# --- Optional: System dependencies ---
11-
# COPY builder/setup.sh /setup.sh
12-
# RUN /bin/bash /setup.sh && \
13-
# rm /setup.sh
13+
# Clone TensorRT-LLM repository
14+
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git /app/TensorRT-LLM
1415

16+
# Set working directory
17+
WORKDIR /app/TensorRT-LLM/examples/llm-api
1518

16-
# Python dependencies
17-
COPY builder/requirements.txt /requirements.txt
18-
RUN python3.11 -m pip install --upgrade pip && \
19-
python3.11 -m pip install --upgrade -r /requirements.txt --no-cache-dir && \
20-
rm /requirements.txt
19+
# Install Python dependencies
20+
RUN pip3 install -r requirements.txt
2121

22-
# NOTE: The base image comes with multiple Python versions pre-installed.
23-
# It is reccommended to specify the version of Python when running your code.
22+
# Install additional dependencies for the serverless worker
23+
RUN pip3 install runpod
2424

25+
# Set the working directory to /app
26+
WORKDIR /app
2527

26-
# Add src files (Worker Template)
27-
ADD src .
28+
# Copy the src directory containing handler.py
29+
COPY src /app/src
2830

29-
CMD python3.11 -u /handler.py
31+
# Command to run the serverless worker
32+
CMD ["python3", "/app/src/handler.py"]

src/handler.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,42 @@
1-
""" Example handler file. """
2-
1+
import os
32
import runpod
4-
5-
# If your handler runs inference on a model, load the model here.
6-
# You will want models to be loaded into memory before starting serverless.
7-
3+
from typing import List
4+
from tensorrt_llm import LLM, SamplingParams
5+
6+
# Enable build caching
7+
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1"
8+
# Optionally, set a custom cache directory
9+
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
10+
11+
class TRTLLMWorker:
12+
def __init__(self, model_path: str):
13+
self.llm = LLM(model=model_path, enable_build_cache=True)
14+
15+
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
16+
sampling_params = SamplingParams(max_new_tokens=max_tokens)
17+
outputs = self.llm.generate(prompts, sampling_params)
18+
19+
results = []
20+
for output in outputs:
21+
results.append(output.outputs[0].text)
22+
23+
return results
24+
25+
# Initialize the worker outside the handler
26+
# This ensures the model is loaded only once when the serverless function starts
27+
worker = TRTLLMWorker("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
828

929
def handler(job):
10-
""" Handler function that will be used to process jobs. """
30+
"""Handler function that will be used to process jobs."""
1131
job_input = job['input']
12-
13-
name = job_input.get('name', 'World')
14-
15-
return f"Hello, {name}!"
16-
17-
18-
runpod.serverless.start({"handler": handler})
32+
prompts = job_input.get('prompts', ["Hello, how are you?"])
33+
max_tokens = job_input.get('max_tokens', 100)
34+
35+
try:
36+
results = worker.generate(prompts, max_tokens)
37+
return {"status": "success", "output": results}
38+
except Exception as e:
39+
return {"status": "error", "message": str(e)}
40+
41+
if __name__ == "__main__":
42+
runpod.serverless.start({"handler": handler})

0 commit comments

Comments
 (0)