1
1
import datetime
2
2
from dataclasses import dataclass
3
- from typing import TYPE_CHECKING , Iterator , List , Optional , Union
3
+ from typing import Iterator , List , Optional , Sequence , Union
4
+
5
+ import torch
4
6
5
7
from outlines .generate .generator import sequence_generator
6
8
from outlines .samplers import BeamSearchSampler , GreedySampler , MultinomialSampler
7
9
8
- if TYPE_CHECKING :
9
- import torch
10
-
11
10
FormattedOutput = Union [
12
11
str , int , float , bool , datetime .date , datetime .time , datetime .datetime
13
12
]
13
+ TotalCompletionsType = Optional [Union [List [str ], str ]]
14
14
15
15
16
16
class SequenceGenerator :
@@ -461,6 +461,47 @@ def prepare_generation_parameters(
461
461
462
462
return generation_params
463
463
464
+ def strip_completions (
465
+ self ,
466
+ completions ,
467
+ prompts : Union [str , List [str ]],
468
+ aligned_prompts : Union [str , List [str ]],
469
+ ):
470
+ """Remove characters generated through token alignment from the completions.
471
+
472
+ As token alignment makes the model re-generate some of the characters at
473
+ the end of the prompt, we want to remove those from the beginning of the
474
+ completions to only return the characters after the end of the user prompts.
475
+
476
+ Parameters
477
+ ----------
478
+ completions
479
+ Text generated by the model
480
+ prompts
481
+ The original prompts provided by the user
482
+ aligned_prompts
483
+ The prompts of the user after token alignment (what's given to the model)
484
+
485
+ Returns
486
+ -------
487
+ The stripped completions
488
+ """
489
+ if isinstance (prompts , str ):
490
+ if isinstance (completions , str ):
491
+ return completions [len (prompts ) - len (aligned_prompts ) :]
492
+
493
+ return [
494
+ self .strip_completions (completion , prompts , aligned_prompts )
495
+ for completion in completions
496
+ ]
497
+
498
+ return [
499
+ self .strip_completions (completion , prompt , aligned_prompt )
500
+ for completion , prompt , aligned_prompt in zip (
501
+ completions , prompts , aligned_prompts
502
+ )
503
+ ]
504
+
464
505
def format_sequence (self , sequence : str ) -> FormattedOutput :
465
506
"""Translate the generated sequence to another type.
466
507
@@ -500,15 +541,24 @@ def format(sequences):
500
541
max_tokens , stop_at , seed
501
542
)
502
543
544
+ aligned_prompts = self .logits_processor .align_prompts (prompts )
545
+
503
546
completions = self .model .generate (
504
- prompts ,
547
+ aligned_prompts ,
505
548
generation_params ,
506
549
self .logits_processor ,
507
550
self .sampling_params ,
508
551
** model_specific_params ,
509
552
)
510
553
511
- return format (completions )
554
+ print (completions , prompts , aligned_prompts )
555
+ stripped_completions = self .strip_completions (
556
+ completions , prompts , aligned_prompts
557
+ )
558
+
559
+ print (stripped_completions )
560
+
561
+ return format (stripped_completions )
512
562
513
563
def stream (
514
564
self ,
@@ -519,13 +569,72 @@ def stream(
519
569
** model_specific_params ,
520
570
):
521
571
"""Return a text generator from a prompt or a list of prompts."""
572
+
573
+ def add_chunks_to_completions (
574
+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
575
+ total_completions : Optional [
576
+ Union [str , List [str ], List [List [str ]], Sequence [str ]]
577
+ ],
578
+ ):
579
+ """Append each of the text chunks at the end of the corresponding completions"""
580
+ if isinstance (text_chunks , str ):
581
+ if isinstance (total_completions , str ):
582
+ return total_completions + text_chunks
583
+ return text_chunks
584
+
585
+ if total_completions :
586
+ return [
587
+ add_chunks_to_completions (text_chunk , total_completion )
588
+ for text_chunk , total_completion in zip (
589
+ text_chunks , total_completions
590
+ )
591
+ ]
592
+
593
+ return [
594
+ add_chunks_to_completions (text_chunk , None )
595
+ for text_chunk in text_chunks
596
+ ]
597
+
598
+ def strip_text_chunks (
599
+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
600
+ stripped_completions : Union [str , List [str ], List [List [str ]], Sequence [str ]],
601
+ ):
602
+ """Get the stripped text_chunks from the stripped_completions."""
603
+ if isinstance (text_chunks , str ):
604
+ return (
605
+ stripped_completions [- len (text_chunks ) :]
606
+ if len (text_chunks ) > 0
607
+ else ""
608
+ )
609
+
610
+ return [
611
+ strip_text_chunks (text_chunk , stripped_completion )
612
+ for text_chunk , stripped_completion in zip (
613
+ text_chunks , stripped_completions
614
+ )
615
+ ]
616
+
522
617
generation_params = self .prepare_generation_parameters (
523
618
max_tokens , stop_at , seed
524
619
)
525
- return self .model .stream (
620
+
621
+ aligned_prompts = self .logits_processor .align_prompts (prompts )
622
+
623
+ total_completions : TotalCompletionsType = None
624
+
625
+ for text_chunks in self .model .stream (
526
626
prompts ,
527
627
generation_params ,
528
628
self .logits_processor ,
529
629
self .sampling_params ,
530
630
** model_specific_params ,
531
- )
631
+ ):
632
+ total_completions = add_chunks_to_completions (
633
+ text_chunks , total_completions
634
+ )
635
+
636
+ stripped_completions = self .strip_completions (
637
+ total_completions , prompts , aligned_prompts
638
+ )
639
+
640
+ yield strip_text_chunks (text_chunks , stripped_completions )
0 commit comments