Skip to content

Commit 66d5f97

Browse files
authored
Merge pull request #2 from runpod-workers/up-init
Up-init
2 parents 7837407 + d609499 commit 66d5f97

File tree

2 files changed

+85
-30
lines changed

2 files changed

+85
-30
lines changed

Diff for: Dockerfile

+23-20
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
1-
# Base image -> https://github.com/runpod/containers/blob/main/official-templates/base/Dockerfile
2-
# DockerHub -> https://hub.docker.com/r/runpod/base/tags
3-
FROM runpod/base:0.4.0-cuda11.8.0
1+
# Start with NVIDIA CUDA base image
2+
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
43

5-
# The base image comes with many system dependencies pre-installed to help you get started quickly.
6-
# Please refer to the base image's Dockerfile for more information before adding additional dependencies.
7-
# IMPORTANT: The base image overrides the default huggingface cache location.
4+
# Avoid prompts from apt
5+
ENV DEBIAN_FRONTEND=noninteractive
86

7+
# Install system dependencies and Python
8+
RUN apt-get update -y && \
9+
apt-get install -y python3-pip python3-dev git libopenmpi-dev && \
10+
apt-get clean && \
11+
rm -rf /var/lib/apt/lists/*
912

10-
# --- Optional: System dependencies ---
11-
# COPY builder/setup.sh /setup.sh
12-
# RUN /bin/bash /setup.sh && \
13-
# rm /setup.sh
13+
# Clone TensorRT-LLM repository
14+
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git /app/TensorRT-LLM
1415

16+
# Set working directory
17+
WORKDIR /app/TensorRT-LLM/examples/llm-api
1518

16-
# Python dependencies
17-
COPY builder/requirements.txt /requirements.txt
18-
RUN python3.11 -m pip install --upgrade pip && \
19-
python3.11 -m pip install --upgrade -r /requirements.txt --no-cache-dir && \
20-
rm /requirements.txt
19+
# Install Python dependencies
20+
RUN pip3 install -r requirements.txt
2121

22-
# NOTE: The base image comes with multiple Python versions pre-installed.
23-
# It is reccommended to specify the version of Python when running your code.
22+
# Install additional dependencies for the serverless worker
23+
RUN pip3 install --upgrade runpod transformers
2424

25+
# Set the working directory to /app
26+
WORKDIR /app
2527

26-
# Add src files (Worker Template)
27-
ADD src .
28+
# Copy the src directory containing handler.py
29+
COPY src /app/src
2830

29-
CMD python3.11 -u /handler.py
31+
# Command to run the serverless worker
32+
CMD ["python3", "/app/src/handler.py"]

Diff for: src/handler.py

+62-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,70 @@
1-
""" Example handler file. """
2-
1+
import os
32
import runpod
3+
from typing import List, AsyncGenerator, Dict, Union
4+
from tensorrt_llm import LLM, SamplingParams
5+
from huggingface_hub import login
6+
from tensorrt_llm.hlapi import BuildConfig, KvCacheConfig
47

5-
# If your handler runs inference on a model, load the model here.
6-
# You will want models to be loaded into memory before starting serverless.
8+
# Enable build caching
9+
os.environ["TLLM_HLAPI_BUILD_CACHE"] = "1"
10+
# Optionally, set a custom cache directory
11+
# os.environ["TLLM_HLAPI_BUILD_CACHE_ROOT"] = "/path/to/custom/cache"
12+
#HF_TOKEN for downloading models
713

814

9-
def handler(job):
10-
""" Handler function that will be used to process jobs. """
11-
job_input = job['input']
1215

13-
name = job_input.get('name', 'World')
16+
hf_token = os.environ["HF_TOKEN"]
17+
login(token=hf_token)
18+
19+
1420

15-
return f"Hello, {name}!"
21+
class TRTLLMWorker:
22+
def __init__(self, model_path: str):
23+
self.llm = LLM(model=model_path, enable_build_cache=True, kv_cache_config=KvCacheConfig(), build_config=BuildConfig())
24+
25+
def generate(self, prompts: List[str], max_tokens: int = 100) -> List[str]:
26+
sampling_params = SamplingParams(max_new_tokens=max_tokens)
27+
outputs = self.llm.generate(prompts, sampling_params)
28+
29+
results = []
30+
for output in outputs:
31+
results.append(output.outputs[0].text)
32+
return results
1633

34+
35+
async def generate_async(self, prompts: List[str], max_tokens: int = 100) -> AsyncGenerator[str, None]:
36+
sampling_params = SamplingParams(max_new_tokens=max_tokens)
37+
38+
async for output in self.llm.generate_async(prompts, sampling_params):
39+
for request_output in output.outputs:
40+
if request_output.text:
41+
yield request_output.text
42+
43+
# Initialize the worker outside the handler
44+
# This ensures the model is loaded only once when the serverless function starts
45+
# this path is hf model "<org_name>/model_name" egs: meta-llama/Meta-Llama-3.1-8B-Instruct
46+
model_path = os.environ["MODEL_PATH"]
47+
worker = TRTLLMWorker(model_path)
48+
49+
50+
async def handler(job: Dict) -> AsyncGenerator[Dict[str, Union[str, List[str]]], None]:
51+
"""Handler function that will be used to process jobs."""
52+
job_input = job['input']
53+
prompts = job_input.get('prompts', ["Hello, how are you?"])
54+
max_tokens = job_input.get('max_tokens', 100)
55+
streaming = job_input.get('streaming', False)
56+
57+
try:
58+
if streaming:
59+
results = []
60+
async for chunk in worker.generate_async(prompts, max_tokens):
61+
results.append(chunk)
62+
yield {"status": "streaming", "chunk": chunk}
63+
yield {"status": "success", "output": results}
64+
else:
65+
results = worker.generate(prompts, max_tokens)
66+
yield {"status": "success", "output": results}
67+
except Exception as e:
68+
yield {"status": "error", "message": str(e)}
1769

18-
runpod.serverless.start({"handler": handler})
70+
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

0 commit comments

Comments
 (0)