Skip to content

Commit e813b32

Browse files
authored
Merge pull request oobabooga#6203 from oobabooga/dev
Merge dev branch
2 parents 3315d00 + aa653e3 commit e813b32

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

modules/chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def find_all_histories_with_first_prompts(state):
577577
data = json.load(f)
578578

579579
first_prompt = ""
580-
if 'visible' in data and len(data['visible']) > 0:
580+
if data and 'visible' in data and len(data['visible']) > 0:
581581
if data['internal'][0][0] == '<|BEGIN-VISIBLE-CHAT|>':
582582
if len(data['visible']) > 1:
583583
first_prompt = html.unescape(data['visible'][1][0])

modules/llama_cpp_python_hijack.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def eval_with_progress(self, tokens: Sequence[int]):
100100

101101

102102
def monkey_patch_llama_cpp_python(lib):
103+
if getattr(lib.Llama, '_is_patched', False):
104+
# If the patch is already applied, do nothing
105+
return
103106

104107
def my_generate(self, *args, **kwargs):
105-
106108
if shared.args.streaming_llm:
107109
new_sequence = args[0]
108110
past_sequence = self._input_ids
@@ -116,3 +118,6 @@ def my_generate(self, *args, **kwargs):
116118
lib.Llama.eval = eval_with_progress
117119
lib.Llama.original_generate = lib.Llama.generate
118120
lib.Llama.generate = my_generate
121+
122+
# Set the flag to indicate that the patch has been applied
123+
lib.Llama._is_patched = True

0 commit comments

Comments
 (0)