diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index 23afbc3..030c243 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -17,6 +17,9 @@ # DEALINGS IN THE SOFTWARE. import re import torch +import json +import random +import ast from dataclasses import dataclass from abc import ABC, abstractmethod from typing import List @@ -94,3 +97,72 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor: def compose_text(self) -> str: return self.text.format(target_length=self.target_length, unit=self.unit.value) + +class LayoutTypeEnum(Enum): + JSON = "json" + DICTIONARY = "python dictionary" + NUMBEREDLIST = "numbered list" + BULLETPOINTLIST = "bullet point list" + + def _select_random_attribute(self): + attribute_names = [attr for attr in vars(self) if not attr.startswith('_')] + random_attribute_name = random.choice(attribute_names) + random_attribute_value = getattr(self, random_attribute_name) + return random_attribute_value + +@dataclass +class MatchLayoutCriteria(TaskCriterion): + text: str = "Your response must be in the form of a {format_type}{fields}" + penalty: float = 0.1 + format_type: LayoutTypeEnum = LayoutTypeEnum.JSON + num_fields: int = 0 + fields : str = " with {num_fields} fields" + + def is_json(self, text): + try: + json.loads(text) + return True + except ValueError: + return False + + def is_dictionary(self, text): + try: + if type(ast.literal_eval(text)) == dict: + return True + else: + return False + except (ValueError, SyntaxError): + return False + + def is_numbered_list(self, input_string): + pattern = r'^\d.*\n?$' + lines = input_string.split('\n') + return all(re.match(pattern, line) for line in lines) + + def is_bullet_point_list(self, input_string): + pattern = r'^\s*[-*+]\s*.*\n?(\s*[-*+]\s.*\n?)*$' + return bool(re.match(pattern, input_string)) + + def _get_format_match(self, response : str) -> bool: + if self.format_type == LayoutTypeEnum.JSON: + return self.is_json(response) + elif self.format_type == LayoutTypeEnum.DICTIONARY: + return self.is_dictionary(response) + elif self.format_type == LayoutTypeEnum.NUMBEREDLIST: + return self.is_numbered_list(response) + elif self.format_type == LayoutTypeEnum.BULLETPOINTLIST: + return self.is_bullet_point_list(response) + else: + return False + + def evaluate(self, completions: list[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype = torch.float32) + for idx, completion in enumerate(completions): + if not self._get_format_match(completion): + penalties[idx] = self.penalty + return penalties + + def compose_text(self) -> str: + if self.num_fields == 0: + return self.text.format(format_type = self.format_type.value, fields = "") + return self.text.format(format_type = self.format_type.value, fields = self.fields) diff --git a/prompting/validators/tasks.py b/prompting/validators/tasks.py index 2799169..d37d91c 100644 --- a/prompting/validators/tasks.py +++ b/prompting/validators/tasks.py @@ -25,6 +25,8 @@ TaskCriterion, MatchLengthCriteria, TextLengthUnitEnum, + MatchLayoutCriteria, + LayoutTypeEnum, ) @@ -169,8 +171,12 @@ def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask: target_length=random.randint(4, 8), unit=TextLengthUnitEnum.SENTENCES, ) + match_layout_criteria = MatchLayoutCriteria( + penalty = 0.1, + target_layout = LayoutTypeEnum._select_random_attribute(LayoutTypeEnum), + ) - criteria = [match_words_criteria, match_length_criteria] + criteria = [match_words_criteria, match_length_criteria, match_layout_criteria] return QuestionAnswerTask( base_text=base_text,