From cf84c40105cf6486cdc05d9f1e16540593a90534 Mon Sep 17 00:00:00 2001 From: ALGW71 Date: Mon, 20 May 2024 17:52:47 +0100 Subject: [PATCH 1/2] Added a gitignore, and the new code to allow users to run TCRLang-paired easily in the existing architecture. --- .gitignore | 162 +++++++++ ablang2/load_model.py | 8 +- notebooks/pretrained_module_tcrlang.ipynb | 418 ++++++++++++++++++++++ 3 files changed, 584 insertions(+), 4 deletions(-) create mode 100644 .gitignore create mode 100644 notebooks/pretrained_module_tcrlang.ipynb diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..efa407c --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/ablang2/load_model.py b/ablang2/load_model.py index ec8bfbd..2604c5e 100644 --- a/ablang2/load_model.py +++ b/ablang2/load_model.py @@ -4,10 +4,11 @@ list_of_models = { "ablang1-heavy":["https://opig.stats.ox.ac.uk/data/downloads/ablang-heavy.tar.gz", "amodel.pt"], "ablang1-light":["https://opig.stats.ox.ac.uk/data/downloads/ablang-light.tar.gz", "amodel.pt"], - "ablang2-paired":["https://zenodo.org/records/10185169/files/ablang2-weights.tar.gz", "model.pt"] + "ablang2-paired":["https://zenodo.org/records/10185169/files/ablang2-weights.tar.gz", "model.pt"], + "tcrlang-paired":["https://zenodo.org/records/11208211/files/tcrlang-weights.tar.gz", "model.pt"], } ablang1_models = ["ablang1-heavy", "ablang1-light"] -ablang2_models = ["ablang2-paired"] +ablang2_models = ["ablang2-paired", "tcrlang-paired"] def load_model(model_to_use = "ablang2-paired", random_init = False, device = 'cpu'): @@ -45,7 +46,7 @@ def download_model(model_to_use = "ablang2-paired"): local_model_folder = os.path.join(os.path.dirname(__file__), "model-weights-{}".format(model_to_use)) os.makedirs(local_model_folder, exist_ok = True) - file_w_weights, file_model = list_of_models[model_to_use] + file_w_weights, file_model = list_of_models[model_to_use] # modify list of models if not os.path.isfile(os.path.join(local_model_folder, file_model)): print("Downloading model ...") @@ -116,4 +117,3 @@ def fetch_ablang2(model_to_use, random_init=False, device='cpu'): tokenizer = tokenizers.ABtokenizer() return AbLang, tokenizer, hparams - diff --git a/notebooks/pretrained_module_tcrlang.ipynb b/notebooks/pretrained_module_tcrlang.ipynb new file mode 100644 index 0000000..727a05a --- /dev/null +++ b/notebooks/pretrained_module_tcrlang.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "28ec7744", + "metadata": {}, + "source": [ + "# **Running TCRLang-paired.**\n", + "\n", + "This simply involves using the TCRLang paired weights. All that needs to be changed is the model_to_use when the model is first initialised!\n", + "\n", + "All functionality should be the same **except for the \"align = True\"** mode. We are currently working on this issue." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7ae54cd0-6253-46dd-a316-4f20b12041e0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import ablang2" + ] + }, + { + "cell_type": "markdown", + "id": "10801511-770d-46ac-a15d-a02d4ef9ec87", + "metadata": {}, + "source": [ + "# **0. Sequence input and its format**\n", + "\n", + "This takes as input either the individual beta variable domain (TRB), alpha variable domain (TRA), or the paired TCR.\n", + "\n", + "Each record (antibody) needs to be a list with the TRB as the first element and the TRA as the second. If either the TRB or TRA is not known, leave an empty string.\n", + "\n", + "An asterisk (\\*) is used for masking. It is recommended to mask residues which you are interested in mutating.\n", + "\n", + "**NB:** It is important that the TRB and TRA sequence is ordered correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "99192978-a008-4a32-a80e-bba238e0ec7c", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's use the famous JM22 TCR to begin with\n", + "\n", + "seq1 = [\n", + " 'GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN', # TRB sequence\n", + " 'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDARKDSSLHITAAQPGDTGLYLCAGAGSQGNLIFGKGTKLSVKP' # TRA sequence\n", + "]\n", + "seq2 = [\n", + " 'GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN',\n", + " 'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'\n", + "]\n", + "seq3 = [\n", + " 'GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN',\n", + " '' # The TRA sequence is not known, so an empty string is left instead. \n", + "]\n", + "seq4 = [\n", + " '',\n", + " 'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDARKDSSLHITAAQPGDTGLYLCAGAGSQGNLIFGKGTKLSVKP'\n", + "]\n", + "seq5 = [\n", + " 'GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSS*EQYFGPGTRLTVTEDLKN', # (*) is used to mask certain residues\n", + " 'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGD*RKDSSLHITAAQPGDTGLYLCAG*GSQGNLIFGKGTKLSVKP'\n", + "]\n", + "\n", + "all_seqs = [seq1, seq2, seq3, seq4, seq5]\n", + "only_both_chains_seqs = [seq1, seq2, seq5]" + ] + }, + { + "cell_type": "markdown", + "id": "dffbacfa-8642-4d94-9572-2205a05c18f9", + "metadata": {}, + "source": [ + "# **1. How to use TCRLang-paired**\n", + "\n", + "TCRLang-paired can be downloaded and used in its raw form as seen below. For convenience, we have also developed different \"modes\" which can be used for specific use cases (see Section 2) " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0e7419e4-db22-49ea-8e12-6db2b3681545", + "metadata": {}, + "outputs": [], + "source": [ + "# Download and initialise the model\n", + "tcrlang = ablang2.pretrained(model_to_use='tcrlang-paired', # This is all that needs to be changed.\n", + " random_init=False, \n", + " ncpu=1, \n", + " device='cpu')\n", + "\n", + "# Tokenize input sequences\n", + "seq = f\"{seq1[0]}|{seq1[1]}\" # TRB first, TRA second, with | used to separated the two sequences \n", + "tokenized_seq = tcrlang.tokenizer([seq], pad=True, w_extra_tkns=False, device=\"cpu\")\n", + " \n", + "# Generate rescodings\n", + "with torch.no_grad():\n", + " rescoding = tcrlang.AbRep(tokenized_seq).last_hidden_states\n", + "\n", + "# Generate logits/likelihoods\n", + "with torch.no_grad():\n", + " likelihoods = tcrlang.AbLang(tokenized_seq)" + ] + }, + { + "cell_type": "markdown", + "id": "48562761-6ebe-4025-be97-918c9f9eff7e", + "metadata": {}, + "source": [ + "# **2. Different modes for specific usecases**\n", + "\n", + "ablang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.\n", + "\n", + "1. seqcoding: Generates sequence representations for each sequence\n", + "2. rescoding: Generates residue representations for each residue in each sequence\n", + "3. likelihood: Generates likelihoods for each amino acid at each position in each sequence\n", + "4. probability: Generates probabilities for each amino acid at each position in each sequence\n", + "5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)\n", + "6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)\n", + "7. restore: Restores masked residues\n", + "\n", + "### **ablang2 can also align the resulting representations using ANARCI**\n", + "\n", + "This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument \"align=True\".\n", + "\n", + "**NB**: Align can only be used on input with the same format, i.e. either all beta, all alpha, or all both beta and alpha.\n", + "\n", + "### **The align argument can also be used to restore variable missing lengths**\n", + "\n", + "For this, use \"align=True\" with the 'restore' mode." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ceae4a88-0679-4704-8bad-c06a4569c497", + "metadata": {}, + "outputs": [], + "source": [ + "tcrlang = ablang2.pretrained()\n", + "\n", + "valid_modes = [\n", + " 'seqcoding', 'rescoding', 'likelihood', 'probability',\n", + " 'pseudo_log_likelihood', 'confidence', 'restore' \n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "aa333732-7508-4826-92ec-3acdd54bc1bb", + "metadata": {}, + "source": [ + "## **seqcoding** \n", + "\n", + "The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, including extra tokens. \n", + "\n", + "**NB:** Seqcodings can also be derived in other ways like using the sum or averaging across only parts of the input such as the CDRs. For such cases please use and adapt the below rescoding." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.08125024, -0.01384698, -0.15913074, ..., 0.28860582,\n", + " -0.12494163, 0.04056989],\n", + " [-0.13111236, 0.03872783, -0.03324484, ..., 0.28661499,\n", + " -0.10136888, 0.00161366],\n", + " [-0.05338498, -0.06573963, -0.11988864, ..., 0.31983912,\n", + " -0.07272346, 0.06720409],\n", + " [-0.08249289, 0.08119008, -0.26181517, ..., 0.22905781,\n", + " -0.17356422, -0.01206601],\n", + " [-0.0763039 , -0.00637079, -0.16056736, ..., 0.29032119,\n", + " -0.13395997, 0.03960718]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tcrlang(all_seqs, mode='seqcoding')" + ] + }, + { + "cell_type": "markdown", + "id": "ac98c317-e085-4d94-9729-bbae039f49ac", + "metadata": {}, + "source": [ + "## **rescoding / likelihood / probability**\n", + "\n", + "The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.\n", + "\n", + "**NB:** The output includes extra tokens (start, stop and separation tokens) in the format \"|\". The length of the output is therefore 5 longer than the TRB and TRA.\n", + "\n", + "**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the \"stepwise_masking\" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[-0.51347226, -0.28576213, 0.10309762, ..., 0.5203557 ,\n", + " 0.0117211 , -0.08264501],\n", + " [-0.19843523, 0.1543197 , -0.5031867 , ..., -0.2870497 ,\n", + " -0.21901898, 0.36414275],\n", + " [-0.16083395, 0.20545605, -0.21582198, ..., -0.00274766,\n", + " 0.066175 , 0.27734903],\n", + " ...,\n", + " [ 0.19438688, 0.2340737 , 0.0407359 , ..., 0.03316897,\n", + " 0.18425798, 0.14009582],\n", + " [-0.09048033, -0.41594166, 0.3686235 , ..., 0.05291507,\n", + " -0.13554473, -0.09198374],\n", + " [-0.04215955, -0.4688292 , -0.04049325, ..., 0.05855337,\n", + " 0.08600137, 0.13561374]], dtype=float32),\n", + " array([[-0.33639285, 0.06262851, -0.09385429, ..., 0.29438573,\n", + " 0.09021386, -0.03847658],\n", + " [-0.13315135, 0.18713355, 0.07811087, ..., 0.5782139 ,\n", + " -0.22035252, 0.03181488],\n", + " [ 0.3239961 , -0.01685584, -0.5550718 , ..., 0.36060256,\n", + " 0.42027324, 0.03702496],\n", + " ...,\n", + " [-0.14616522, 0.15133138, -0.23368333, ..., -0.19600233,\n", + " -0.55732346, -0.15294474],\n", + " [ 0.11544223, 0.15623151, 0.19846602, ..., 0.19014375,\n", + " -0.4080012 , 0.2844629 ],\n", + " [-0.19953507, -0.27715072, 0.11643803, ..., 0.0261317 ,\n", + " 0.18988611, -0.16810946]], dtype=float32),\n", + " array([[-0.50845665, -0.31238076, 0.09663024, ..., 0.46686876,\n", + " 0.14437647, -0.07101524],\n", + " [-0.15447244, 0.16523995, -0.5090536 , ..., -0.2029969 ,\n", + " -0.16771431, 0.34480545],\n", + " [-0.11773082, 0.07923666, -0.25603563, ..., 0.10093339,\n", + " -0.01438329, 0.28142142],\n", + " ...,\n", + " [-0.04196927, 0.02127836, -0.44641826, ..., -0.08536111,\n", + " 0.0983487 , 0.40973493],\n", + " [-0.10732048, 0.03674107, -0.27424234, ..., 0.12897371,\n", + " 0.18343335, -0.05968828],\n", + " [-0.40199417, 0.08560478, 0.0094398 , ..., 0.0304165 ,\n", + " -0.07760128, -0.09373415]], dtype=float32),\n", + " array([[-0.44753143, -0.11013485, -0.05007413, ..., 0.41176277,\n", + " 0.01311761, 0.0878911 ],\n", + " [-0.2665582 , -0.4412719 , -0.2218026 , ..., 0.303545 ,\n", + " 0.077465 , -0.09151924],\n", + " [-0.07410756, 0.00668654, -0.3327679 , ..., 0.1349459 ,\n", + " -0.31039506, 0.2365871 ],\n", + " ...,\n", + " [ 0.20614003, 0.5740512 , -0.01496052, ..., 0.09354969,\n", + " 0.20616154, 0.16729088],\n", + " [-0.10038415, -0.3134321 , 0.27842972, ..., 0.02146813,\n", + " -0.21205966, -0.15864977],\n", + " [-0.01906172, -0.47228584, -0.05227043, ..., 0.10767294,\n", + " 0.07469795, 0.18540816]], dtype=float32),\n", + " array([[-0.42599586, 0.10091405, -0.20853448, ..., 0.22974445,\n", + " 0.09603838, 0.12763825],\n", + " [-0.2779723 , 0.26220116, -0.00395166, ..., 0.3220927 ,\n", + " -0.05186327, 0.02503206],\n", + " [ 0.24614441, -0.08099816, -0.41675895, ..., 0.29475564,\n", + " 0.27429575, 0.01983144],\n", + " ...,\n", + " [ 0.33051035, 0.24691616, -0.01975616, ..., -0.02933154,\n", + " 0.1508603 , 0.12146682],\n", + " [-0.04957892, -0.38600227, 0.3221674 , ..., 0.05534112,\n", + " -0.15177359, -0.10046635],\n", + " [-0.03469401, -0.43283713, -0.03586697, ..., 0.08389509,\n", + " 0.08420966, 0.12835406]], dtype=float32)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tcrlang(all_seqs, mode='rescoding', stepwise_masking = False)" + ] + }, + { + "cell_type": "markdown", + "id": "8f0a71ec-e916-4330-90d0-13a4b1121a89", + "metadata": {}, + "source": [ + "## **Pseudo log likelihood and Confidence scores**\n", + "\n", + "The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.\n", + "\n", + "- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).\n", + "\n", + "- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. \n", + "\n", + "**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.\n", + "\n", + "**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "83f3064b-48a7-42fb-ba82-ec153ea946da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([20.41889 , 6.8523793, 21.07971 , 20.479687 , 18.358845 ],\n", + " dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = tcrlang(all_seqs, mode='pseudo_log_likelihood')\n", + "np.exp(-results) # convert to pseudo perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([2.2753055, 1.528476 , 2.1577282, 2.5318768, 2.156193 ],\n", + " dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = tcrlang(all_seqs, mode='confidence')\n", + "np.exp(-results)" + ] + }, + { + "cell_type": "markdown", + "id": "e0b63e48-b2a1-4a8e-8ecb-449748a2cb25", + "metadata": {}, + "source": [ + "## **restore**\n", + "\n", + "This mode can be used to restore masked residues. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['|',\n", + " '|',\n", + " '|'],\n", + " dtype=' Date: Mon, 20 May 2024 17:53:14 +0100 Subject: [PATCH 2/2] Update README. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 530bc9a..a87b531 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ **Availability and implementation:** AbLang2 is a python package available at https://github.com/oxpig/AbLang2.git. +**TCRLang-Paired:** The AbLang2 architecture can be initialised with model weights trained on paired TCR sequences. This model can be used in an identical way to AbLang2 on TCR sequences. The only missing functionality is the lack of the align command. The generation of sequence and residue encodings, as well as masking are all the same. For an example please see the notebook. + -----------