Skip to content

Commit 333e346

Browse files
committed
Merge branch 'develop' into cache-per-step
2 parents 95f7618 + d5c0484 commit 333e346

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ llama-cpp = ["llama-cpp-python >= 0.2.0"]
8484
mistralai = ["mistralai >= 1.0.0"]
8585
ollama = ["ollama >= 0.1.7"]
8686
openai = ["openai >= 1.0.0"]
87-
outlines = ["outlines >= 0.0.40"]
87+
outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"]
8888
ray = ["ray[default] >= 2.31.0"]
8989
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
9090
vllm = [
@@ -99,7 +99,7 @@ faiss-gpu = ["faiss-gpu >= 1.7.2"]
9999
text-clustering = [
100100
"umap-learn >= 0.5.6",
101101
"scikit-learn >= 1.4.1",
102-
"matplotlib >= 3.8.3" # For the figure (even though it's optional)
102+
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
103103
]
104104

105105
# minhash

scripts/install_dependencies.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ python -m pip install uv
99
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"
1010

1111
if [ "${python_version}" != "(3, 12)" ]; then
12-
uv pip install --system -e .[ray]
12+
uv pip install --system -e .[ray]
1313
fi
1414

1515
./scripts/install_cpu_vllm.sh
16-
uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git
1716

1817
uv pip install --system -e ".[dev,tests]"

src/distilabel/steps/tasks/argilla_labeller.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,13 @@ def format_input(
382382
"""Format the input into a chat message.
383383
384384
Args:
385-
input (Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]): The input to format.
385+
input: The input to format.
386386
387387
Returns:
388-
ChatType: The formatted chat message.
388+
The formatted chat message.
389+
390+
Raises:
391+
ValueError: If question or fields are not provided.
389392
"""
390393
input_keys = list(self.inputs.keys())
391394
record = input[input_keys[0]]
@@ -394,6 +397,11 @@ def format_input(
394397
examples = input.get(input_keys[3], self.example_records)
395398
guidelines = input.get(input_keys[4], self.guidelines)
396399

400+
if question is None:
401+
raise ValueError("Question must be provided.")
402+
if fields is None or any(field is None for field in fields):
403+
raise ValueError("Fields must be provided.")
404+
397405
record = record.to_dict() if not isinstance(record, dict) else record
398406
question = question.serialize() if not isinstance(question, dict) else question
399407
fields = [
@@ -416,6 +424,7 @@ def format_input(
416424
if examples
417425
else False
418426
)
427+
419428
prompt = self._template.render(
420429
fields=formatted_fields,
421430
question=formatted_question,

0 commit comments

Comments
 (0)