diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2e4415148..1099d5c2b 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -182,10 +182,38 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"): self.empty_token_ids, fsm_finals, ) = create_states_mapping(regex_string, tokenizer) + + # token alignment + crossing_tokens = find_crossing_tokens(self.states_to_token_maps, tokenizer.vocabulary) + highest_state = max( + max(self.states_to_token_maps.keys()), + max(max(items.values()) for items in self.states_to_token_maps.values()), + ) + prefixes_map = create_prefixes_map({x[1] for x in crossing_tokens}, highest_state+1, tokenizer.vocabulary) + for item in crossing_tokens: + prefix_map = prefixes_map[item[1]] + self.states_to_token_maps.update(prefix_map) + prefix_map_starting_state = min(prefix_map.keys()) + self.states_to_token_maps[prefix_map_starting_state][item[0]] = item[3] + self.crossing_tokens_prefixes_map = {key: min(value.keys()) for key, value in prefixes_map.items()} + self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} self._cache_state_to_token_tensor() + def get_starting_states(self, prompts: List[str]) -> List[Tuple[int, str]]: + """Get the starting state and the character sequence that should be removed from each prompt""" + results = [] + for prompt in prompts: + longest_prefix = "" + target_state = self.initial_state + for prefix, starting_state in self.crossing_tokens_prefixes_map.items(): + if prompt.endswith(prefix) and len(prefix) > len(longest_prefix): + longest_prefix = prefix + target_state = starting_state + results.append((target_state, longest_prefix)) + return results + def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -475,3 +503,78 @@ def is_final_state(self, state: int) -> bool: def copy(self) -> "CFGGuide": """Create a copy of the FSM.""" return CFGGuide(self.cfg_string, self.tokenizer) + + +### token alignment functions ### + +def find_crossing_tokens(states_to_token_maps: dict, vocabulary: dict) -> List[Tuple[int, str, str, int]]: + """Find the crossing tokens for a given states_to_token_maps. + Crossing tokens are tokens that can be decomposed into a prefix and a postfix, + such that the postfix is a valid sequence of characters for the states_to_token_maps. + Returns a list of tuples, where each tuple contains the token id, the prefix, the postfix and the target state. + """ + + def get_target_state(vocabulary: dict, states_to_token_map: dict, char_seq: str): + """Get the target state in the states_to_token_map for a sequence of characters. + Return None if the sequence is not valid. + """ + state = 0 + for char in char_seq: + char_token = vocabulary.get(char) + try: + state = states_to_token_map[state][char_token] + except KeyError: + return None + return state + + crossing_tokens = [] + invalid_postfixes = set() + valid_postfixes = {} + + for char_seq, token_id in vocabulary.items(): + if len(char_seq) == 1: + continue + # we want to look at all possible "crossing positions" of the token (between char 1 and 2, 2 and 3, etc) + for i in range(1, len(char_seq)): + prefix = char_seq[:i] + postfix = char_seq[i:] + if postfix in invalid_postfixes: + continue + if postfix in valid_postfixes.keys(): + crossing_tokens.append([token_id, prefix, postfix, valid_postfixes[postfix]]) + continue + target_state = get_target_state(vocabulary, states_to_token_maps, postfix) + if target_state is None: + invalid_postfixes.add(postfix) + else: + valid_postfixes[postfix] = target_state + crossing_tokens.append([token_id, prefix, postfix, target_state]) + + return crossing_tokens + + +def create_prefixes_map(prefixes: List[str], starting_state: int, vocabulary: dict) -> dict: + """Create a state to token map for each prefix. + The starting state is the first available state number in the existing FSM. + Return a dictionary where each key is a prefix and the value is the associated states_to_token_map. + """ + + def get_states_to_token_map(char_seq: str, starting_state: int, states_to_token_map: dict, vocabulary: dict): + """Create the states_to_token_map representing all ways of generating the sequence of characters.""" + for i in range(1, len(char_seq) + 1): + if char_seq[:i] in vocabulary.keys(): + if starting_state not in states_to_token_map: + states_to_token_map[starting_state] = {} + if i == len(char_seq): + states_to_token_map[starting_state][vocabulary[char_seq[:i]]] = 0 + else: + states_to_token_map[starting_state][vocabulary[char_seq[:i]]] = starting_state + i + get_states_to_token_map(char_seq[i:], starting_state + i, states_to_token_map, vocabulary) + return states_to_token_map + + prefixes_map = {} + for prefix in prefixes: + prefixes_map[prefix] = get_states_to_token_map(prefix, starting_state, {}, vocabulary) + starting_state += len(prefix) + + return prefixes_map diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4104e3080..fd4bd8f91 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -500,6 +500,8 @@ def format(sequences): max_tokens, stop_at, seed ) + removed_chars_from_prompts = self.logits_processor.get_removed_chars_from_prompts(prompts) + completions = self.model.generate( prompts, generation_params, @@ -508,7 +510,13 @@ def format(sequences): **model_specific_params, ) - return format(completions) + trimmed_completions = [ + completion[len(removed_chars):] + for completion, removed_chars in zip(completions, removed_chars_from_prompts) + ] + + return format(trimmed_completions) + def stream( self, diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d037c679f..d66c83682 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -61,10 +61,20 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): The finite state machine which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_states: Dict[int, int] = {} + self._fsm_states: List[Dict[int, int]] = [] self.fsm: Guide = fsm self._is_first_token = True self._seq_start_idx: Optional[int] = None + self.default_starting_state = 0 + self.token_alignment_starting_states = [] + + def get_removed_chars_from_prompts(self, prompts: List[str]) -> List[str]: + """For each prompt, get the postfix to be removed and the resulting starting state. + Update the token_alignment_starting_states attribute and return the postfixes to be removed. + """ + starting_states_and_prefixes = self.fsm.get_starting_states(prompts) + self.token_alignment_starting_states = [starting_state for starting_state, _ in starting_states_and_prefixes] + return [prefix for _, prefix in starting_states_and_prefixes] def process_logits( self, input_ids: List[List[int]], logits: torch.Tensor @@ -89,18 +99,18 @@ def process_logits( self._is_first_token = False self._seq_start_idx = len(input_ids[0]) - self._fsm_states = {hash(tuple([])): 0} - sequence_states = [0] * len(input_ids) + sequence_states = self.token_alignment_starting_states if self.token_alignment_starting_states else [self.default_starting_state] * len(input_ids) + self._fsm_states = [{hash(tuple([])): sequence_states[i]} for i in range(len(input_ids))] else: - for seq_ids in input_ids: + for i, seq_ids in enumerate(input_ids): prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1])) - prev_state = self._fsm_states[prev_state_key] + prev_state = self._fsm_states[i][prev_state_key] curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :])) curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1]) - self._fsm_states[curr_state_key] = curr_state + self._fsm_states[i][curr_state_key] = curr_state sequence_states.append(curr_state) mask = torch.full_like(logits, -math.inf)