1
1
import datetime
2
2
from dataclasses import dataclass
3
- from typing import TYPE_CHECKING , Iterator , List , Optional , Union
3
+ from typing import Iterator , List , Optional , Sequence , Union
4
+
5
+ import torch
4
6
5
7
from outlines .generate .generator import sequence_generator
6
8
from outlines .samplers import BeamSearchSampler , GreedySampler , MultinomialSampler
7
9
8
- if TYPE_CHECKING :
9
- import torch
10
-
11
10
FormattedOutput = Union [
12
11
str , int , float , bool , datetime .date , datetime .time , datetime .datetime
13
12
]
13
+ TotalCompletionsType = Optional [Union [List [str ], str ]]
14
14
15
15
16
16
class SequenceGenerator :
@@ -461,6 +461,47 @@ def prepare_generation_parameters(
461
461
462
462
return generation_params
463
463
464
+ def strip_completions (
465
+ self ,
466
+ completions ,
467
+ prompts : Union [str , List [str ]],
468
+ aligned_prompts : Union [str , List [str ]],
469
+ ):
470
+ """Remove characters generated through token alignment from the completions.
471
+
472
+ As token alignment makes the model re-generate some of the characters at
473
+ the end of the prompt, we want to remove those from the beginning of the
474
+ completions to only return the characters after the end of the user prompts.
475
+
476
+ Parameters
477
+ ----------
478
+ completions
479
+ Text generated by the model
480
+ prompts
481
+ The original prompts provided by the user
482
+ aligned_prompts
483
+ The prompts of the user after token alignment (what's given to the model)
484
+
485
+ Returns
486
+ -------
487
+ The stripped completions
488
+ """
489
+ if isinstance (prompts , str ):
490
+ if isinstance (completions , str ):
491
+ return completions [len (prompts ) - len (aligned_prompts ) :]
492
+
493
+ return [
494
+ self .strip_completions (completion , prompts , aligned_prompts )
495
+ for completion in completions
496
+ ]
497
+
498
+ return [
499
+ self .strip_completions (completion , prompt , aligned_prompt )
500
+ for completion , prompt , aligned_prompt in zip (
501
+ completions , prompts , aligned_prompts
502
+ )
503
+ ]
504
+
464
505
def format_sequence (self , sequence : str ) -> FormattedOutput :
465
506
"""Translate the generated sequence to another type.
466
507
@@ -485,6 +526,7 @@ def __call__(
485
526
max_tokens : Optional [int ] = None ,
486
527
stop_at : Optional [Union [str , List [str ]]] = None ,
487
528
seed : Optional [int ] = None ,
529
+ token_healing_enabled = True ,
488
530
** model_specific_params ,
489
531
):
490
532
"""Generate text from a prompt of list of prompts."""
@@ -500,32 +542,106 @@ def format(sequences):
500
542
max_tokens , stop_at , seed
501
543
)
502
544
545
+ # if token_healing is disabled or unavailable for the type of fsm used by the processor,
546
+ # the aligned_prompts are just the prompts
547
+ aligned_prompts = self .logits_processor .setup_processor (
548
+ prompts , token_healing_enabled
549
+ )
550
+
503
551
completions = self .model .generate (
504
- prompts ,
552
+ aligned_prompts ,
505
553
generation_params ,
506
554
self .logits_processor ,
507
555
self .sampling_params ,
508
556
** model_specific_params ,
509
557
)
510
558
511
- return format (completions )
559
+ stripped_completions = self .strip_completions (
560
+ completions , prompts , aligned_prompts
561
+ )
562
+
563
+ return format (stripped_completions )
512
564
513
565
def stream (
514
566
self ,
515
567
prompts : Union [str , List [str ]],
516
568
max_tokens : Optional [int ] = None ,
517
569
stop_at : Optional [Union [str , List [str ]]] = None ,
518
570
seed : Optional [int ] = None ,
571
+ token_healing_enabled = True ,
519
572
** model_specific_params ,
520
573
):
521
574
"""Return a text generator from a prompt or a list of prompts."""
575
+
576
+ def add_chunks_to_completions (
577
+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
578
+ total_completions : Optional [
579
+ Union [str , List [str ], List [List [str ]], Sequence [str ]]
580
+ ],
581
+ ):
582
+ """Append each of the text chunks at the end of the corresponding completions"""
583
+ if isinstance (text_chunks , str ):
584
+ if isinstance (total_completions , str ):
585
+ return total_completions + text_chunks
586
+ return text_chunks
587
+
588
+ if total_completions :
589
+ return [
590
+ add_chunks_to_completions (text_chunk , total_completion )
591
+ for text_chunk , total_completion in zip (
592
+ text_chunks , total_completions
593
+ )
594
+ ]
595
+
596
+ return [
597
+ add_chunks_to_completions (text_chunk , None )
598
+ for text_chunk in text_chunks
599
+ ]
600
+
601
+ def strip_text_chunks (
602
+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
603
+ stripped_completions : Union [str , List [str ], List [List [str ]], Sequence [str ]],
604
+ ):
605
+ """Get the stripped text_chunks from the stripped_completions."""
606
+ if isinstance (text_chunks , str ):
607
+ return (
608
+ stripped_completions [- len (text_chunks ) :]
609
+ if len (text_chunks ) > 0
610
+ else ""
611
+ )
612
+
613
+ return [
614
+ strip_text_chunks (text_chunk , stripped_completion )
615
+ for text_chunk , stripped_completion in zip (
616
+ text_chunks , stripped_completions
617
+ )
618
+ ]
619
+
522
620
generation_params = self .prepare_generation_parameters (
523
621
max_tokens , stop_at , seed
524
622
)
525
- return self .model .stream (
623
+
624
+ # if token_healing is disabled or unavailable for the type of fsm used by the processor,
625
+ # the aligned_prompts are just the prompts
626
+ aligned_prompts = self .logits_processor .setup_processor (
627
+ prompts , token_healing_enabled
628
+ )
629
+
630
+ total_completions : TotalCompletionsType = None
631
+
632
+ for text_chunks in self .model .stream (
526
633
prompts ,
527
634
generation_params ,
528
635
self .logits_processor ,
529
636
self .sampling_params ,
530
637
** model_specific_params ,
531
- )
638
+ ):
639
+ total_completions = add_chunks_to_completions (
640
+ text_chunks , total_completions
641
+ )
642
+
643
+ stripped_completions = self .strip_completions (
644
+ total_completions , prompts , aligned_prompts
645
+ )
646
+
647
+ yield strip_text_chunks (text_chunks , stripped_completions )
0 commit comments