From 77a9ad647575cd297b40fa34f5070ca7a4ce576d Mon Sep 17 00:00:00 2001 From: JT Date: Mon, 22 Nov 2021 18:06:23 +0100 Subject: [PATCH] Remove batch encoder --- .../semantic_clustering.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/asreviewcontrib/semantic_clustering/semantic_clustering.py b/asreviewcontrib/semantic_clustering/semantic_clustering.py index c5a310e..e1c8a64 100644 --- a/asreviewcontrib/semantic_clustering/semantic_clustering.py +++ b/asreviewcontrib/semantic_clustering/semantic_clustering.py @@ -56,19 +56,20 @@ def run_clustering_steps( # tokenize abstracts and add to data print("Tokenizing abstracts...") - encoded = tokenizer.batch_encode_plus( - data['abstract'].tolist(), - add_special_tokens=False, - truncation=True, - max_length=200, - padding='max_length', - return_tensors='pt') + encoded = data['abstract'].progress_apply( + lambda x: tokenizer.encode_plus( + x, + add_special_tokens=False, + truncation=True, + max_length=512, + # padding='max_length', + return_tensors='pt')) # generate embeddings and format correctly print("Generating embeddings...") embeddings = [] - for x in tqdm(encoded.input_ids): - embeddings.append(model(x.unsqueeze(0), output_hidden_states=False)[-1].detach().numpy().squeeze()) # noqa: E501 + for x in tqdm(encoded): + embeddings.append(model(**x, output_hidden_states=False)[-1].detach().numpy().squeeze()) # noqa: E501 # from here on the data is not directly attached to the dataframe anymore, # as a result of legacy code. This will be fixed in a future PR.