Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Layout Matching #55

Draft
wants to merge 5 commits into
base: features/penalty_rewards
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions prompting/validators/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# DEALINGS IN THE SOFTWARE.
import re
import torch
import json
import random
import ast
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List
Expand Down Expand Up @@ -94,3 +97,72 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor:

def compose_text(self) -> str:
return self.text.format(target_length=self.target_length, unit=self.unit.value)

class LayoutTypeEnum(Enum):
JSON = "json"
DICTIONARY = "python dictionary"
NUMBEREDLIST = "numbered list"
BULLETPOINTLIST = "bullet point list"

def _select_random_attribute(self):
attribute_names = [attr for attr in vars(self) if not attr.startswith('_')]
random_attribute_name = random.choice(attribute_names)
random_attribute_value = getattr(self, random_attribute_name)
return random_attribute_value

@dataclass
class MatchLayoutCriteria(TaskCriterion):
text: str = "Your response must be in the form of a {format_type}{fields}"
penalty: float = 0.1
format_type: LayoutTypeEnum = LayoutTypeEnum.JSON
num_fields: int = 0
fields : str = " with {num_fields} fields"

def is_json(self, text):
try:
json.loads(text)
return True
except ValueError:
return False

def is_dictionary(self, text):
try:
if type(ast.literal_eval(text)) == dict:
return True
else:
return False
except (ValueError, SyntaxError):
return False

def is_numbered_list(self, input_string):
pattern = r'^\d.*\n?$'
lines = input_string.split('\n')
return all(re.match(pattern, line) for line in lines)

def is_bullet_point_list(self, input_string):
pattern = r'^\s*[-*+]\s*.*\n?(\s*[-*+]\s.*\n?)*$'
return bool(re.match(pattern, input_string))

def _get_format_match(self, response : str) -> bool:
if self.format_type == LayoutTypeEnum.JSON:
return self.is_json(response)
elif self.format_type == LayoutTypeEnum.DICTIONARY:
return self.is_dictionary(response)
elif self.format_type == LayoutTypeEnum.NUMBEREDLIST:
return self.is_numbered_list(response)
elif self.format_type == LayoutTypeEnum.BULLETPOINTLIST:
return self.is_bullet_point_list(response)
else:
return False

def evaluate(self, completions: list[str]) -> torch.FloatTensor:
penalties = torch.zeros(len(completions), dtype = torch.float32)
for idx, completion in enumerate(completions):
if not self._get_format_match(completion):
penalties[idx] = self.penalty
return penalties

def compose_text(self) -> str:
if self.num_fields == 0:
return self.text.format(format_type = self.format_type.value, fields = "")
return self.text.format(format_type = self.format_type.value, fields = self.fields)
8 changes: 7 additions & 1 deletion prompting/validators/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
TaskCriterion,
MatchLengthCriteria,
TextLengthUnitEnum,
MatchLayoutCriteria,
LayoutTypeEnum,
)


Expand Down Expand Up @@ -169,8 +171,12 @@ def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask:
target_length=random.randint(4, 8),
unit=TextLengthUnitEnum.SENTENCES,
)
match_layout_criteria = MatchLayoutCriteria(
penalty = 0.1,
target_layout = LayoutTypeEnum._select_random_attribute(LayoutTypeEnum),
)

criteria = [match_words_criteria, match_length_criteria]
criteria = [match_words_criteria, match_length_criteria, match_layout_criteria]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steffencruz what's your thoughts on including the layout criteria in the 2.1.0 release? At first glance I would think that this could be challenging both to the miners and the current reward stack evaluation


return QuestionAnswerTask(
base_text=base_text,
Expand Down