diff --git a/src/genesys/data.py b/src/genesys/data.py index f398fb8..c7244db 100644 --- a/src/genesys/data.py +++ b/src/genesys/data.py @@ -102,16 +102,25 @@ def _prepare_batch(self, batch: dict, dataset: str) -> tuple: batch_messages = [ [ {"role": "user", "content": b["prompt"]}, - {"role": "assistant", "content": "/n" + b["llm_response"]}, + {"role": "assistant", "content": "\n" + b["llm_response_first_time"]}, + {"role": "assistant", "content": ""}, # this message needs to be here so hf templating works, we're stripping it out again below ] for b in batch ] + batch_inputs = self.tokenizer.apply_chat_template( + batch_messages, + tokenize=False, + continue_final_message=True + ) + unwanted_suffix = "<|end▁of▁sentence|><|Assistant|><|end▁of▁sentence|>" # strip out last message + for i, inp in enumerate(batch_inputs): + if inp.endswith(unwanted_suffix): + batch_inputs[i] = inp[: -len(unwanted_suffix)] else: batch_messages = [ [{"role": "user", "content": b["prompt"]}, {"role": "assistant", "content": "/n"}] for b in batch ] - - batch_inputs = self.tokenizer.apply_chat_template(batch_messages, tokenize=False, continue_final_message=True) + batch_inputs = self.tokenizer.apply_chat_template(batch_messages, tokenize=False, continue_final_message=True) return batch_inputs, batch diff --git a/src/genesys/generate.py b/src/genesys/generate.py index 732e0cb..164e777 100644 --- a/src/genesys/generate.py +++ b/src/genesys/generate.py @@ -51,10 +51,13 @@ def main(config: GenerateConfig): # Initialize components log("[cyan] Configuring output path and gcp bucket...[/]") + if config.gcp_bucket is not None: + gcp_credentials = os.environ.get("GCP_CREDENTIALS_BASE64") + assert gcp_credentials is not None, "the GCP_CREDENTIALS_BASE64 environment variable is not set" if not os.path.exists(config.path_output): os.makedirs(config.path_output) gcp_bucket = ( - GcpBucket(config.gcp_bucket, os.environ.get("GCP_CREDENTIALS_BASE64")) + GcpBucket(config.gcp_bucket, gcp_credentials) if config.gcp_bucket is not None else None )