From e44203ed5003c23d0fd27b72f5a0cf5671f0b445 Mon Sep 17 00:00:00 2001 From: chflame163 Date: Thu, 26 Sep 2024 10:51:09 +0800 Subject: [PATCH] rename py/BiRefNet to py/BiRefNet_v2, for avoid module name conflicts --- py/BiRefNet/.gitignore | 142 -- py/BiRefNet/LICENSE | 21 - py/BiRefNet/README.md | 316 ---- py/BiRefNet/__init__.py | 0 py/BiRefNet/config.py | 174 -- py/BiRefNet/dataset.py | 118 -- py/BiRefNet/eval_existingOnes.py | 146 -- py/BiRefNet/evaluation/metrics.py | 763 -------- py/BiRefNet/gen_best_ep.py | 86 - py/BiRefNet/image_proc.py | 119 -- py/BiRefNet/inference.py | 105 -- py/BiRefNet/loss.py | 277 --- py/BiRefNet/make_a_copy.sh | 18 - .../models/backbones/build_backbone.py | 44 - py/BiRefNet/models/backbones/pvt_v2.py | 435 ----- py/BiRefNet/models/backbones/swin_v1.py | 627 ------- py/BiRefNet/models/birefnet.py | 286 --- py/BiRefNet/models/modules/aspp.py | 120 -- py/BiRefNet/models/modules/decoder_blocks.py | 66 - py/BiRefNet/models/modules/deform_conv.py | 66 - py/BiRefNet/models/modules/lateral_blocks.py | 21 - py/BiRefNet/models/modules/mlp.py | 118 -- py/BiRefNet/models/modules/prompt_encoder.py | 222 --- py/BiRefNet/models/modules/utils.py | 54 - py/BiRefNet/models/refinement/refiner.py | 252 --- py/BiRefNet/models/refinement/stem_layer.py | 45 - py/BiRefNet/requirements.txt | 15 - py/BiRefNet/rm_cache.sh | 20 - py/BiRefNet/sub.sh | 17 - py/BiRefNet/test.sh | 29 - py/BiRefNet/train.py | 333 ---- py/BiRefNet/train.sh | 42 - py/BiRefNet/train_test.sh | 11 - .../tutorials/BiRefNet_inference.ipynb | 1575 ----------------- py/BiRefNet/tutorials/BiRefNet_pth2onnx.ipynb | 312 ---- py/BiRefNet/utils.py | 97 - py/BiRefNet_v2 | 1 + py/birefnet_ultra_v2.py | 7 +- pyproject.toml | 2 +- 39 files changed, 6 insertions(+), 7096 deletions(-) delete mode 100644 py/BiRefNet/.gitignore delete mode 100644 py/BiRefNet/LICENSE delete mode 100644 py/BiRefNet/README.md delete mode 100644 py/BiRefNet/__init__.py delete mode 100644 py/BiRefNet/config.py delete mode 100644 py/BiRefNet/dataset.py delete mode 100644 py/BiRefNet/eval_existingOnes.py delete mode 100644 py/BiRefNet/evaluation/metrics.py delete mode 100644 py/BiRefNet/gen_best_ep.py delete mode 100644 py/BiRefNet/image_proc.py delete mode 100644 py/BiRefNet/inference.py delete mode 100644 py/BiRefNet/loss.py delete mode 100644 py/BiRefNet/make_a_copy.sh delete mode 100644 py/BiRefNet/models/backbones/build_backbone.py delete mode 100644 py/BiRefNet/models/backbones/pvt_v2.py delete mode 100644 py/BiRefNet/models/backbones/swin_v1.py delete mode 100644 py/BiRefNet/models/birefnet.py delete mode 100644 py/BiRefNet/models/modules/aspp.py delete mode 100644 py/BiRefNet/models/modules/decoder_blocks.py delete mode 100644 py/BiRefNet/models/modules/deform_conv.py delete mode 100644 py/BiRefNet/models/modules/lateral_blocks.py delete mode 100644 py/BiRefNet/models/modules/mlp.py delete mode 100644 py/BiRefNet/models/modules/prompt_encoder.py delete mode 100644 py/BiRefNet/models/modules/utils.py delete mode 100644 py/BiRefNet/models/refinement/refiner.py delete mode 100644 py/BiRefNet/models/refinement/stem_layer.py delete mode 100644 py/BiRefNet/requirements.txt delete mode 100644 py/BiRefNet/rm_cache.sh delete mode 100644 py/BiRefNet/sub.sh delete mode 100644 py/BiRefNet/test.sh delete mode 100644 py/BiRefNet/train.py delete mode 100644 py/BiRefNet/train.sh delete mode 100644 py/BiRefNet/train_test.sh delete mode 100644 py/BiRefNet/tutorials/BiRefNet_inference.ipynb delete mode 100644 py/BiRefNet/tutorials/BiRefNet_pth2onnx.ipynb delete mode 100644 py/BiRefNet/utils.py create mode 160000 py/BiRefNet_v2 diff --git a/py/BiRefNet/.gitignore b/py/BiRefNet/.gitignore deleted file mode 100644 index af0deeca..00000000 --- a/py/BiRefNet/.gitignore +++ /dev/null @@ -1,142 +0,0 @@ -# Custom -e_* -.vscode -ckpt -preds -evaluation/eval-* -nohup.out* -tmp* -*.pth -core-*-python-* -.DS_Store -__MACOSX/ - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ diff --git a/py/BiRefNet/LICENSE b/py/BiRefNet/LICENSE deleted file mode 100644 index 485921e8..00000000 --- a/py/BiRefNet/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2024 ZhengPeng - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/py/BiRefNet/README.md b/py/BiRefNet/README.md deleted file mode 100644 index 39f71d7f..00000000 --- a/py/BiRefNet/README.md +++ /dev/null @@ -1,316 +0,0 @@ -

Bilateral Reference for High-Resolution Dichotomous Image Segmentation

- -
- Peng Zheng 1,4,5,6,  - Dehong Gao 2,  - Deng-Ping Fan 1*,  - Li Liu 3,  - Jorma Laaksonen 4,  - Wanli Ouyang 5,  - Nicu Sebe 6 -
- -
- 1 Nankai University  2 Northwestern Polytechnical University  3 National University of Defense Technology  -
- 4 Aalto University  5 Shanghai AI Laboratory  6 University of Trento  -
- -
-   -   -   -   -   -   -   -   -
- -
-   -   -   -
- - -| *DIS-Sample_1* | *DIS-Sample_2* | -| :------------------------------: | :-------------------------------: | -| | | - -This repo is the official implementation of "[**Bilateral Reference for High-Resolution Dichotomous Image Segmentation**](https://arxiv.org/pdf/2401.03407)" (___CAAI AIR 2024___). - -> [!note] -> **We need more GPU resources** to push forward the performance of BiRefNet, especially on *matting* tasks, higher-resolution inference (*2K*), and more *efficient* model design. If you are happy to cooperate, please contact me at zhengpeng0108@gmail.com. - -## News :newspaper: -* **`Aug 30, 2024`:** We uploaded notebooks in `tutorials` to run the inference and ONNX conversion locally. -* **`Aug 23, 2024`:** Our BiRefNet is now officially released [online](https://www.sciopen.com/article/10.26599/AIR.2024.9150038) on CAAI AIR journal. And thanks to the [press release](https://www.eurekalert.org/news-releases/1055380). -* **`Aug 19, 2024`:** We uploaded the ONNX model files of all weights in the [GitHub release](https://github.com/ZhengPeng7/BiRefNet/releases/tag/v1) and [GDrive folder](https://drive.google.com/drive/u/0/folders/1kZM55bwsRdS__bdnsXpkmH6QPyza-9-N). Check out the **ONNX conversion** part in [model zoo](https://github.com/ZhengPeng7/BiRefNet?tab=readme-ov-file#model-zoo) for more details. -* **`Jul 30, 2024`:** Thanks to @not-lain for his kind efforts in adding BiRefNet to the official huggingface.js [repo](https://github.com/huggingface/huggingface.js/blob/3a8651fbc6508920475564a692bf0e5b601d9343/packages/tasks/src/model-libraries-snippets.ts#L763). -* **`Jul 28, 2024`:** We released the [Colab demo for box-guided segmentation](https://colab.research.google.com/drive/1B6aKZ3ekcvKMkSBn0N5mCASLUYMp0whK). -* **`Jul 15, 2024`:** We deployed our BiRefNet on [Hugging Face Models](https://huggingface.co/ZhengPeng7/BiRefNet) for users to easily load it in one line code. -* **`Jun 21, 2024`:** We released and uploaded the Chinese version of our original paper to my [GDrive](https://drive.google.com/file/d/1aBnJ_R9lbnC2dm8dqD0-pzP2Cu-U1Xpt/view). -* **`May 28, 2024`:** We hold a [model zoo](https://github.com/ZhengPeng7/BiRefNet?tab=readme-ov-file#model-zoo) with well-trained weights of our BiRefNet in different sizes and for different tasks, including general use, matting segmentation, DIS, HRSOD, COD, etc. -* **`May 7, 2024`:** We also released the [Colab demo for multiple images inference](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba). Many thanks to @rishabh063 for his support on it. -* **`Apr 9, 2024`:** Thanks to [Features and Labels Inc.](https://fal.ai/) for deploying a cool online BiRefNet [inference API](https://fal.ai/models/fal-ai/birefnet/playground) and providing me with strong GPU resources for 4 months on more extensive experiments! -* **`Mar 7, 2024`:** We released BiRefNet codes, the well-trained weights for all tasks in the original papers, and all related stuff in my [GDrive folder](https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM). Meanwhile, we also deployed our BiRefNet on [Hugging Face Spaces](https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo) for easier online use and released the [Colab demo for inference and evaluation](https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl). -* **`Jan 7, 2024`:** We released our paper on [arXiv](https://arxiv.org/pdf/2401.03407). - - -## :rocket: Load BiRefNet in _ONE LINE_ by HuggingFace, check more: [![BiRefNet](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/ZhengPeng7/birefnet) -```python -from transformers import AutoModelForImageSegmentation -birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True) -``` -## :flight_arrival: Inference Partner: -We are really happy to collaborate with [FAL](https://fal.ai) to deploy the **inference API** of BiRefNet. You can access this service via the link below: -+ https://fal.ai/models/fal-ai/birefnet - -Our BiRefNet has achieved SOTA on many similar HR tasks: - -**DIS**: [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/dichotomous-image-segmentation-on-dis-te1)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/dichotomous-image-segmentation-on-dis-te2)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/dichotomous-image-segmentation-on-dis-te3)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/dichotomous-image-segmentation-on-dis-te4)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=bilateral-reference-for-high-resolution) - -
Figure of Comparison on DIS Papers with Codes (by the time of this work): - - - - - -
-
- -**COD**:[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/camouflaged-object-segmentation-on-cod)](https://paperswithcode.com/sota/camouflaged-object-segmentation-on-cod?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/camouflaged-object-segmentation-on-nc4k)](https://paperswithcode.com/sota/camouflaged-object-segmentation-on-nc4k?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/camouflaged-object-segmentation-on-camo)](https://paperswithcode.com/sota/camouflaged-object-segmentation-on-camo?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/camouflaged-object-segmentation-on-chameleon)](https://paperswithcode.com/sota/camouflaged-object-segmentation-on-chameleon?p=bilateral-reference-for-high-resolution) - -
Figure of Comparison on COD Papers with Codes (by the time of this work): - - - -
-
- -**HRSOD**: [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/rgb-salient-object-detection-on-davis-s)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-davis-s?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/rgb-salient-object-detection-on-hrsod)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-hrsod?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/rgb-salient-object-detection-on-uhrsd)](https://paperswithcode.com/sota/rgb-salient-object-detection-on-uhrsd?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/salient-object-detection-on-duts-te)](https://paperswithcode.com/sota/salient-object-detection-on-duts-te?p=bilateral-reference-for-high-resolution) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bilateral-reference-for-high-resolution/salient-object-detection-on-dut-omron)](https://paperswithcode.com/sota/salient-object-detection-on-dut-omron?p=bilateral-reference-for-high-resolution) - -
Figure of Comparison on HRSOD Papers with Codes (by the time of this work): - - - - - -
-
- -#### Try our online demos for inference: - -+ **Inference and evaluation** of your given weights: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl) -+ **Online Inference with GUI** with adjustable resolutions: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo) -+ Online **Multiple Images Inference** on Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba) - - - - - -## Model Zoo - -> For more general use of our BiRefNet, I extended the original academic one to more general ones for better real-life application. -> -> Datasets and datasets are suggested to be downloaded from official pages. But you can also download the packaged ones: [DIS](https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ), [HRSOD](https://drive.google.com/drive/folders/18_hAE3QM4cwAzEAKXuSNtKjmgFXTQXZN), [COD](https://drive.google.com/drive/folders/1EyHmKWsXfaCR9O0BiZEc3roZbRcs4ECO), [Backbones](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms). -> -> Find performances (almost all metrics) of all models in the `exp-TASK_SETTINGS` folders in [[**stuff**](https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM)]. - - - -
Models in the original paper, for comparison on benchmarks: - -| Task | Training Sets | Backbone | Download | -| :---: | :-------------------------: | :-----------: | :----------------------------------------------------------: | -| DIS | DIS5K-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1J90LucvDQaS3R_-9E7QUh1mgJ8eQvccb/view) | -| COD | COD10K-TR, CAMO-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1tM5M72k7a8aKF-dYy-QXaqvfEhbFaWkC/view) | -| HRSOD | DUTS-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1f7L0Pb1Y3RkOMbqLCW_zO31dik9AiUFa/view) | -| HRSOD | HRSOD-TR | swin_v1_large | google-drive | -| HRSOD | UHRSD-TR | swin_v1_large | google-drive | -| HRSOD | DUTS-TR, HRSOD-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1WJooyTkhoDLllaqwbpur_9Hle0XTHEs_/view) | -| HRSOD | DUTS-TR, UHRSD-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1Pu1mv3ORobJatIuUoEuZaWDl2ylP3Gw7/view) | -| HRSOD | HRSOD-TR, UHRSD-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/1xEh7fsgWGaS5c3IffMswasv0_u-aVM9E/view) | -| HRSOD | DUTS-TR, HRSOD-TR, UHRSD-TR | swin_v1_large | [google-drive](https://drive.google.com/file/d/13FaxyyOwyCddfZn2vZo1xG1KNZ3cZ-6B/view) | - -
- - - -
Models trained with customed data (general, matting), for general use in practical application: - -| Task | Training Sets | Backbone | Test Set | Metric (S, wF[, HCE]) | Download | -| :-----------------------: | :----------------------------------------------------------: | :-----------: | :-------: | :-------------------: | :----------------------------------------------------------: | -| **general use** | DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE, TR-P3M-10k, TE-P3M-500-NP, TE-P3M-500-P, TR-humans | swin_v1_large | DIS-VD | 0.911, 0.875, 1069 | [google-drive](https://drive.google.com/file/d/1_IfUnu8Fpfn-nerB89FzdNXQ7zk6FKxc/view) | -| **general use** | DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE, TR-P3M-10k, TE-P3M-500-NP, TE-P3M-500-P, TR-humans | swin_v1_tiny | DIS-VD | 0.882, 0.830, 1175 | [google-drive](https://drive.google.com/file/d/1fzInDWiE2n65tmjaHDSZpqhL0VME6-Yl/view) | -| **general use** | DIS5K-TR, DIS-TEs | swin_v1_large | DIS-VD | 0.907, 0.865, 1059 | [google-drive](https://drive.google.com/file/d/1P6NJzG3Jf1sl7js2q1CPC3yqvBn_O8UJ/view) | -| **matting segmentation** | [P3M-10k](https://github.com/JizhiziLi/P3M), [humans](https://huggingface.co/datasets/schirrmacher/humans) | swin_v1_large | P3M-500-P | 0.983, 0.989 | [google-drive](https://drive.google.com/file/d/1uUeXjEUoD2XF_6YjD_fsct-TJp7TFiqh) | - -
- - - -
Segmentation with box guidance: - -+ Given box guidance: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1B6aKZ3ekcvKMkSBn0N5mCASLUYMp0whK) - -
- - - -
Model efficiency: - -> Screenshot from the original paper. All tests are conducted on a single A100 GPU. - - - -
- - - -
ONNX conversion: - -> We converted from `.pth` weights files to `.onnx` files. -> We referred a lot to the [Kazuhito00/BiRefNet-ONNX-Sample](https://github.com/Kazuhito00/BiRefNet-ONNX-Sample), many thanks to @Kazuhito00. - -+ Check our [Colab demo for ONNX conversion](https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om) or the [notebook file for local running](https://drive.google.com/file/d/1cgL2qyvOO5q3ySfhytypX46swdQwZLrJ), where you can do the conversion/inference by yourself and find all relevant info. -+ As tested, BiRefNets with SwinL (default backbone) cost `~90%` more time (the inference costs `~165ms` on an A100 GPU) using ONNX files. Meanwhile, BiRefNets with SwinT (lightweight) cost `~75%` more time (the inference costs `~93.8ms` on an A100 GPU) using ONNX files. Input resolution is `1024x1024` as default. -+ The results of the original pth files and the converted onnx files are slightly different, which is acceptable. -+ Pay attention to the compatibility among `onnxruntime-gpu, CUDA, and CUDNN` (we use `torch==2.0.1, cuda=11.8` here). - - -
- -## Third-Party Creations - -> Concerning edge devices with less computing power, we provide a lightweight version with `swin_v1_tiny` as the backbone, which is x4+ faster and x5+ smaller. The details can be found in [this issue](https://github.com/ZhengPeng7/BiRefNet/issues/11#issuecomment-2041033576) and links there. - -We found there've been some 3rd party applications based on our BiRefNet. Many thanks for their contribution to the community! -Choose the one you like to try with clicks instead of codes: -1. **Applications**: - + Thanks [**lbq779660843/BiRefNet-Tensorrt**](https://github.com/lbq779660843/BiRefNet-Tensorrt) and [**yuanyang1991/birefnet_tensorrt**](https://github.com/yuanyang1991/birefnet_tensorrt): they both provided the project to convert BiRefNet to **TensorRT**, which is faster and better for deployment. Their repos offer solid local establishment (Win and Linux) and [colab demo](https://colab.research.google.com/drive/1r8GkFPyMMO0OkMX6ih5FjZnUCQrl2SHV?usp=sharing), respectively. And @yuanyang1991 kindly offered the comparison among the inference efficiency of naive PyTorch, ONNX, and TensorRT on an RTX 4080S: - -| Methods | [Pytorch](https://drive.google.com/file/d/1_IfUnu8Fpfn-nerB89FzdNXQ7zk6FKxc/view) | [ONNX](https://drive.google.com/drive/u/0/folders/1kZM55bwsRdS__bdnsXpkmH6QPyza-9-N) | TensorRT | -|:------------------------------------------------------------------------------------:|:--------------:|:--------------:|:--------------:| -|        First Inference Time       | 0.71s | 5.32s | **0.17s** | - -| Methods | [Pytorch](https://drive.google.com/file/d/1_IfUnu8Fpfn-nerB89FzdNXQ7zk6FKxc/view) | [ONNX](https://drive.google.com/drive/u/0/folders/1kZM55bwsRdS__bdnsXpkmH6QPyza-9-N) | TensorRT | -|:------------------------------------------------------------------------------------:|:--------------:|:--------------:|:--------------:| -| Avg Inf Time (excluding 1st) | 0.15s | 4.43s | **0.11s** | - - + Thanks [**dimitribarbot/sd-webui-birefnet**](https://github.com/dimitribarbot/sd-webui-birefnet): this project allows to add a BiRefNet section to the original **Stable Diffusion WebUI**'s Extras tab. -

- - + Thanks [**fal.ai/birefnet**](https://fal.ai/models/birefnet): this project on `fal.ai` encapsulates BiRefNet **online** with more useful options in **UI** and **API** to call the model. -

- - + Thanks [**ZHO-ZHO-ZHO/ComfyUI-BiRefNet-ZHO**](https://github.com/ZHO-ZHO-ZHO/ComfyUI-BiRefNet-ZHO): this project further improves the **UI** for BiRefNet in ComfyUI, especially for **video data**. -

- - - - + Thanks [**viperyl/ComfyUI-BiRefNet**](https://github.com/viperyl/ComfyUI-BiRefNet): this project packs BiRefNet as **ComfyUI nodes**, and makes this SOTA model easier use for everyone. -

- - + Thanks [**Rishabh**](https://github.com/rishabh063) for offering a demo for the [easier multiple images inference on colab](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba). - -2. **More Visual Comparisons** - + Thanks [**twitter.com/ZHOZHO672070**](https://twitter.com/ZHOZHO672070) for the comparison with more background-removal methods in images: - - - - + Thanks [**twitter.com/toyxyz3**](https://twitter.com/toyxyz3) for the comparison with more background-removal methods in videos: - - - - - - -## Usage - -#### Environment Setup - -```shell -# PyTorch==2.0.1 is used for faster training with compilation. -conda create -n birefnet python=3.9 -y && conda activate birefnet -pip install -r requirements.txt -``` - -#### Dataset Preparation - -Download combined training / test sets I have organized well from: [DIS](https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ)--[COD](https://drive.google.com/drive/folders/1EyHmKWsXfaCR9O0BiZEc3roZbRcs4ECO)--[HRSOD](https://drive.google.com/drive/folders/18_hAE3QM4cwAzEAKXuSNtKjmgFXTQXZN) or the single official ones in the `single_ones` folder, or their official pages. You can also find the same ones on my **BaiduDisk**: [DIS](https://pan.baidu.com/s/1O_pQIGAE4DKqL93xOxHpxw?pwd=PSWD)--[COD](https://pan.baidu.com/s/1RnxAzaHSTGBC1N6r_RfeqQ?pwd=PSWD)--[HRSOD](https://pan.baidu.com/s/1_Del53_0lBuG0DKJJAk4UA?pwd=PSWD). - -#### Weights Preparation - -Download backbone weights from [my google-drive folder](https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM) or their official pages. - -## Run - -```shell -# Train & Test & Evaluation -./train_test.sh RUN_NAME GPU_NUMBERS_FOR_TRAINING GPU_NUMBERS_FOR_TEST -# Example: ./train_test.sh tmp-proj 0,1,2,3,4,5,6,7 0 - -# See train.sh / test.sh for only training / test-evaluation. -# After the evaluation, run `gen_best_ep.py` to select the best ckpt from a specific metric (you choose it from Sm, wFm, HCE (DIS only)). -``` - -#### Well-trained weights: - -Download the `BiRefNet-{TASK}-{EPOCH}.pth` from [[**stuff**](https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM)]. Info of the corresponding (predicted\_maps/performance/training\_log) weights can be also found in folders like `exp-BiRefNet-{TASK_SETTINGS}` in the same directory. - -You can also download the weights from the release of this repo. - -The results might be a bit different from those in the original paper, you can see them in the `eval_results-BiRefNet-{TASK_SETTINGS}` folder in each `exp-xx`, we will update them in the following days. Due to the very high cost I used (A100-80G x 8) which many people cannot afford to (including myself....), I re-trained BiRefNet on a single A100-40G only and achieve the performance on the same level (even better). It means you can directly train the model on a single GPU with 36.5G+ memory. BTW, 5.5G GPU memory is needed for inference in 1024x1024. (I personally paid a lot for renting an A100-40G to re-train BiRefNet on the three tasks... T_T. Hope it can help you.) - -But if you have more and more powerful GPUs, you can set GPU IDs and increase the batch size in `config.py` to accelerate the training. We have made all this kind of things adaptive in scripts to seamlessly switch between single-card training and multi-card training. Enjoy it :) - -#### Some of my messages: - -This project was originally built for DIS only. But after the updates one by one, I made it larger and larger with many functions embedded together. Finally, you can **use it for any binary image segmentation tasks**, such as DIS/COD/SOD, medical image segmentation, anomaly segmentation, etc. You can eaily open/close below things (usually in `config.py`): -+ Multi-GPU training: open/close with one variable. -+ Backbone choices: Swin_v1, PVT_v2, ConvNets, ... -+ Weighted losses: BCE, IoU, SSIM, MAE, Reg, ... -+ Adversarial loss for binary segmentation (proposed in my previous work [MCCL](https://arxiv.org/pdf/2302.14485)). -+ Training tricks: multi-scale supervision, freezing backbone, multi-scale input... -+ Data collator: loading all in memory, smooth combination of different datasets for combined training and test. -+ ... -I really hope you enjoy this project and use it in more works to achieve new SOTAs. - - -### Quantitative Results - -

- -

- - - -### Qualitative Results - -

- -

- - - -### Citation - -``` -@article{zheng2024birefnet, - title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation}, - author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu}, - journal={CAAI Artificial Intelligence Research}, - volume = {3}, - pages = {9150038}, - year={2024} -} -``` - - - -## Contact - -Any questions, discussions, or even complaints, feel free to leave issues here or send me e-mails (zhengpeng0108@gmail.com). You can also join the Discord Group (https://discord.gg/d9NN5sgFrq) or QQ Group (https://qm.qq.com/q/y6WPy7WOIK) if you want to talk a lot publicly. - diff --git a/py/BiRefNet/__init__.py b/py/BiRefNet/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/py/BiRefNet/config.py b/py/BiRefNet/config.py deleted file mode 100644 index c333160a..00000000 --- a/py/BiRefNet/config.py +++ /dev/null @@ -1,174 +0,0 @@ -import os -import math - - -class Config(): - def __init__(self) -> None: - # PATH settings - # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx - if os.name == 'nt': - self.sys_home_dir = os.environ['USERPROFILE'] # For windows system - else: - self.sys_home_dir = os.environ['HOME'] # For Linux system - - # TASK settings - self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'Matting'][0] - self.training_set = { - 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], - 'COD': 'TR-COD10K+TR-CAMO', - 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], - 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans', # leave DIS-VD for evaluation. - 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', - }[self.task] - self.prompt4loc = ['dense', 'sparse'][0] - - # Faster-Training settings - self.load_all = False # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data. - self.use_fp16 = False # It may cause nan in training. - self.compile = True and (not self.use_fp16) # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. - # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting. - # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607. - # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training. - self.precisionHigh = True - - # MODEL settings - self.ms_supervision = True - self.out_ref = self.ms_supervision and True - self.dec_ipt = True - self.dec_ipt_split = True - self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder - self.mul_scl_ipt = ['', 'add', 'cat'][2] - self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2] - self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1] - self.dec_blk = ['BasicDecBlk', 'ResBlk'][0] - - # TRAINING settings - self.batch_size = 4 - self.finetune_last_epochs = [ - ('IoU', 0), - { - 'DIS5K': ('IoU', -30), - 'COD': ('IoU', -20), - 'HRSOD': ('IoU', -20), - 'General': ('MAE', -10), - 'Matting': ('MAE', -10), - }[self.task] - ][1] # choose 0 to skip - self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly - self.size = 1024 - self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader - - # Backbone settings - self.bb = [ - 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2 - 'swin_v1_t', 'swin_v1_s', # 3, 4 - 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4 - 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8 - 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5 - ][6] - self.lateral_channels_in_collection = { - 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], - 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], - 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], - 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96], - 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64], - }[self.bb] - if self.mul_scl_ipt == 'cat': - self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection] - self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else [] - - # MODEL settings - inactive - self.lat_blk = ['BasicLatBlk'][0] - self.dec_channels_inter = ['fixed', 'adap'][0] - self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0] - self.progressive_ref = self.refine and True - self.ender = self.progressive_ref and False - self.scale = self.progressive_ref and 2 - self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`. - self.refine_iteration = 1 - self.freeze_bb = False - self.model = [ - 'BiRefNet', - ][0] - - # TRAINING settings - inactive - self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4] - self.optimizer = ['Adam', 'AdamW'][1] - self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch. - self.lr_decay_rate = 0.5 - # Loss - if self.task not in ['Matting']: - self.lambdas_pix_last = { - # not 0 means opening this loss - # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 - 'bce': 30 * 1, # high performance - 'iou': 0.5 * 1, # 0 / 255 - 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) - 'mae': 30 * 0, - 'mse': 30 * 0, # can smooth the saliency map - 'triplet': 3 * 0, - 'reg': 100 * 0, - 'ssim': 10 * 1, # help contours, - 'cnt': 5 * 0, # help contours - 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. - } - else: - self.lambdas_pix_last = { - # not 0 means opening this loss - # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 - 'bce': 30 * 0, # high performance - 'iou': 0.5 * 0, # 0 / 255 - 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) - 'mae': 100 * 1, - 'mse': 30 * 0, # can smooth the saliency map - 'triplet': 3 * 0, - 'reg': 100 * 0, - 'ssim': 10 * 1, # help contours, - 'cnt': 5 * 0, # help contours - 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. - } - self.lambdas_cls = { - 'ce': 5.0 - } - # Adv - self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training - self.lambda_adv_d = 3. * (self.lambda_adv_g > 0) - - # PATH settings - inactive - self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') - self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights') - self.weights = { - 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), - 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), - 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), - 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), - 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), - 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), - 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), - 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), - } - - # Callbacks - inactive - self.verbose_eval = True - self.only_S_MAE = False - self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs - - # others - self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0') - - self.batch_size_valid = 1 - self.rand_seed = 7 - run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] - if run_sh_file: - with open(run_sh_file[0], 'r') as f: - lines = f.readlines() - self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) - - def print_task(self) -> None: - # Return task for choosing settings in shell scripts. - print(self.task) - -if __name__ == '__main__': - config = Config() - config.print_task() - diff --git a/py/BiRefNet/dataset.py b/py/BiRefNet/dataset.py deleted file mode 100644 index c332b351..00000000 --- a/py/BiRefNet/dataset.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import cv2 -from tqdm import tqdm -from PIL import Image -from torch.utils import data -from torchvision import transforms - -from BiRefNet.image_proc import preproc -from BiRefNet.config import Config -from BiRefNet.utils import path_to_image - - -Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning -config = Config() -_class_labels_TR_sorted = ( - 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, ' - 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, ' - 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, ' - 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, ' - 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, ' - 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, ' - 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, ' - 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, ' - 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, ' - 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ' - 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, ' - 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, ' - 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, ' - 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht' -) -class_labels_TR_sorted = _class_labels_TR_sorted.split(', ') - - -class MyData(data.Dataset): - def __init__(self, datasets, image_size, is_train=True): - self.size_train = image_size - self.size_test = image_size - self.keep_size = not config.size - self.data_size = (config.size, config.size) - self.is_train = is_train - self.load_all = config.load_all - self.device = config.device - valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG'] - - if self.is_train and config.auxiliary_classification: - self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)} - self.transform_image = transforms.Compose([ - transforms.Resize(self.data_size), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ][self.load_all or self.keep_size:]) - self.transform_label = transforms.Compose([ - transforms.Resize(self.data_size), - transforms.ToTensor(), - ][self.load_all or self.keep_size:]) - dataset_root = os.path.join(config.data_root_dir, config.task) - # datasets can be a list of different datasets for training on combined sets. - self.image_paths = [] - for dataset in datasets.split('+'): - image_root = os.path.join(dataset_root, dataset, 'im') - self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)] - self.label_paths = [] - for p in self.image_paths: - for ext in valid_extensions: - ## 'im' and 'gt' may need modifying - p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext - file_exists = False - if os.path.exists(p_gt): - self.label_paths.append(p_gt) - file_exists = True - break - if not file_exists: - print('Not exists:', p_gt) - - if len(self.label_paths) != len(self.image_paths): - raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})") - - if self.load_all: - self.images_loaded, self.labels_loaded = [], [] - self.class_labels_loaded = [] - # for image_path, label_path in zip(self.image_paths, self.label_paths): - for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)): - _image = path_to_image(image_path, size=(config.size, config.size), color_type='rgb') - _label = path_to_image(label_path, size=(config.size, config.size), color_type='gray') - self.images_loaded.append(_image) - self.labels_loaded.append(_label) - self.class_labels_loaded.append( - self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 - ) - - def __getitem__(self, index): - - if self.load_all: - image = self.images_loaded[index] - label = self.labels_loaded[index] - class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1 - else: - image = path_to_image(self.image_paths[index], size=(config.size, config.size), color_type='rgb') - label = path_to_image(self.label_paths[index], size=(config.size, config.size), color_type='gray') - class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 - - # loading image and label - if self.is_train: - image, label = preproc(image, label, preproc_methods=config.preproc_methods) - # else: - # if _label.shape[0] > 2048 or _label.shape[1] > 2048: - # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR) - # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR) - - image, label = self.transform_image(image), self.transform_label(label) - - if self.is_train: - return image, label, class_label - else: - return image, label, self.label_paths[index] - - def __len__(self): - return len(self.image_paths) diff --git a/py/BiRefNet/eval_existingOnes.py b/py/BiRefNet/eval_existingOnes.py deleted file mode 100644 index f8d02c28..00000000 --- a/py/BiRefNet/eval_existingOnes.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -import argparse -from glob import glob -import prettytable as pt - -from BiRefNet.evaluation.evaluate import evaluator -from BiRefNet.config import Config - - -config = Config() - - -def do_eval(args): - # evaluation for whole dataset - # dataset first in evaluation - for _data_name in args.data_lst.split('+'): - pred_data_dir = sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name))) - if not pred_data_dir: - print('Skip dataset {}.'.format(_data_name)) - continue - gt_src = os.path.join(args.gt_root, _data_name) - gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*'))) - print('#' * 20, _data_name, '#' * 20) - filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name)) - tb = pt.PrettyTable() - tb.vertical_char = '&' - if config.task == 'DIS5K': - tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'] - elif config.task == 'COD': - tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] - elif config.task == 'HRSOD': - tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] - elif config.task == 'General': - tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'] - elif config.task == 'Matting': - tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MSE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] - else: - tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] - for _model_name in args.model_lst[:]: - print('\t', 'Evaluating model: {}...'.format(_model_name)) - pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths] - # print(pred_paths[:1], gt_paths[:1]) - em, sm, fm, mae, wfm, hce, mba, biou = evaluator( - gt_paths=gt_paths, - pred_paths=pred_paths, - metrics=args.metrics.split('+'), - verbose=config.verbose_eval - ) - if config.task == 'DIS5K': - scores = [ - fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), - em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - elif config.task == 'COD': - scores = [ - sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3), - fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - elif config.task == 'HRSOD': - scores = [ - sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3), - em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - elif config.task == 'General': - scores = [ - fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), - em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - elif config.task == 'Matting': - scores = [ - sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mse.round(3), - em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - else: - scores = [ - sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3), - fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), - em['adp'].round(3), fm['adp'].round(3), int(hce.round()), - mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), - ] - - for idx_score, score in enumerate(scores): - scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4') - records = [_data_name, _model_name] + scores - tb.add_row(records) - # Write results after every check. - with open(filename, 'w+') as file_to_write: - file_to_write.write(str(tb)+'\n') - print(tb) - - -if __name__ == '__main__': - # set parameters - parser = argparse.ArgumentParser() - parser.add_argument( - '--gt_root', type=str, help='ground-truth root', - default=os.path.join(config.data_root_dir, config.task)) - parser.add_argument( - '--pred_root', type=str, help='prediction root', - default='./e_preds') - parser.add_argument( - '--data_lst', type=str, help='test dataset', - default={ - 'DIS5K': '+'.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:]), - 'COD': '+'.join(['TE-COD10K', 'NC4K', 'TE-CAMO', 'CHAMELEON'][:]), - 'HRSOD': '+'.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'TE-DUTS', 'DUT-OMRON'][:]), - 'General': '+'.join(['DIS-VD'][:]), - 'Matting': '+'.join(['TE-P3M-500-P'][:]), - }[config.task]) - parser.add_argument( - '--save_dir', type=str, help='candidate competitors', - default='e_results') - parser.add_argument( - '--check_integrity', type=bool, help='whether to check the file integrity', - default=False) - parser.add_argument( - '--metrics', type=str, help='candidate competitors', - default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'HCE'][:100 if 'DIS5K' in config.task else -1])) - args = parser.parse_args() - args.metrics = '+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'HCE'][:100 if sum(['DIS-' in _data for _data in args.data_lst.split('+')]) else -1]) - - os.makedirs(args.save_dir, exist_ok=True) - try: - args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0] - except: - args.model_lst = [m for m in sorted(os.listdir(args.pred_root))] - - # check the integrity of each candidates - if args.check_integrity: - for _data_name in args.data_lst.split('+'): - for _model_name in args.model_lst: - gt_pth = os.path.join(args.gt_root, _data_name) - pred_pth = os.path.join(args.pred_root, _model_name, _data_name) - if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)): - print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth)))) - print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name)) - else: - print('>>> skip check the integrity of each candidates') - - # start engine - do_eval(args) diff --git a/py/BiRefNet/evaluation/metrics.py b/py/BiRefNet/evaluation/metrics.py deleted file mode 100644 index 76ebc45e..00000000 --- a/py/BiRefNet/evaluation/metrics.py +++ /dev/null @@ -1,763 +0,0 @@ -import os -from tqdm import tqdm -import cv2 -import numpy as np -from scipy.ndimage import convolve, distance_transform_edt as bwdist -from skimage.morphology import skeletonize -from skimage.morphology import disk -from skimage.measure import label - - -_EPS = np.spacing(1) -_TYPE = np.float64 - - -def evaluator(gt_paths, pred_paths, metrics=['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'HCE'], verbose=False): - # define measures - if 'E' in metrics: - EM = EMeasure() - if 'S' in metrics: - SM = SMeasure() - if 'F' in metrics: - FM = FMeasure() - if 'MAE' in metrics: - MAE = MAEMeasure() - if 'WF' in metrics: - WFM = WeightedFMeasure() - if 'HCE' in metrics: - HCE = HCEMeasure() - if 'MBA' in metrics: - MBA = MBAMeasure() - if 'BIoU' in metrics: - BIoU = BIoUMeasure() - - if isinstance(gt_paths, list) and isinstance(pred_paths, list): - # print(len(gt_paths), len(pred_paths)) - assert len(gt_paths) == len(pred_paths) - - for idx_sample in tqdm(range(len(gt_paths)), total=len(gt_paths)) if verbose else range(len(gt_paths)): - gt = gt_paths[idx_sample] - pred = pred_paths[idx_sample] - - pred = pred[:-4] + '.png' - valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG'] - file_exists = False - for ext in valid_extensions: - if os.path.exists(pred[:-4] + ext): - pred = pred[:-4] + ext - file_exists = True - break - if file_exists: - pred_ary = cv2.imread(pred, cv2.IMREAD_GRAYSCALE) - else: - print('Not exists:', pred) - - gt_ary = cv2.imread(gt, cv2.IMREAD_GRAYSCALE) - pred_ary = cv2.resize(pred_ary, (gt_ary.shape[1], gt_ary.shape[0])) - - if 'E' in metrics: - EM.step(pred=pred_ary, gt=gt_ary) - if 'S' in metrics: - SM.step(pred=pred_ary, gt=gt_ary) - if 'F' in metrics: - FM.step(pred=pred_ary, gt=gt_ary) - if 'MAE' in metrics: - MAE.step(pred=pred_ary, gt=gt_ary) - if 'WF' in metrics: - WFM.step(pred=pred_ary, gt=gt_ary) - if 'HCE' in metrics: - ske_path = gt.replace('/gt/', '/ske/') - if os.path.exists(ske_path): - ske_ary = cv2.imread(ske_path, cv2.IMREAD_GRAYSCALE) - ske_ary = ske_ary > 128 - else: - ske_ary = skeletonize(gt_ary > 128) - ske_save_dir = os.path.join(*ske_path.split(os.sep)[:-1]) - if ske_path[0] == os.sep: - ske_save_dir = os.sep + ske_save_dir - os.makedirs(ske_save_dir, exist_ok=True) - cv2.imwrite(ske_path, ske_ary.astype(np.uint8) * 255) - HCE.step(pred=pred_ary, gt=gt_ary, gt_ske=ske_ary) - if 'MBA' in metrics: - MBA.step(pred=pred_ary, gt=gt_ary) - if 'BIoU' in metrics: - BIoU.step(pred=pred_ary, gt=gt_ary) - - if 'E' in metrics: - em = EM.get_results()['em'] - else: - em = {'curve': np.array([np.float64(-1)]), 'adp': np.float64(-1)} - if 'S' in metrics: - sm = SM.get_results()['sm'] - else: - sm = np.float64(-1) - if 'F' in metrics: - fm = FM.get_results()['fm'] - else: - fm = {'curve': np.array([np.float64(-1)]), 'adp': np.float64(-1)} - if 'MAE' in metrics: - mae = MAE.get_results()['mae'] - else: - mae = np.float64(-1) - if 'WF' in metrics: - wfm = WFM.get_results()['wfm'] - else: - wfm = np.float64(-1) - if 'HCE' in metrics: - hce = HCE.get_results()['hce'] - else: - hce = np.float64(-1) - if 'MBA' in metrics: - mba = MBA.get_results()['mba'] - else: - mba = np.float64(-1) - if 'BIoU' in metrics: - biou = BIoU.get_results()['biou'] - else: - biou = {'curve': np.array([np.float64(-1)])} - - return em, sm, fm, mae, wfm, hce, mba, biou - - -def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: - gt = gt > 128 - pred = pred / 255 - if pred.max() != pred.min(): - pred = (pred - pred.min()) / (pred.max() - pred.min()) - return pred, gt - - -def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: - return min(2 * matrix.mean(), max_value) - - -class FMeasure(object): - def __init__(self, beta: float = 0.3): - self.beta = beta - self.precisions = [] - self.recalls = [] - self.adaptive_fms = [] - self.changeable_fms = [] - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred, gt) - - adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) - self.adaptive_fms.append(adaptive_fm) - - precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) - self.precisions.append(precisions) - self.recalls.append(recalls) - self.changeable_fms.append(changeable_fms) - - def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: - adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) - binary_predcition = pred >= adaptive_threshold - area_intersection = binary_predcition[gt].sum() - if area_intersection == 0: - adaptive_fm = 0 - else: - pre = area_intersection / np.count_nonzero(binary_predcition) - rec = area_intersection / np.count_nonzero(gt) - adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) - return adaptive_fm - - def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: - pred = (pred * 255).astype(np.uint8) - bins = np.linspace(0, 256, 257) - fg_hist, _ = np.histogram(pred[gt], bins=bins) - bg_hist, _ = np.histogram(pred[~gt], bins=bins) - fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) - bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) - TPs = fg_w_thrs - Ps = fg_w_thrs + bg_w_thrs - Ps[Ps == 0] = 1 - T = max(np.count_nonzero(gt), 1) - precisions = TPs / Ps - recalls = TPs / T - numerator = (1 + self.beta) * precisions * recalls - denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) - changeable_fms = numerator / denominator - return precisions, recalls, changeable_fms - - def get_results(self) -> dict: - adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) - changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) - precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 - recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 - return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), - pr=dict(p=precision, r=recall)) - - -class MAEMeasure(object): - def __init__(self): - self.maes = [] - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred, gt) - - mae = self.cal_mae(pred, gt) - self.maes.append(mae) - - def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float: - mae = np.mean(np.abs(pred - gt)) - return mae - - def get_results(self) -> dict: - mae = np.mean(np.array(self.maes, _TYPE)) - return dict(mae=mae) - - -class SMeasure(object): - def __init__(self, alpha: float = 0.5): - self.sms = [] - self.alpha = alpha - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred=pred, gt=gt) - - sm = self.cal_sm(pred, gt) - self.sms.append(sm) - - def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: - y = np.mean(gt) - if y == 0: - sm = 1 - np.mean(pred) - elif y == 1: - sm = np.mean(pred) - else: - sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) - sm = max(0, sm) - return sm - - def object(self, pred: np.ndarray, gt: np.ndarray) -> float: - fg = pred * gt - bg = (1 - pred) * (1 - gt) - u = np.mean(gt) - object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) - return object_score - - def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: - x = np.mean(pred[gt == 1]) - sigma_x = np.std(pred[gt == 1], ddof=1) - score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) - return score - - def region(self, pred: np.ndarray, gt: np.ndarray) -> float: - x, y = self.centroid(gt) - part_info = self.divide_with_xy(pred, gt, x, y) - w1, w2, w3, w4 = part_info['weight'] - pred1, pred2, pred3, pred4 = part_info['pred'] - gt1, gt2, gt3, gt4 = part_info['gt'] - score1 = self.ssim(pred1, gt1) - score2 = self.ssim(pred2, gt2) - score3 = self.ssim(pred3, gt3) - score4 = self.ssim(pred4, gt4) - - return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 - - def centroid(self, matrix: np.ndarray) -> tuple: - h, w = matrix.shape - area_object = np.count_nonzero(matrix) - if area_object == 0: - x = np.round(w / 2) - y = np.round(h / 2) - else: - # More details can be found at: https://www.yuque.com/lart/blog/gpbigm - y, x = np.argwhere(matrix).mean(axis=0).round() - return int(x) + 1, int(y) + 1 - - def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: - h, w = gt.shape - area = h * w - - gt_LT = gt[0:y, 0:x] - gt_RT = gt[0:y, x:w] - gt_LB = gt[y:h, 0:x] - gt_RB = gt[y:h, x:w] - - pred_LT = pred[0:y, 0:x] - pred_RT = pred[0:y, x:w] - pred_LB = pred[y:h, 0:x] - pred_RB = pred[y:h, x:w] - - w1 = x * y / area - w2 = y * (w - x) / area - w3 = (h - y) * x / area - w4 = 1 - w1 - w2 - w3 - - return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), - pred=(pred_LT, pred_RT, pred_LB, pred_RB), - weight=(w1, w2, w3, w4)) - - def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: - h, w = pred.shape - N = h * w - - x = np.mean(pred) - y = np.mean(gt) - - sigma_x = np.sum((pred - x) ** 2) / (N - 1) - sigma_y = np.sum((gt - y) ** 2) / (N - 1) - sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) - - alpha = 4 * x * y * sigma_xy - beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) - - if alpha != 0: - score = alpha / (beta + _EPS) - elif alpha == 0 and beta == 0: - score = 1 - else: - score = 0 - return score - - def get_results(self) -> dict: - sm = np.mean(np.array(self.sms, dtype=_TYPE)) - return dict(sm=sm) - - -class EMeasure(object): - def __init__(self): - self.adaptive_ems = [] - self.changeable_ems = [] - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred=pred, gt=gt) - self.gt_fg_numel = np.count_nonzero(gt) - self.gt_size = gt.shape[0] * gt.shape[1] - - changeable_ems = self.cal_changeable_em(pred, gt) - self.changeable_ems.append(changeable_ems) - adaptive_em = self.cal_adaptive_em(pred, gt) - self.adaptive_ems.append(adaptive_em) - - def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: - adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) - adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) - return adaptive_em - - def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: - changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) - return changeable_ems - - def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: - binarized_pred = pred >= threshold - fg_fg_numel = np.count_nonzero(binarized_pred & gt) - fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) - - fg___numel = fg_fg_numel + fg_bg_numel - bg___numel = self.gt_size - fg___numel - - if self.gt_fg_numel == 0: - enhanced_matrix_sum = bg___numel - elif self.gt_fg_numel == self.gt_size: - enhanced_matrix_sum = fg___numel - else: - parts_numel, combinations = self.generate_parts_numel_combinations( - fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, - pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, - ) - - results_parts = [] - for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): - align_matrix_value = 2 * (combination[0] * combination[1]) / \ - (combination[0] ** 2 + combination[1] ** 2 + _EPS) - enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 - results_parts.append(enhanced_matrix_value * part_numel) - enhanced_matrix_sum = sum(results_parts) - - em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) - return em - - def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: - pred = (pred * 255).astype(np.uint8) - bins = np.linspace(0, 256, 257) - fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) - fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) - fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) - fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) - - fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs - bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs - - if self.gt_fg_numel == 0: - enhanced_matrix_sum = bg___numel_w_thrs - elif self.gt_fg_numel == self.gt_size: - enhanced_matrix_sum = fg___numel_w_thrs - else: - parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( - fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, - pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, - ) - - results_parts = np.empty(shape=(4, 256), dtype=np.float64) - for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): - align_matrix_value = 2 * (combination[0] * combination[1]) / \ - (combination[0] ** 2 + combination[1] ** 2 + _EPS) - enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 - results_parts[i] = enhanced_matrix_value * part_numel - enhanced_matrix_sum = results_parts.sum(axis=0) - - em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) - return em - - def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): - bg_fg_numel = self.gt_fg_numel - fg_fg_numel - bg_bg_numel = pred_bg_numel - bg_fg_numel - - parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] - - mean_pred_value = pred_fg_numel / self.gt_size - mean_gt_value = self.gt_fg_numel / self.gt_size - - demeaned_pred_fg_value = 1 - mean_pred_value - demeaned_pred_bg_value = 0 - mean_pred_value - demeaned_gt_fg_value = 1 - mean_gt_value - demeaned_gt_bg_value = 0 - mean_gt_value - - combinations = [ - (demeaned_pred_fg_value, demeaned_gt_fg_value), - (demeaned_pred_fg_value, demeaned_gt_bg_value), - (demeaned_pred_bg_value, demeaned_gt_fg_value), - (demeaned_pred_bg_value, demeaned_gt_bg_value) - ] - return parts_numel, combinations - - def get_results(self) -> dict: - adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) - changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) - return dict(em=dict(adp=adaptive_em, curve=changeable_em)) - - -class WeightedFMeasure(object): - def __init__(self, beta: float = 1): - self.beta = beta - self.weighted_fms = [] - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred=pred, gt=gt) - - if np.all(~gt): - wfm = 0 - else: - wfm = self.cal_wfm(pred, gt) - self.weighted_fms.append(wfm) - - def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: - # [Dst,IDXT] = bwdist(dGT); - Dst, Idxt = bwdist(gt == 0, return_indices=True) - - # %Pixel dependency - # E = abs(FG-dGT); - E = np.abs(pred - gt) - Et = np.copy(E) - Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] - - # K = fspecial('gaussian',7,5); - # EA = imfilter(Et,K); - K = self.matlab_style_gauss2D((7, 7), sigma=5) - EA = convolve(Et, weights=K, mode="constant", cval=0) - # MIN_E_EA = E; - # MIN_E_EA(GT & EA np.ndarray: - """ - 2D gaussian mask - should give the same result as MATLAB's - fspecial('gaussian',[shape],[sigma]) - """ - m, n = [(ss - 1) / 2 for ss in shape] - y, x = np.ogrid[-m: m + 1, -n: n + 1] - h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) - h[h < np.finfo(h.dtype).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h /= sumh - return h - - def get_results(self) -> dict: - weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) - return dict(wfm=weighted_fm) - - -class HCEMeasure(object): - def __init__(self): - self.hces = [] - - def step(self, pred: np.ndarray, gt: np.ndarray, gt_ske): - # pred, gt = _prepare_data(pred, gt) - - hce = self.cal_hce(pred, gt, gt_ske) - self.hces.append(hce) - - def get_results(self) -> dict: - hce = np.mean(np.array(self.hces, _TYPE)) - return dict(hce=hce) - - - def cal_hce(self, pred: np.ndarray, gt: np.ndarray, gt_ske: np.ndarray, relax=5, epsilon=2.0) -> float: - # Binarize gt - if(len(gt.shape)>2): - gt = gt[:, :, 0] - - epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0 - gt = (gt>epsilon_gt).astype(np.uint8) - - # Binarize pred - if(len(pred.shape)>2): - pred = pred[:, :, 0] - epsilon_pred = 128#(np.amin(pred)+np.amax(pred))/2.0 - pred = (pred>epsilon_pred).astype(np.uint8) - - Union = np.logical_or(gt, pred) - TP = np.logical_and(gt, pred) - FP = pred - TP - FN = gt - TP - - # relax the Union of gt and pred - Union_erode = Union.copy() - Union_erode = cv2.erode(Union_erode.astype(np.uint8), disk(1), iterations=relax) - - # --- get the relaxed False Positive regions for computing the human efforts in correcting them --- - FP_ = np.logical_and(FP, Union_erode) # get the relaxed FP - for i in range(0, relax): - FP_ = cv2.dilate(FP_.astype(np.uint8), disk(1)) - FP_ = np.logical_and(FP_, 1-np.logical_or(TP, FN)) - FP_ = np.logical_and(FP, FP_) - - # --- get the relaxed False Negative regions for computing the human efforts in correcting them --- - FN_ = np.logical_and(FN, Union_erode) # preserve the structural components of FN - ## recover the FN, where pixels are not close to the TP borders - for i in range(0, relax): - FN_ = cv2.dilate(FN_.astype(np.uint8), disk(1)) - FN_ = np.logical_and(FN_, 1-np.logical_or(TP, FP)) - FN_ = np.logical_and(FN, FN_) - FN_ = np.logical_or(FN_, np.logical_xor(gt_ske, np.logical_and(TP, gt_ske))) # preserve the structural components of FN - - ## 2. =============Find exact polygon control points and independent regions============== - ## find contours from FP_ - ctrs_FP, hier_FP = cv2.findContours(FP_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - ## find control points and independent regions for human correction - bdies_FP, indep_cnt_FP = self.filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_)) - ## find contours from FN_ - ctrs_FN, hier_FN = cv2.findContours(FN_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - ## find control points and independent regions for human correction - bdies_FN, indep_cnt_FN = self.filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP, FP_), FN_)) - - poly_FP, poly_FP_len, poly_FP_point_cnt = self.approximate_RDP(bdies_FP, epsilon=epsilon) - poly_FN, poly_FN_len, poly_FN_point_cnt = self.approximate_RDP(bdies_FN, epsilon=epsilon) - - # FP_points+FP_indep+FN_points+FN_indep - return poly_FP_point_cnt+indep_cnt_FP+poly_FN_point_cnt+indep_cnt_FN - - def filter_bdy_cond(self, bdy_, mask, cond): - - cond = cv2.dilate(cond.astype(np.uint8), disk(1)) - labels = label(mask) # find the connected regions - lbls = np.unique(labels) # the indices of the connected regions - indep = np.ones(lbls.shape[0]) # the label of each connected regions - indep[0] = 0 # 0 indicate the background region - - boundaries = [] - h,w = cond.shape[0:2] - ind_map = np.zeros((h, w)) - indep_cnt = 0 - - for i in range(0, len(bdy_)): - tmp_bdies = [] - tmp_bdy = [] - for j in range(0, bdy_[i].shape[0]): - r, c = bdy_[i][j,0,1],bdy_[i][j,0,0] - - if(np.sum(cond[r, c])==0 or ind_map[r, c]!=0): - if(len(tmp_bdy)>0): - tmp_bdies.append(tmp_bdy) - tmp_bdy = [] - continue - tmp_bdy.append([c, r]) - ind_map[r, c] = ind_map[r, c] + 1 - indep[labels[r, c]] = 0 # indicates part of the boundary of this region needs human correction - if(len(tmp_bdy)>0): - tmp_bdies.append(tmp_bdy) - - # check if the first and the last boundaries are connected - # if yes, invert the first boundary and attach it after the last boundary - if(len(tmp_bdies)>1): - first_x, first_y = tmp_bdies[0][0] - last_x, last_y = tmp_bdies[-1][-1] - if((abs(first_x-last_x)==1 and first_y==last_y) or - (first_x==last_x and abs(first_y-last_y)==1) or - (abs(first_x-last_x)==1 and abs(first_y-last_y)==1) - ): - tmp_bdies[-1].extend(tmp_bdies[0][::-1]) - del tmp_bdies[0] - - for k in range(0, len(tmp_bdies)): - tmp_bdies[k] = np.array(tmp_bdies[k])[:, np.newaxis, :] - if(len(tmp_bdies)>0): - boundaries.extend(tmp_bdies) - - return boundaries, np.sum(indep) - - # this function approximate each boundary by DP algorithm - # https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm - def approximate_RDP(self, boundaries, epsilon=1.0): - - boundaries_ = [] - boundaries_len_ = [] - pixel_cnt_ = 0 - - # polygon approximate of each boundary - for i in range(0, len(boundaries)): - boundaries_.append(cv2.approxPolyDP(boundaries[i], epsilon, False)) - - # count the control points number of each boundary and the total control points number of all the boundaries - for i in range(0, len(boundaries_)): - boundaries_len_.append(len(boundaries_[i])) - pixel_cnt_ = pixel_cnt_ + len(boundaries_[i]) - - return boundaries_, boundaries_len_, pixel_cnt_ - - -class MBAMeasure(object): - def __init__(self): - self.bas = [] - self.all_h = 0 - self.all_w = 0 - self.all_max = 0 - - def step(self, pred: np.ndarray, gt: np.ndarray): - # pred, gt = _prepare_data(pred, gt) - - refined = gt.copy() - - rmin = cmin = 0 - rmax, cmax = gt.shape - - self.all_h += rmax - self.all_w += cmax - self.all_max += max(rmax, cmax) - - refined_h, refined_w = refined.shape - if refined_h != cmax: - refined = np.array(Image.fromarray(pred).resize((cmax, rmax), Image.BILINEAR)) - - if not(gt.sum() < 32*32): - if not((cmax==cmin) or (rmax==rmin)): - class_refined_prob = np.array(Image.fromarray(pred).resize((cmax-cmin, rmax-rmin), Image.BILINEAR)) - refined[rmin:rmax, cmin:cmax] = class_refined_prob - - pred = pred > 128 - gt = gt > 128 - - ba = self.cal_ba(pred, gt) - self.bas.append(ba) - - def get_disk_kernel(self, radius): - return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius*2+1, radius*2+1)) - - def cal_ba(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: - """ - Calculate the mean absolute error. - - :return: ba - """ - - gt = gt.astype(np.uint8) - pred = pred.astype(np.uint8) - - h, w = gt.shape - - min_radius = 1 - max_radius = (w+h)/300 - num_steps = 5 - - pred_acc = [None] * num_steps - - for i in range(num_steps): - curr_radius = min_radius + int((max_radius-min_radius)/num_steps*i) - - kernel = self.get_disk_kernel(curr_radius) - boundary_region = cv2.morphologyEx(gt, cv2.MORPH_GRADIENT, kernel) > 0 - - gt_in_bound = gt[boundary_region] - pred_in_bound = pred[boundary_region] - - num_edge_pixels = (boundary_region).sum() - num_pred_gd_pix = ((gt_in_bound) * (pred_in_bound) + (1-gt_in_bound) * (1-pred_in_bound)).sum() - - pred_acc[i] = num_pred_gd_pix / num_edge_pixels - - ba = sum(pred_acc)/num_steps - return ba - - def get_results(self) -> dict: - mba = np.mean(np.array(self.bas, _TYPE)) - return dict(mba=mba) - - -class BIoUMeasure(object): - def __init__(self, dilation_ratio=0.02): - self.bious = [] - self.dilation_ratio = dilation_ratio - - def mask_to_boundary(self, mask): - h, w = mask.shape - img_diag = np.sqrt(h ** 2 + w ** 2) - dilation = int(round(self.dilation_ratio * img_diag)) - if dilation < 1: - dilation = 1 - # Pad image so mask truncated by the image border is also considered as boundary. - new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) - kernel = np.ones((3, 3), dtype=np.uint8) - new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) - mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] - # G_d intersects G in the paper. - return mask - mask_erode - - def step(self, pred: np.ndarray, gt: np.ndarray): - pred, gt = _prepare_data(pred, gt) - - bious = self.cal_biou(pred=pred, gt=gt) - self.bious.append(bious) - - def cal_biou(self, pred, gt): - pred = (pred * 255).astype(np.uint8) - pred = self.mask_to_boundary(pred) - gt = (gt * 255).astype(np.uint8) - gt = self.mask_to_boundary(gt) - gt = gt > 128 - - bins = np.linspace(0, 256, 257) - fg_hist, _ = np.histogram(pred[gt], bins=bins) # ture positive - bg_hist, _ = np.histogram(pred[~gt], bins=bins) # false positive - fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) - bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) - TPs = fg_w_thrs - Ps = fg_w_thrs + bg_w_thrs # positives - Ps[Ps == 0] = 1 - T = max(np.count_nonzero(gt), 1) - - ious = TPs / (T + bg_w_thrs) - return ious - - def get_results(self) -> dict: - biou = np.mean(np.array(self.bious, dtype=_TYPE), axis=0) - return dict(biou=dict(curve=biou)) diff --git a/py/BiRefNet/gen_best_ep.py b/py/BiRefNet/gen_best_ep.py deleted file mode 100644 index 6c2e4bde..00000000 --- a/py/BiRefNet/gen_best_ep.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -from glob import glob -import numpy as np - -from BiRefNet.config import Config - - -config = Config() - -eval_txts = sorted(glob('e_results/*_eval.txt')) -print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts]) -score_panel = {} -sep = '&' -metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others. -if 'DIS5K' not in config.task: - metrics.remove('hce') - -for metric in metrics: - print('Metric:', metric) - current_line_nums = [] - for idx_et, eval_txt in enumerate(eval_txts): - with open(eval_txt, 'r') as f: - lines = [l for l in f.readlines()[3:] if '.' in l] - current_line_nums.append(len(lines)) - for idx_et, eval_txt in enumerate(eval_txts): - with open(eval_txt, 'r') as f: - lines = [l for l in f.readlines()[3:] if '.' in l] - for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file. - properties = line.strip().strip(sep).split(sep) - dataset = properties[0].strip() - ckpt = properties[1].strip() - if int(ckpt.split('--epoch_')[-1].strip()) < 0: - continue - targe_idx = { - 'sm': [5, 2, 2, 5, 2], - 'wfm': [3, 3, 8, 3, 8], - 'hce': [7, -1, -1, 7, -1] - }[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'Matting'].index(config.task)] - if metric != 'hce': - score_sm = float(properties[targe_idx].strip()) - else: - score_sm = int(properties[targe_idx].strip().strip('.')) - if idx_et == 0: - score_panel[ckpt] = [] - score_panel[ckpt].append(score_sm) - - metrics_min = ['hce', 'mae'] - max_or_min = min if metric in metrics_min else max - score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x)) - - good_models = [] - for k, v in score_panel.items(): - if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)): - print(k, v) - good_models.append(k) - - # Write - with open(eval_txt, 'r') as f: - lines = f.readlines() - info4good_models = lines[:3] - metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]] - testset_mean_values = {metric_name: [] for metric_name in metric_names} - for good_model in good_models: - for idx_et, eval_txt in enumerate(eval_txts): - with open(eval_txt, 'r') as f: - lines = f.readlines() - for line in lines: - if set([good_model]) & set([_.strip() for _ in line.split(sep)]): - info4good_models.append(line) - metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]] - for idx_score, metric_score in enumerate(metric_scores): - testset_mean_values[metric_names[idx_score]].append(metric_score) - - if 'DIS5K' in config.task: - testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD - sample_line_for_placing_mean_values = info4good_models[-2] - numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:] - for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)): - numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value) - testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n' - info4good_models.append(testset_mean_line) - info4good_models.append(lines[-1]) - info = ''.join(info4good_models) - print(info) - with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f: - f.write(info + '\n') diff --git a/py/BiRefNet/image_proc.py b/py/BiRefNet/image_proc.py deleted file mode 100644 index 2ebfbfaf..00000000 --- a/py/BiRefNet/image_proc.py +++ /dev/null @@ -1,119 +0,0 @@ -import random -from PIL import Image, ImageEnhance -import numpy as np -import cv2 - - -def refine_foreground(image, mask, r=90): - if mask.size != image.size: - mask = mask.resize(image.size) - image = np.array(image) / 255.0 - mask = np.array(mask) / 255.0 - estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) - image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) - return image_masked - - -def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): - # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation - alpha = alpha[:, :, None] - F, blur_B = FB_blur_fusion_foreground_estimator( - image, image, image, alpha, r) - return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] - - -def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): - if isinstance(image, Image.Image): - image = np.array(image) / 255.0 - blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] - - blurred_FA = cv2.blur(F * alpha, (r, r)) - blurred_F = blurred_FA / (blurred_alpha + 1e-5) - - blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) - blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) - F = blurred_F + alpha * \ - (image - alpha * blurred_F - (1 - alpha) * blurred_B) - F = np.clip(F, 0, 1) - return F, blurred_B - - -def preproc(image, label, preproc_methods=['flip']): - if 'flip' in preproc_methods: - image, label = cv_random_flip(image, label) - if 'crop' in preproc_methods: - image, label = random_crop(image, label) - if 'rotate' in preproc_methods: - image, label = random_rotate(image, label) - if 'enhance' in preproc_methods: - image = color_enhance(image) - if 'pepper' in preproc_methods: - label = random_pepper(label) - return image, label - - -def cv_random_flip(img, label): - if random.random() > 0.5: - img = img.transpose(Image.FLIP_LEFT_RIGHT) - label = label.transpose(Image.FLIP_LEFT_RIGHT) - return img, label - - -def random_crop(image, label): - border = 30 - image_width = image.size[0] - image_height = image.size[1] - border = int(min(image_width, image_height) * 0.1) - crop_win_width = np.random.randint(image_width - border, image_width) - crop_win_height = np.random.randint(image_height - border, image_height) - random_region = ( - (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, - (image_height + crop_win_height) >> 1) - return image.crop(random_region), label.crop(random_region) - - -def random_rotate(image, label, angle=15): - mode = Image.BICUBIC - if random.random() > 0.8: - random_angle = np.random.randint(-angle, angle) - image = image.rotate(random_angle, mode) - label = label.rotate(random_angle, mode) - return image, label - - -def color_enhance(image): - bright_intensity = random.randint(5, 15) / 10.0 - image = ImageEnhance.Brightness(image).enhance(bright_intensity) - contrast_intensity = random.randint(5, 15) / 10.0 - image = ImageEnhance.Contrast(image).enhance(contrast_intensity) - color_intensity = random.randint(0, 20) / 10.0 - image = ImageEnhance.Color(image).enhance(color_intensity) - sharp_intensity = random.randint(0, 30) / 10.0 - image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) - return image - - -def random_gaussian(image, mean=0.1, sigma=0.35): - def gaussianNoisy(im, mean=mean, sigma=sigma): - for _i in range(len(im)): - im[_i] += random.gauss(mean, sigma) - return im - - img = np.asarray(image) - width, height = img.shape - img = gaussianNoisy(img[:].flatten(), mean, sigma) - img = img.reshape([width, height]) - return Image.fromarray(np.uint8(img)) - - -def random_pepper(img, N=0.0015): - img = np.array(img) - noiseNum = int(N * img.shape[0] * img.shape[1]) - for i in range(noiseNum): - randX = random.randint(0, img.shape[0] - 1) - randY = random.randint(0, img.shape[1] - 1) - if random.randint(0, 1) == 0: - img[randX, randY] = 0 - else: - img[randX, randY] = 255 - return Image.fromarray(img) diff --git a/py/BiRefNet/inference.py b/py/BiRefNet/inference.py deleted file mode 100644 index c56ae1a6..00000000 --- a/py/BiRefNet/inference.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import argparse -from glob import glob -from tqdm import tqdm -import cv2 -import torch - -from BiRefNet.dataset import MyData -from BiRefNet.models.birefnet import BiRefNet -from BiRefNet.utils import save_tensor_img, check_state_dict -from BiRefNet.config import Config - - -config = Config() - - -def inference(model, data_loader_test, pred_root, method, testset, device=0): - model_training = model.training - if model_training: - model.eval() - for batch in tqdm(data_loader_test, total=len(data_loader_test)) if 1 or config.verbose_eval else data_loader_test: - inputs = batch[0].to(device) - # gts = batch[1].to(device) - label_paths = batch[-1] - with torch.no_grad(): - scaled_preds = model(inputs)[-1].sigmoid() - - os.makedirs(os.path.join(pred_root, method, testset), exist_ok=True) - - for idx_sample in range(scaled_preds.shape[0]): - res = torch.nn.functional.interpolate( - scaled_preds[idx_sample].unsqueeze(0), - size=cv2.imread(label_paths[idx_sample], cv2.IMREAD_GRAYSCALE).shape[:2], - mode='bilinear', - align_corners=True - ) - save_tensor_img(res, os.path.join(os.path.join(pred_root, method, testset), label_paths[idx_sample].replace('\\', '/').split('/')[-1])) # test set dir + file name - if model_training: - model.train() - return None - - -def main(args): - # Init model - - device = config.device - if args.ckpt_folder: - print('Testing with models in {}'.format(args.ckpt_folder)) - else: - print('Testing with model {}'.format(args.ckpt)) - - if config.model == 'BiRefNet': - model = BiRefNet(bb_pretrained=False) - weights_lst = sorted( - glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt], - key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]), - reverse=True - ) - for testset in args.testsets.split('+'): - print('>>>> Testset: {}...'.format(testset)) - data_loader_test = torch.utils.data.DataLoader( - dataset=MyData(testset, image_size=config.size, is_train=False), - batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True - ) - for weights in weights_lst: - if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0: - continue - print('\tInferencing {}...'.format(weights)) - # model.load_state_dict(torch.load(weights, map_location='cpu')) - state_dict = torch.load(weights, map_location='cpu') - state_dict = check_state_dict(state_dict) - model.load_state_dict(state_dict) - model = model.to(device) - inference( - model, data_loader_test=data_loader_test, pred_root=args.pred_root, - method='--'.join([w.rstrip('.pth') for w in weights.split(os.sep)[-2:]]), - testset=testset, device=config.device - ) - - -if __name__ == '__main__': - # Parameter from command line - parser = argparse.ArgumentParser(description='') - parser.add_argument('--ckpt', type=str, help='model folder') - parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpt', '*')))[-1], type=str, help='model folder') - parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder') - parser.add_argument('--testsets', - default={ - 'DIS5K': 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4', - 'COD': 'TE-COD10K+NC4K+TE-CAMO+CHAMELEON', - 'HRSOD': 'DAVIS-S+TE-HRSOD+TE-UHRSD+TE-DUTS+DUT-OMRON', - 'General': 'DIS-VD', - 'Matting': 'TE-P3M-500-P', - 'DIS5K-': 'DIS-VD', - 'COD-': 'TE-COD10K', - 'SOD-': 'DAVIS-S+TE-HRSOD+TE-UHRSD', - }[config.task + ''], - type=str, - help="Test all sets: , 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'") - - args = parser.parse_args() - - if config.precisionHigh: - torch.set_float32_matmul_precision('high') - main(args) diff --git a/py/BiRefNet/loss.py b/py/BiRefNet/loss.py deleted file mode 100644 index 02d0cd00..00000000 --- a/py/BiRefNet/loss.py +++ /dev/null @@ -1,277 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from torch.autograd import Variable -from math import exp - -from BiRefNet.config import Config - - -class Discriminator(nn.Module): - def __init__(self, channels=1, img_size=256): - super(Discriminator, self).__init__() - - def discriminator_block(in_filters, out_filters, bn=Config().batch_size > 1): - block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] - if bn: - block.append(nn.BatchNorm2d(out_filters, 0.8)) - return block - - self.model = nn.Sequential( - *discriminator_block(channels, 16, bn=False), - *discriminator_block(16, 32), - *discriminator_block(32, 64), - *discriminator_block(64, 128), - ) - - # The height and width of downsampled image - ds_size = img_size // 2 ** 4 - self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) - - def forward(self, img): - out = self.model(img) - out = out.view(out.shape[0], -1) - validity = self.adv_layer(out) - - return validity - - -class ContourLoss(torch.nn.Module): - def __init__(self): - super(ContourLoss, self).__init__() - - def forward(self, pred, target, weight=10): - ''' - target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1, - target[:,:,region_out_contour] == 0. - weight: scalar, length term weight. - ''' - # length term - delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W) - delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1) - - delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2) - delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2) - delta_pred = torch.abs(delta_r + delta_c) - - epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice. - length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum. - - c_in = torch.ones_like(pred) - c_out = torch.zeros_like(pred) - - region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum. - region_out = torch.mean( (1-pred) * (target - c_out)**2 ) - region = region_in + region_out - - loss = weight * length + region - - return loss - - -class IoULoss(torch.nn.Module): - def __init__(self): - super(IoULoss, self).__init__() - - def forward(self, pred, target): - b = pred.shape[0] - IoU = 0.0 - for i in range(0, b): - # compute the IoU of the foreground - Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :]) - Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1 - IoU1 = Iand1 / Ior1 - # IoU loss is (1-IoU1) - IoU = IoU + (1-IoU1) - # return IoU/b - return IoU - - -class StructureLoss(torch.nn.Module): - def __init__(self): - super(StructureLoss, self).__init__() - - def forward(self, pred, target): - weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target) - wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none') - wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) - - pred = torch.sigmoid(pred) - inter = ((pred * target) * weit).sum(dim=(2, 3)) - union = ((pred + target) * weit).sum(dim=(2, 3)) - wiou = 1-(inter+1)/(union-inter+1) - - return (wbce+wiou).mean() - - -class PatchIoULoss(torch.nn.Module): - def __init__(self): - super(PatchIoULoss, self).__init__() - self.iou_loss = IoULoss() - - def forward(self, pred, target): - win_y, win_x = 64, 64 - iou_loss = 0. - for anchor_y in range(0, target.shape[0], win_y): - for anchor_x in range(0, target.shape[1], win_y): - patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] - patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] - patch_iou_loss = self.iou_loss(patch_pred, patch_target) - iou_loss += patch_iou_loss - return iou_loss - - -class ThrReg_loss(torch.nn.Module): - def __init__(self): - super(ThrReg_loss, self).__init__() - - def forward(self, pred, gt=None): - return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2)) - - -class ClsLoss(nn.Module): - """ - Auxiliary classification loss for each refined class output. - """ - def __init__(self): - super(ClsLoss, self).__init__() - self.config = Config() - self.lambdas_cls = self.config.lambdas_cls - - self.criterions_last = { - 'ce': nn.CrossEntropyLoss() - } - - def forward(self, preds, gt): - loss = 0. - for _, pred_lvl in enumerate(preds): - if pred_lvl is None: - continue - for criterion_name, criterion in self.criterions_last.items(): - loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name] - return loss - - -class PixLoss(nn.Module): - """ - Pixel loss for each refined map output. - """ - def __init__(self): - super(PixLoss, self).__init__() - self.config = Config() - self.lambdas_pix_last = self.config.lambdas_pix_last - - self.criterions_last = {} - if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']: - self.criterions_last['bce'] = nn.BCELoss() if not self.config.use_fp16 else nn.BCEWithLogitsLoss() - if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']: - self.criterions_last['iou'] = IoULoss() - if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']: - self.criterions_last['iou_patch'] = PatchIoULoss() - if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']: - self.criterions_last['ssim'] = SSIMLoss() - if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']: - self.criterions_last['mae'] = nn.L1Loss() - if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']: - self.criterions_last['mse'] = nn.MSELoss() - if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']: - self.criterions_last['reg'] = ThrReg_loss() - if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']: - self.criterions_last['cnt'] = ContourLoss() - if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']: - self.criterions_last['structure'] = StructureLoss() - - def forward(self, scaled_preds, gt): - loss = 0. - criterions_embedded_with_sigmoid = ['structure', ] + ['bce'] if self.config.use_fp16 else [] - for _, pred_lvl in enumerate(scaled_preds): - if pred_lvl.shape != gt.shape: - pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True) - for criterion_name, criterion in self.criterions_last.items(): - _loss = criterion(pred_lvl.sigmoid() if criterion_name not in criterions_embedded_with_sigmoid else pred_lvl, gt) * self.lambdas_pix_last[criterion_name] - loss += _loss - # print(criterion_name, _loss.item()) - return loss - - -class SSIMLoss(torch.nn.Module): - def __init__(self, window_size=11, size_average=True): - super(SSIMLoss, self).__init__() - self.window_size = window_size - self.size_average = size_average - self.channel = 1 - self.window = create_window(window_size, self.channel) - - def forward(self, img1, img2): - (_, channel, _, _) = img1.size() - if channel == self.channel and self.window.data.type() == img1.data.type(): - window = self.window - else: - window = create_window(self.window_size, channel) - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - self.window = window - self.channel = channel - return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average) - - -def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() - - -def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) - return window - - -def _ssim(img1, img2, window, window_size, channel, size_average=True): - mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel) - mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel) - - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) - mu1_mu2 = mu1*mu2 - - sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 - - C1 = 0.01**2 - C2 = 0.03**2 - - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) - - if size_average: - return ssim_map.mean() - else: - return ssim_map.mean(1).mean(1).mean(1) - - -def SSIM(x, y): - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 - - mu_x = nn.AvgPool2d(3, 1, 1)(x) - mu_y = nn.AvgPool2d(3, 1, 1)(y) - mu_x_mu_y = mu_x * mu_y - mu_x_sq = mu_x.pow(2) - mu_y_sq = mu_y.pow(2) - - sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq - sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq - sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y - - SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) - SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) - SSIM = SSIM_n / SSIM_d - - return torch.clamp((1 - SSIM) / 2, 0, 1) - - -def saliency_structure_consistency(x, y): - ssim = torch.mean(SSIM(x,y)) - return ssim diff --git a/py/BiRefNet/make_a_copy.sh b/py/BiRefNet/make_a_copy.sh deleted file mode 100644 index 97a35fbe..00000000 --- a/py/BiRefNet/make_a_copy.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# Set dst repo here. -repo=$1 -mkdir ../${repo} -mkdir ../${repo}/evaluation -mkdir ../${repo}/models -mkdir ../${repo}/models/backbones -mkdir ../${repo}/models/modules -mkdir ../${repo}/models/refinement - -cp ./*.sh ../${repo} -cp ./*.py ../${repo} -cp ./evaluation/*.py ../${repo}/evaluation -cp ./models/*.py ../${repo}/models -cp ./models/backbones/*.py ../${repo}/models/backbones -cp ./models/modules/*.py ../${repo}/models/modules -cp ./models/refinement/*.py ../${repo}/models/refinement -cp -r ./.git* ../${repo} diff --git a/py/BiRefNet/models/backbones/build_backbone.py b/py/BiRefNet/models/backbones/build_backbone.py deleted file mode 100644 index 699c6a87..00000000 --- a/py/BiRefNet/models/backbones/build_backbone.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.nn as nn -from collections import OrderedDict -from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights -from BiRefNet.models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 -from BiRefNet.models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l -from BiRefNet.config import Config - - -config = Config() - -def build_backbone(bb_name, pretrained=True, params_settings=''): - if bb_name == 'vgg16': - bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0] - bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]})) - elif bb_name == 'vgg16bn': - bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0] - bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]})) - elif bb_name == 'resnet50': - bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children()) - bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]})) - else: - bb = eval('{}({})'.format(bb_name, params_settings)) - if pretrained: - bb = load_weights(bb, bb_name) - return bb - -def load_weights(model, model_name): - save_model = torch.load(config.weights[model_name], map_location='cpu') - model_dict = model.state_dict() - state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} - # to ignore the weights with mismatched size when I modify the backbone itself. - if not state_dict: - save_model_keys = list(save_model.keys()) - sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None - state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} - if not state_dict or not sub_item: - print('Weights are not successully loaded. Check the state dict of weights file.') - return None - else: - print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) - model_dict.update(state_dict) - model.load_state_dict(model_dict) - return model diff --git a/py/BiRefNet/models/backbones/pvt_v2.py b/py/BiRefNet/models/backbones/pvt_v2.py deleted file mode 100644 index ce6720ba..00000000 --- a/py/BiRefNet/models/backbones/pvt_v2.py +++ /dev/null @@ -1,435 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial - -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model - -import math - -from BiRefNet.config import Config - -config = Config() - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.dwconv = DWConv(hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = self.fc1(x) - x = self.dwconv(x, H, W) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): - super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." - - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop_prob = attn_drop - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] - - if config.SDPA_enabled: - x = torch.nn.functional.scaled_dot_product_attention( - q, k, v, - attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False - ).transpose(1, 2).reshape(B, N, C) - else: - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - - return x - - -class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - - return x, H, W - - -class PyramidVisionTransformerImpr(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): - super().__init__() - self.num_classes = num_classes - self.depths = depths - - # patch_embed - self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels, - embed_dim=embed_dims[0]) - self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0], - embed_dim=embed_dims[1]) - self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1], - embed_dim=embed_dims[2]) - self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2], - embed_dim=embed_dims[3]) - - # transformer encoder - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - cur = 0 - self.block1 = nn.ModuleList([Block( - dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[0]) - for i in range(depths[0])]) - self.norm1 = norm_layer(embed_dims[0]) - - cur += depths[0] - self.block2 = nn.ModuleList([Block( - dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[1]) - for i in range(depths[1])]) - self.norm2 = norm_layer(embed_dims[1]) - - cur += depths[1] - self.block3 = nn.ModuleList([Block( - dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[2]) - for i in range(depths[2])]) - self.norm3 = norm_layer(embed_dims[2]) - - cur += depths[2] - self.block4 = nn.ModuleList([Block( - dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[3]) - for i in range(depths[3])]) - self.norm4 = norm_layer(embed_dims[3]) - - # classification head - # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str): - logger = 1 - #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) - - def reset_drop_path(self, drop_path_rate): - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] - cur = 0 - for i in range(self.depths[0]): - self.block1[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[0] - for i in range(self.depths[1]): - self.block2[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[1] - for i in range(self.depths[2]): - self.block3[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[2] - for i in range(self.depths[3]): - self.block4[i].drop_path.drop_prob = dpr[cur + i] - - def freeze_patch_emb(self): - self.patch_embed1.requires_grad = False - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better - - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes, global_pool=''): - self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - - def forward_features(self, x): - B = x.shape[0] - outs = [] - - # stage 1 - x, H, W = self.patch_embed1(x) - for i, blk in enumerate(self.block1): - x = blk(x, H, W) - x = self.norm1(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 2 - x, H, W = self.patch_embed2(x) - for i, blk in enumerate(self.block2): - x = blk(x, H, W) - x = self.norm2(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 3 - x, H, W = self.patch_embed3(x) - for i, blk in enumerate(self.block3): - x = blk(x, H, W) - x = self.norm3(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 4 - x, H, W = self.patch_embed4(x) - for i, blk in enumerate(self.block4): - x = blk(x, H, W) - x = self.norm4(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - return outs - - # return x.mean(dim=1) - - def forward(self, x): - x = self.forward_features(x) - # x = self.head(x) - - return x - - -class DWConv(nn.Module): - def __init__(self, dim=768): - super(DWConv, self).__init__() - self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - - def forward(self, x, H, W): - B, N, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W).contiguous() - x = self.dwconv(x) - x = x.flatten(2).transpose(1, 2) - - return x - - -def _conv_filter(state_dict, patch_size=16): - """ convert patch embedding weight from manual patchify + linear proj to conv""" - out_dict = {} - for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k: - v = v.reshape((v.shape[0], 3, patch_size, patch_size)) - out_dict[k] = v - - return out_dict - - -## @register_model -class pvt_v2_b0(PyramidVisionTransformerImpr): - def __init__(self, **kwargs): - super(pvt_v2_b0, self).__init__( - patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - - -## @register_model -class pvt_v2_b1(PyramidVisionTransformerImpr): - def __init__(self, **kwargs): - super(pvt_v2_b1, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - -## @register_model -class pvt_v2_b2(PyramidVisionTransformerImpr): - def __init__(self, in_channels=3, **kwargs): - super(pvt_v2_b2, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) - -## @register_model -class pvt_v2_b3(PyramidVisionTransformerImpr): - def __init__(self, **kwargs): - super(pvt_v2_b3, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - -## @register_model -class pvt_v2_b4(PyramidVisionTransformerImpr): - def __init__(self, **kwargs): - super(pvt_v2_b4, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -## @register_model -class pvt_v2_b5(PyramidVisionTransformerImpr): - def __init__(self, **kwargs): - super(pvt_v2_b5, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) diff --git a/py/BiRefNet/models/backbones/swin_v1.py b/py/BiRefNet/models/backbones/swin_v1.py deleted file mode 100644 index b3a93964..00000000 --- a/py/BiRefNet/models/backbones/swin_v1.py +++ /dev/null @@ -1,627 +0,0 @@ -# -------------------------------------------------------- -# Swin Transformer -# Copyright (c) 2021 Microsoft -# Licensed under The MIT License [see LICENSE for details] -# Written by Ze Liu, Yutong Lin, Yixuan Wei -# -------------------------------------------------------- - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -import numpy as np -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - -from BiRefNet.config import Config - - -config = Config() - -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - """ Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop_prob = attn_drop - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ Forward function. - - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - - if config.SDPA_enabled: - x = torch.nn.functional.scaled_dot_product_attention( - q, k, v, - attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False - ).transpose(1, 2).reshape(B_, N, C) - else: - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - """ Swin Transformer Block. - - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.H = None - self.W = None - - def forward(self, x, mask_matrix): - """ Forward function. - - Args: - x: Input feature, tensor size (B, H*W, C). - H, W: Spatial resolution of the input feature. - mask_matrix: Attention mask for cyclic shift. - """ - B, L, C = x.shape - H, W = self.H, self.W - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # pad feature maps to multiples of window size - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - -class PatchMerging(nn.Module): - """ Patch Merging Layer - - Args: - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x, H, W): - """ Forward function. - - Args: - x: Input feature, tensor size (B, H*W, C). - H, W: Spatial resolution of the input feature. - """ - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C) - - # padding - pad_input = (H % 2 == 1) or (W % 2 == 1) - if pad_input: - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of feature channels - depth (int): Depths of this stage. - num_heads (int): Number of attention head. - window_size (int): Local window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, - dim, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, H, W): - """ Forward function. - - Args: - x: Input feature, tensor size (B, H*W, C). - H, W: Spatial resolution of the input feature. - """ - - # calculate attention mask for SW-MSA - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - for blk in self.blocks: - blk.H, blk.W = H, W - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, attn_mask) - else: - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, H, W) - Wh, Ww = (H + 1) // 2, (W + 1) // 2 - return x, H, W, x_down, Wh, Ww - else: - return x, H, W, x, H, W - - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - - Args: - patch_size (int): Patch token size. Default: 4. - in_channels (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None): - super().__init__() - patch_size = to_2tuple(patch_size) - self.patch_size = patch_size - - self.in_channels = in_channels - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, H, W = x.size() - if W % self.patch_size[1] != 0: - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) - if H % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) - - x = self.proj(x) # B C Wh Ww - if self.norm is not None: - Wh, Ww = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) - - return x - - -class SwinTransformer(nn.Module): - """ Swin Transformer backbone. - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - - Args: - pretrain_img_size (int): Input image size for training the pretrained model, - used in absolute postion embedding. Default 224. - patch_size (int | tuple(int)): Patch size. Default: 4. - in_channels (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - depths (tuple[int]): Depths of each Swin Transformer stage. - num_heads (tuple[int]): Number of attention head of each stage. - window_size (int): Window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - drop_rate (float): Dropout rate. - attn_drop_rate (float): Attention dropout rate. Default: 0. - drop_path_rate (float): Stochastic depth rate. Default: 0.2. - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. - patch_norm (bool): If True, add normalization after patch embedding. Default: True. - out_indices (Sequence[int]): Output from which stages. - frozen_stages (int): Stages to be frozen (stop grad and set eval mode). - -1 means not freezing any parameters. - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, - pretrain_img_size=224, - patch_size=4, - in_channels=3, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - ape=False, - patch_norm=True, - out_indices=(0, 1, 2, 3), - frozen_stages=-1, - use_checkpoint=False): - super().__init__() - - self.pretrain_img_size = pretrain_img_size - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.out_indices = out_indices - self.frozen_stages = frozen_stages - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - pretrain_img_size = to_2tuple(pretrain_img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] - - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) - self.layers.append(layer) - - num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - self.num_features = num_features - - # add a norm layer for each output - for i_layer in out_indices: - layer = norm_layer(num_features[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - - self._freeze_stages() - - def _freeze_stages(self): - if self.frozen_stages >= 0: - self.patch_embed.eval() - for param in self.patch_embed.parameters(): - param.requires_grad = False - - if self.frozen_stages >= 1 and self.ape: - self.absolute_pos_embed.requires_grad = False - - if self.frozen_stages >= 2: - self.pos_drop.eval() - for i in range(0, self.frozen_stages - 1): - m = self.layers[i] - m.eval() - for param in m.parameters(): - param.requires_grad = False - - - def forward(self, x): - """Forward function.""" - x = self.patch_embed(x) - - Wh, Ww = x.size(2), x.size(3) - if self.ape: - # interpolate the position embedding to the corresponding size - absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') - x = (x + absolute_pos_embed) # B Wh*Ww C - - outs = []#x.contiguous()] - x = x.flatten(2).transpose(1, 2) - x = self.pos_drop(x) - for i in range(self.num_layers): - layer = self.layers[i] - x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) - - if i in self.out_indices: - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(x_out) - - out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() - outs.append(out) - - return tuple(outs) - - def train(self, mode=True): - """Convert the model into training mode while keep layers freezed.""" - super(SwinTransformer, self).train(mode) - self._freeze_stages() - -def swin_v1_t(): - model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7) - return model - -def swin_v1_s(): - model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7) - return model - -def swin_v1_b(): - model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) - return model - -def swin_v1_l(): - model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12) - return model diff --git a/py/BiRefNet/models/birefnet.py b/py/BiRefNet/models/birefnet.py deleted file mode 100644 index 72e6c405..00000000 --- a/py/BiRefNet/models/birefnet.py +++ /dev/null @@ -1,286 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from kornia.filters import laplacian -from huggingface_hub import PyTorchModelHubMixin - -from BiRefNet.config import Config -from BiRefNet.dataset import class_labels_TR_sorted -from BiRefNet.models.backbones.build_backbone import build_backbone -from BiRefNet.models.modules.decoder_blocks import BasicDecBlk, ResBlk -from BiRefNet.models.modules.lateral_blocks import BasicLatBlk -from BiRefNet.models.modules.aspp import ASPP, ASPPDeformable -from BiRefNet.models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet -from BiRefNet.models.refinement.stem_layer import StemLayer - - -class BiRefNet( - nn.Module, - PyTorchModelHubMixin, - library_name="birefnet", - repo_url="https://github.com/ZhengPeng7/BiRefNet", - tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection'] -): - def __init__(self, bb_pretrained=True): - super(BiRefNet, self).__init__() - self.config = Config() - self.epoch = 1 - self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) - - channels = self.config.lateral_channels_in_collection - - if self.config.auxiliary_classification: - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.cls_head = nn.Sequential( - nn.Linear(channels[0], len(class_labels_TR_sorted)) - ) - - if self.config.squeeze_block: - self.squeeze_module = nn.Sequential(*[ - eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) - for _ in range(eval(self.config.squeeze_block.split('_x')[1])) - ]) - - self.decoder = Decoder(channels) - - if self.config.ender: - self.dec_end = nn.Sequential( - nn.Conv2d(1, 16, 3, 1, 1), - nn.Conv2d(16, 1, 3, 1, 1), - nn.ReLU(inplace=True), - ) - - # refine patch-level segmentation - if self.config.refine: - if self.config.refine == 'itself': - self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') - else: - self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) - - if self.config.freeze_bb: - # Freeze the backbone... - print(self.named_parameters()) - for key, value in self.named_parameters(): - if 'bb.' in key and 'refiner.' not in key: - value.requires_grad = False - - def forward_enc(self, x): - if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: - x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) - else: - x1, x2, x3, x4 = self.bb(x) - if self.config.mul_scl_ipt == 'cat': - B, C, H, W = x.shape - x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) - x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) - x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) - x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) - x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) - elif self.config.mul_scl_ipt == 'add': - B, C, H, W = x.shape - x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) - x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) - x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) - x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) - x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) - class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None - if self.config.cxt: - x4 = torch.cat( - ( - *[ - F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), - F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), - F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), - ][-len(self.config.cxt):], - x4 - ), - dim=1 - ) - return (x1, x2, x3, x4), class_preds - - def forward_ori(self, x): - ########## Encoder ########## - (x1, x2, x3, x4), class_preds = self.forward_enc(x) - if self.config.squeeze_block: - x4 = self.squeeze_module(x4) - ########## Decoder ########## - features = [x, x1, x2, x3, x4] - if self.training and self.config.out_ref: - features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) - scaled_preds = self.decoder(features) - return scaled_preds, class_preds - - def forward(self, x): - scaled_preds, class_preds = self.forward_ori(x) - class_preds_lst = [class_preds] - return [scaled_preds, class_preds_lst] if self.training else scaled_preds - - -class Decoder(nn.Module): - def __init__(self, channels): - super(Decoder, self).__init__() - self.config = Config() - DecoderBlock = eval(self.config.dec_blk) - LateralBlock = eval(self.config.lat_blk) - - if self.config.dec_ipt: - self.split = self.config.dec_ipt_split - N_dec_ipt = 64 - DBlock = SimpleConvs - ic = 64 - ipt_cha_opt = 1 - self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) - self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) - self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic) - self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic) - self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic) - else: - self.split = None - - self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1]) - self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2]) - self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]) - self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2) - self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0)) - - self.lateral_block4 = LateralBlock(channels[1], channels[1]) - self.lateral_block3 = LateralBlock(channels[2], channels[2]) - self.lateral_block2 = LateralBlock(channels[3], channels[3]) - - if self.config.ms_supervision: - self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) - self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) - self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) - - if self.config.out_ref: - _N = 16 - self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) - self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) - self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) - - self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - - self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - - def get_patches_batch(self, x, p): - _size_h, _size_w = p.shape[2:] - patches_batch = [] - for idx in range(x.shape[0]): - columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1) - patches_x = [] - for column_x in columns_x: - patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)] - patch_sample = torch.cat(patches_x, dim=1) - patches_batch.append(patch_sample) - return torch.cat(patches_batch, dim=0) - - def forward(self, features): - if self.training and self.config.out_ref: - outs_gdt_pred = [] - outs_gdt_label = [] - x, x1, x2, x3, x4, gdt_gt = features - else: - x, x1, x2, x3, x4 = features - outs = [] - - if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, x4) if self.split else x - x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) - p4 = self.decoder_block4(x4) - m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None - if self.config.out_ref: - p4_gdt = self.gdt_convs_4(p4) - if self.training: - # >> GT: - m4_dia = m4 - gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) - outs_gdt_label.append(gdt_label_main_4) - # >> Pred: - gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt) - outs_gdt_pred.append(gdt_pred_4) - gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid() - # >> Finally: - p4 = p4 * gdt_attn_4 - _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) - _p3 = _p4 + self.lateral_block4(x3) - - if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p3) if self.split else x - _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) - p3 = self.decoder_block3(_p3) - m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None - if self.config.out_ref: - p3_gdt = self.gdt_convs_3(p3) - if self.training: - # >> GT: - # m3 --dilation--> m3_dia - # G_3^gt * m3_dia --> G_3^m, which is the label of gradient - m3_dia = m3 - gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) - outs_gdt_label.append(gdt_label_main_3) - # >> Pred: - # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx - # F_3^G --sigmoid--> A_3^G - gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt) - outs_gdt_pred.append(gdt_pred_3) - gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() - # >> Finally: - # p3 = p3 * A_3^G - p3 = p3 * gdt_attn_3 - _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) - _p2 = _p3 + self.lateral_block3(x2) - - if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p2) if self.split else x - _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) - p2 = self.decoder_block2(_p2) - m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None - if self.config.out_ref: - p2_gdt = self.gdt_convs_2(p2) - if self.training: - # >> GT: - m2_dia = m2 - gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) - outs_gdt_label.append(gdt_label_main_2) - # >> Pred: - gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt) - outs_gdt_pred.append(gdt_pred_2) - gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid() - # >> Finally: - p2 = p2 * gdt_attn_2 - _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) - _p1 = _p2 + self.lateral_block2(x1) - - if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p1) if self.split else x - _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) - _p1 = self.decoder_block1(_p1) - _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) - - if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p1) if self.split else x - _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) - p1_out = self.conv_out1(_p1) - - if self.config.ms_supervision and self.training: - outs.append(m4) - outs.append(m3) - outs.append(m2) - outs.append(p1_out) - return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs) - - -class SimpleConvs(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, inter_channels=64 - ) -> None: - super().__init__() - self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) - self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1) - - def forward(self, x): - return self.conv_out(self.conv1(x)) diff --git a/py/BiRefNet/models/modules/aspp.py b/py/BiRefNet/models/modules/aspp.py deleted file mode 100644 index 7686db96..00000000 --- a/py/BiRefNet/models/modules/aspp.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from BiRefNet.models.modules.deform_conv import DeformableConv2d -from BiRefNet.config import Config - - -config = Config() - - -class _ASPPModule(nn.Module): - def __init__(self, in_channels, planes, kernel_size, padding, dilation): - super(_ASPPModule, self).__init__() - self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size, - stride=1, padding=padding, dilation=dilation, bias=False) - self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.atrous_conv(x) - x = self.bn(x) - - return self.relu(x) - - -class ASPP(nn.Module): - def __init__(self, in_channels=64, out_channels=None, output_stride=16): - super(ASPP, self).__init__() - self.down_scale = 1 - if out_channels is None: - out_channels = in_channels - self.in_channelster = 256 // self.down_scale - if output_stride == 16: - dilations = [1, 6, 12, 18] - elif output_stride == 8: - dilations = [1, 12, 24, 36] - else: - raise NotImplementedError - - self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) - self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) - self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) - self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) - - self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), - nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), - nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), - nn.ReLU(inplace=True)) - self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) - self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() - self.relu = nn.ReLU(inplace=True) - self.dropout = nn.Dropout(0.5) - - def forward(self, x): - x1 = self.aspp1(x) - x2 = self.aspp2(x) - x3 = self.aspp3(x) - x4 = self.aspp4(x) - x5 = self.global_avg_pool(x) - x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) - x = torch.cat((x1, x2, x3, x4, x5), dim=1) - - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - return self.dropout(x) - - -##################### Deformable -class _ASPPModuleDeformable(nn.Module): - def __init__(self, in_channels, planes, kernel_size, padding): - super(_ASPPModuleDeformable, self).__init__() - self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, - stride=1, padding=padding, bias=False) - self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.atrous_conv(x) - x = self.bn(x) - - return self.relu(x) - - -class ASPPDeformable(nn.Module): - def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]): - super(ASPPDeformable, self).__init__() - self.down_scale = 1 - if out_channels is None: - out_channels = in_channels - self.in_channelster = 256 // self.down_scale - - self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0) - self.aspp_deforms = nn.ModuleList([ - _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes - ]) - - self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), - nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), - nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), - nn.ReLU(inplace=True)) - self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False) - self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() - self.relu = nn.ReLU(inplace=True) - self.dropout = nn.Dropout(0.5) - - def forward(self, x): - x1 = self.aspp1(x) - x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] - x5 = self.global_avg_pool(x) - x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) - x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) - - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - return self.dropout(x) diff --git a/py/BiRefNet/models/modules/decoder_blocks.py b/py/BiRefNet/models/modules/decoder_blocks.py deleted file mode 100644 index 9bfbbd85..00000000 --- a/py/BiRefNet/models/modules/decoder_blocks.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -import torch.nn as nn - -from BiRefNet.models.modules.aspp import ASPP, ASPPDeformable -from BiRefNet.config import Config - - -config = Config() - - -class BasicDecBlk(nn.Module): - def __init__(self, in_channels=64, out_channels=64, inter_channels=64): - super(BasicDecBlk, self).__init__() - inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 - self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) - self.relu_in = nn.ReLU(inplace=True) - if config.dec_att == 'ASPP': - self.dec_att = ASPP(in_channels=inter_channels) - elif config.dec_att == 'ASPPDeformable': - self.dec_att = ASPPDeformable(in_channels=inter_channels) - self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) - self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() - self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() - - def forward(self, x): - x = self.conv_in(x) - x = self.bn_in(x) - x = self.relu_in(x) - if hasattr(self, 'dec_att'): - x = self.dec_att(x) - x = self.conv_out(x) - x = self.bn_out(x) - return x - - -class ResBlk(nn.Module): - def __init__(self, in_channels=64, out_channels=None, inter_channels=64): - super(ResBlk, self).__init__() - if out_channels is None: - out_channels = in_channels - inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 - - self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) - self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() - self.relu_in = nn.ReLU(inplace=True) - - if config.dec_att == 'ASPP': - self.dec_att = ASPP(in_channels=inter_channels) - elif config.dec_att == 'ASPPDeformable': - self.dec_att = ASPPDeformable(in_channels=inter_channels) - - self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) - self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() - - self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0) - - def forward(self, x): - _x = self.conv_resi(x) - x = self.conv_in(x) - x = self.bn_in(x) - x = self.relu_in(x) - if hasattr(self, 'dec_att'): - x = self.dec_att(x) - x = self.conv_out(x) - x = self.bn_out(x) - return x + _x diff --git a/py/BiRefNet/models/modules/deform_conv.py b/py/BiRefNet/models/modules/deform_conv.py deleted file mode 100644 index 43f5e57f..00000000 --- a/py/BiRefNet/models/modules/deform_conv.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.ops import deform_conv2d - - -class DeformableConv2d(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False): - - super(DeformableConv2d, self).__init__() - - assert type(kernel_size) == tuple or type(kernel_size) == int - - kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) - self.stride = stride if type(stride) == tuple else (stride, stride) - self.padding = padding - - self.offset_conv = nn.Conv2d(in_channels, - 2 * kernel_size[0] * kernel_size[1], - kernel_size=kernel_size, - stride=stride, - padding=self.padding, - bias=True) - - nn.init.constant_(self.offset_conv.weight, 0.) - nn.init.constant_(self.offset_conv.bias, 0.) - - self.modulator_conv = nn.Conv2d(in_channels, - 1 * kernel_size[0] * kernel_size[1], - kernel_size=kernel_size, - stride=stride, - padding=self.padding, - bias=True) - - nn.init.constant_(self.modulator_conv.weight, 0.) - nn.init.constant_(self.modulator_conv.bias, 0.) - - self.regular_conv = nn.Conv2d(in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=self.padding, - bias=bias) - - def forward(self, x): - #h, w = x.shape[2:] - #max_offset = max(h, w)/4. - - offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) - modulator = 2. * torch.sigmoid(self.modulator_conv(x)) - - x = deform_conv2d( - input=x, - offset=offset, - weight=self.regular_conv.weight, - bias=self.regular_conv.bias, - padding=self.padding, - mask=modulator, - stride=self.stride, - ) - return x diff --git a/py/BiRefNet/models/modules/lateral_blocks.py b/py/BiRefNet/models/modules/lateral_blocks.py deleted file mode 100644 index abee527f..00000000 --- a/py/BiRefNet/models/modules/lateral_blocks.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial - -from BiRefNet.config import Config - - -config = Config() - - -class BasicLatBlk(nn.Module): - def __init__(self, in_channels=64, out_channels=64, inter_channels=64): - super(BasicLatBlk, self).__init__() - inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 - self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) - - def forward(self, x): - x = self.conv(x) - return x diff --git a/py/BiRefNet/models/modules/mlp.py b/py/BiRefNet/models/modules/mlp.py deleted file mode 100644 index 39b35683..00000000 --- a/py/BiRefNet/models/modules/mlp.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial - -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model - -import math - - -class MLPLayer(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): - super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." - - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) - - def forward(self, x, H, W): - B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLPLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - return x - - -class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - return x, H, W - diff --git a/py/BiRefNet/models/modules/prompt_encoder.py b/py/BiRefNet/models/modules/prompt_encoder.py deleted file mode 100644 index 23ce18c3..00000000 --- a/py/BiRefNet/models/modules/prompt_encoder.py +++ /dev/null @@ -1,222 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from typing import Any, Optional, Tuple, Type - - -class PromptEncoder(nn.Module): - def __init__( - self, - embed_dim=256, - image_embedding_size=1024, - input_image_size=(1024, 1024), - mask_in_chans=16, - activation=nn.GELU - ) -> None: - super().__init__() - """ - Codes are partially from SAM: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/prompt_encoder.py. - - Arguments: - embed_dim (int): The prompts' embedding dimension - image_embedding_size (tuple(int, int)): The spatial size of the - image embedding, as (H, W). - input_image_size (int): The padded size of the image as input - to the image encoder, as (H, W). - mask_in_chans (int): The number of hidden channels used for - encoding input masks. - activation (nn.Module): The activation to use when encoding - input masks. - """ - super().__init__() - self.embed_dim = embed_dim - self.input_image_size = input_image_size - self.image_embedding_size = image_embedding_size - self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) - - self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners - point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] - self.point_embeddings = nn.ModuleList(point_embeddings) - self.not_a_point_embed = nn.Embedding(1, embed_dim) - - self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) - self.mask_downscaling = nn.Sequential( - nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), - LayerNorm2d(mask_in_chans // 4), - activation(), - nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), - LayerNorm2d(mask_in_chans), - activation(), - nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), - ) - self.no_mask_embed = nn.Embedding(1, embed_dim) - - def get_dense_pe(self) -> torch.Tensor: - """ - Returns the positional encoding used to encode point prompts, - applied to a dense set of points the shape of the image encoding. - - Returns: - torch.Tensor: Positional encoding with shape - 1x(embed_dim)x(embedding_h)x(embedding_w) - """ - return self.pe_layer(self.image_embedding_size).unsqueeze(0) - - def _embed_points( - self, - points: torch.Tensor, - labels: torch.Tensor, - pad: bool, - ) -> torch.Tensor: - """Embeds point prompts.""" - points = points + 0.5 # Shift to center of pixel - if pad: - padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) - padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) - points = torch.cat([points, padding_point], dim=1) - labels = torch.cat([labels, padding_label], dim=1) - point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - return point_embedding - - def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: - """Embeds box prompts.""" - boxes = boxes + 0.5 # Shift to center of pixel - coords = boxes.reshape(-1, 2, 2) - corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) - corner_embedding[:, 0, :] += self.point_embeddings[2].weight - corner_embedding[:, 1, :] += self.point_embeddings[3].weight - return corner_embedding - - def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: - """Embeds mask inputs.""" - mask_embedding = self.mask_downscaling(masks) - return mask_embedding - - def _get_batch_size( - self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], - masks: Optional[torch.Tensor], - ) -> int: - """ - Gets the batch size of the output given the batch size of the input prompts. - """ - if points is not None: - return points[0].shape[0] - elif boxes is not None: - return boxes.shape[0] - elif masks is not None: - return masks.shape[0] - else: - return 1 - - def _get_device(self) -> torch.device: - return self.point_embeddings[0].weight.device - - def forward( - self, - points: Optional[Tuple[torch.Tensor, torch.Tensor]], - boxes: Optional[torch.Tensor], - masks: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Embeds different types of prompts, returning both sparse and dense - embeddings. - - Arguments: - points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates - and labels to embed. - boxes (torch.Tensor or none): boxes to embed - masks (torch.Tensor or none): masks to embed - - Returns: - torch.Tensor: sparse embeddings for the points and boxes, with shape - BxNx(embed_dim), where N is determined by the number of input points - and boxes. - torch.Tensor: dense embeddings for the masks, in the shape - Bx(embed_dim)x(embed_H)x(embed_W) - """ - bs = self._get_batch_size(points, boxes, masks) - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) - if points is not None: - coords, labels = points - point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) - if boxes is not None: - box_embeddings = self._embed_boxes(boxes) - sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) - - if masks is not None: - dense_embeddings = self._embed_masks(masks) - else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( - bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] - ) - - return sparse_embeddings, dense_embeddings - - -class PositionEmbeddingRandom(nn.Module): - """ - Positional encoding using random spatial frequencies. - """ - - def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: - super().__init__() - if scale is None or scale <= 0.0: - scale = 1.0 - self.register_buffer( - "positional_encoding_gaussian_matrix", - scale * torch.randn((2, num_pos_feats)), - ) - - def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: - """Positionally encode points that are normalized to [0,1].""" - # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape - coords = 2 * coords - 1 - coords = coords @ self.positional_encoding_gaussian_matrix - coords = 2 * np.pi * coords - # outputs d_1 x ... x d_n x C shape - return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - - def forward(self, size: Tuple[int, int]) -> torch.Tensor: - """Generate positional encoding for a grid of the specified size.""" - h, w = size - device: Any = self.positional_encoding_gaussian_matrix.device - grid = torch.ones((h, w), device=device, dtype=torch.float32) - y_embed = grid.cumsum(dim=0) - 0.5 - x_embed = grid.cumsum(dim=1) - 0.5 - y_embed = y_embed / h - x_embed = x_embed / w - - pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) - return pe.permute(2, 0, 1) # C x H x W - - def forward_with_coords( - self, coords_input: torch.Tensor, image_size: Tuple[int, int] - ) -> torch.Tensor: - """Positionally encode points that are not normalized to [0,1].""" - coords = coords_input.clone() - coords[:, :, 0] = coords[:, :, 0] / image_size[1] - coords[:, :, 1] = coords[:, :, 1] / image_size[0] - return self._pe_encoding(coords.to(torch.float)) # B x N x C - - -class LayerNorm2d(nn.Module): - def __init__(self, num_channels: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(num_channels)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - self.eps = eps - - def forward(self, x: torch.Tensor) -> torch.Tensor: - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - diff --git a/py/BiRefNet/models/modules/utils.py b/py/BiRefNet/models/modules/utils.py deleted file mode 100644 index 59bd9121..00000000 --- a/py/BiRefNet/models/modules/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch.nn as nn - - -def build_act_layer(act_layer): - if act_layer == 'ReLU': - return nn.ReLU(inplace=True) - elif act_layer == 'SiLU': - return nn.SiLU(inplace=True) - elif act_layer == 'GELU': - return nn.GELU() - - raise NotImplementedError(f'build_act_layer does not support {act_layer}') - - -def build_norm_layer(dim, - norm_layer, - in_format='channels_last', - out_format='channels_last', - eps=1e-6): - layers = [] - if norm_layer == 'BN': - if in_format == 'channels_last': - layers.append(to_channels_first()) - layers.append(nn.BatchNorm2d(dim)) - if out_format == 'channels_last': - layers.append(to_channels_last()) - elif norm_layer == 'LN': - if in_format == 'channels_first': - layers.append(to_channels_last()) - layers.append(nn.LayerNorm(dim, eps=eps)) - if out_format == 'channels_first': - layers.append(to_channels_first()) - else: - raise NotImplementedError( - f'build_norm_layer does not support {norm_layer}') - return nn.Sequential(*layers) - - -class to_channels_first(nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.permute(0, 3, 1, 2) - - -class to_channels_last(nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, x): - return x.permute(0, 2, 3, 1) diff --git a/py/BiRefNet/models/refinement/refiner.py b/py/BiRefNet/models/refinement/refiner.py deleted file mode 100644 index 19f696a2..00000000 --- a/py/BiRefNet/models/refinement/refiner.py +++ /dev/null @@ -1,252 +0,0 @@ -import torch -import torch.nn as nn -from collections import OrderedDict -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision.models import vgg16, vgg16_bn -from torchvision.models import resnet50 - -from BiRefNet.config import Config -from BiRefNet.dataset import class_labels_TR_sorted -from BiRefNet.models.backbones.build_backbone import build_backbone -from BiRefNet.models.modules.decoder_blocks import BasicDecBlk -from BiRefNet.models.modules.lateral_blocks import BasicLatBlk -from BiRefNet.models.refinement.stem_layer import StemLayer - - -class RefinerPVTInChannels4(nn.Module): - def __init__(self, in_channels=3+1): - super(RefinerPVTInChannels4, self).__init__() - self.config = Config() - self.epoch = 1 - self.bb = build_backbone(self.config.bb, params_settings='in_channels=4') - - lateral_channels_in_collection = { - 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], - 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], - 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], - } - channels = lateral_channels_in_collection[self.config.bb] - self.squeeze_module = BasicDecBlk(channels[0], channels[0]) - - self.decoder = Decoder(channels) - - if 0: - for key, value in self.named_parameters(): - if 'bb.' in key: - value.requires_grad = False - - def forward(self, x): - if isinstance(x, list): - x = torch.cat(x, dim=1) - ########## Encoder ########## - if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: - x1 = self.bb.conv1(x) - x2 = self.bb.conv2(x1) - x3 = self.bb.conv3(x2) - x4 = self.bb.conv4(x3) - else: - x1, x2, x3, x4 = self.bb(x) - - x4 = self.squeeze_module(x4) - - ########## Decoder ########## - - features = [x, x1, x2, x3, x4] - scaled_preds = self.decoder(features) - - return scaled_preds - - -class Refiner(nn.Module): - def __init__(self, in_channels=3+1): - super(Refiner, self).__init__() - self.config = Config() - self.epoch = 1 - self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') - self.bb = build_backbone(self.config.bb) - - lateral_channels_in_collection = { - 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], - 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], - 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], - } - channels = lateral_channels_in_collection[self.config.bb] - self.squeeze_module = BasicDecBlk(channels[0], channels[0]) - - self.decoder = Decoder(channels) - - if 0: - for key, value in self.named_parameters(): - if 'bb.' in key: - value.requires_grad = False - - def forward(self, x): - if isinstance(x, list): - x = torch.cat(x, dim=1) - x = self.stem_layer(x) - ########## Encoder ########## - if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: - x1 = self.bb.conv1(x) - x2 = self.bb.conv2(x1) - x3 = self.bb.conv3(x2) - x4 = self.bb.conv4(x3) - else: - x1, x2, x3, x4 = self.bb(x) - - x4 = self.squeeze_module(x4) - - ########## Decoder ########## - - features = [x, x1, x2, x3, x4] - scaled_preds = self.decoder(features) - - return scaled_preds - - -class Decoder(nn.Module): - def __init__(self, channels): - super(Decoder, self).__init__() - self.config = Config() - DecoderBlock = eval('BasicDecBlk') - LateralBlock = eval('BasicLatBlk') - - self.decoder_block4 = DecoderBlock(channels[0], channels[1]) - self.decoder_block3 = DecoderBlock(channels[1], channels[2]) - self.decoder_block2 = DecoderBlock(channels[2], channels[3]) - self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2) - - self.lateral_block4 = LateralBlock(channels[1], channels[1]) - self.lateral_block3 = LateralBlock(channels[2], channels[2]) - self.lateral_block2 = LateralBlock(channels[3], channels[3]) - - if self.config.ms_supervision: - self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) - self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) - self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) - self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0)) - - def forward(self, features): - x, x1, x2, x3, x4 = features - outs = [] - p4 = self.decoder_block4(x4) - _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) - _p3 = _p4 + self.lateral_block4(x3) - - p3 = self.decoder_block3(_p3) - _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) - _p2 = _p3 + self.lateral_block3(x2) - - p2 = self.decoder_block2(_p2) - _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) - _p1 = _p2 + self.lateral_block2(x1) - - _p1 = self.decoder_block1(_p1) - _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) - p1_out = self.conv_out1(_p1) - - if self.config.ms_supervision: - outs.append(self.conv_ms_spvn_4(p4)) - outs.append(self.conv_ms_spvn_3(p3)) - outs.append(self.conv_ms_spvn_2(p2)) - outs.append(p1_out) - return outs - - -class RefUNet(nn.Module): - # Refinement - def __init__(self, in_channels=3+1): - super(RefUNet, self).__init__() - self.encoder_1 = nn.Sequential( - nn.Conv2d(in_channels, 64, 3, 1, 1), - nn.Conv2d(64, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.encoder_2 = nn.Sequential( - nn.MaxPool2d(2, 2, ceil_mode=True), - nn.Conv2d(64, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.encoder_3 = nn.Sequential( - nn.MaxPool2d(2, 2, ceil_mode=True), - nn.Conv2d(64, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.encoder_4 = nn.Sequential( - nn.MaxPool2d(2, 2, ceil_mode=True), - nn.Conv2d(64, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) - ##### - self.decoder_5 = nn.Sequential( - nn.Conv2d(64, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - ##### - self.decoder_4 = nn.Sequential( - nn.Conv2d(128, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.decoder_3 = nn.Sequential( - nn.Conv2d(128, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.decoder_2 = nn.Sequential( - nn.Conv2d(128, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.decoder_1 = nn.Sequential( - nn.Conv2d(128, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True) - ) - - self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1) - - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - - def forward(self, x): - outs = [] - if isinstance(x, list): - x = torch.cat(x, dim=1) - hx = x - - hx1 = self.encoder_1(hx) - hx2 = self.encoder_2(hx1) - hx3 = self.encoder_3(hx2) - hx4 = self.encoder_4(hx3) - - hx = self.decoder_5(self.pool4(hx4)) - hx = torch.cat((self.upscore2(hx), hx4), 1) - - d4 = self.decoder_4(hx) - hx = torch.cat((self.upscore2(d4), hx3), 1) - - d3 = self.decoder_3(hx) - hx = torch.cat((self.upscore2(d3), hx2), 1) - - d2 = self.decoder_2(hx) - hx = torch.cat((self.upscore2(d2), hx1), 1) - - d1 = self.decoder_1(hx) - - x = self.conv_d0(d1) - outs.append(x) - return outs diff --git a/py/BiRefNet/models/refinement/stem_layer.py b/py/BiRefNet/models/refinement/stem_layer.py deleted file mode 100644 index 50d0ac21..00000000 --- a/py/BiRefNet/models/refinement/stem_layer.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch.nn as nn -from BiRefNet.models.modules.utils import build_act_layer, build_norm_layer - - -class StemLayer(nn.Module): - r""" Stem layer of InternImage - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - act_layer (str): activation layer - norm_layer (str): normalization layer - """ - - def __init__(self, - in_channels=3+1, - inter_channels=48, - out_channels=96, - act_layer='GELU', - norm_layer='BN'): - super().__init__() - self.conv1 = nn.Conv2d(in_channels, - inter_channels, - kernel_size=3, - stride=1, - padding=1) - self.norm1 = build_norm_layer( - inter_channels, norm_layer, 'channels_first', 'channels_first' - ) - self.act = build_act_layer(act_layer) - self.conv2 = nn.Conv2d(inter_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - self.norm2 = build_norm_layer( - out_channels, norm_layer, 'channels_first', 'channels_first' - ) - - def forward(self, x): - x = self.conv1(x) - x = self.norm1(x) - x = self.act(x) - x = self.conv2(x) - x = self.norm2(x) - return x diff --git a/py/BiRefNet/requirements.txt b/py/BiRefNet/requirements.txt deleted file mode 100644 index 546ffa35..00000000 --- a/py/BiRefNet/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ ---extra-index-url https://download.pytorch.org/whl/cu118 -torch==2.0.1 ---extra-index-url https://download.pytorch.org/whl/cu118 -torchvision==0.15.2 -numpy<2 -opencv-python -timm -scipy -scikit-image -kornia - -tqdm -prettytable - -huggingface_hub diff --git a/py/BiRefNet/rm_cache.sh b/py/BiRefNet/rm_cache.sh deleted file mode 100644 index 5e75b92c..00000000 --- a/py/BiRefNet/rm_cache.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -rm -rf __pycache__ */__pycache__ - -# Val -rm -r tmp* - -# Train -rm slurm* -rm -r ckpt -rm nohup.out* - -# Eval -rm -r evaluation/eval-* -rm -r tmp* -rm -r e_logs/ - -# System -rm core-*-python-* - -clear diff --git a/py/BiRefNet/sub.sh b/py/BiRefNet/sub.sh deleted file mode 100644 index 9e216b98..00000000 --- a/py/BiRefNet/sub.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh -# Example: ./sub.sh tmp_proj 0,1,2,3 3 --> Use 0,1,2,3 for training, release GPUs, use GPU:3 for inference. - -# module load gcc/11.2.0 cuda/11.8 cudnn/8.6.0_cu11x && cpu_core_num=6 -module load compilers/cuda/11.8 compilers/gcc/12.2.0 cudnn/8.4.0.27_cuda11.x && cpu_core_num=32 - -export PYTHONUNBUFFERED=1 - -method=${1:-"BSL"} -devices=${2:-0} -gpu_num=$(($(echo ${devices%%,} | grep -o "," | wc -l)+1)) - -sbatch --nodes=1 -p vip_gpu_ailab -A ai4bio \ - --gres=gpu:${gpu_num} --ntasks-per-node=1 --cpus-per-task=$((gpu_num*cpu_core_num)) \ - ./train_test.sh ${method} ${devices} - -hostname diff --git a/py/BiRefNet/test.sh b/py/BiRefNet/test.sh deleted file mode 100644 index 66a61490..00000000 --- a/py/BiRefNet/test.sh +++ /dev/null @@ -1,29 +0,0 @@ -devices=${1:-0} -pred_root=${2:-e_preds} - -# Inference - -CUDA_VISIBLE_DEVICES=${devices} python inference.py --pred_root ${pred_root} - -echo Inference finished at $(date) - -# Evaluation -log_dir=e_logs && mkdir ${log_dir} - -task=$(python3 config.py) -case "${task}" in - "DIS5K") testsets='DIS-VD,DIS-TE1,DIS-TE2,DIS-TE3,DIS-TE4' ;; - "COD") testsets='CHAMELEON,NC4K,TE-CAMO,TE-COD10K' ;; - "HRSOD") testsets='DAVIS-S,TE-HRSOD,TE-UHRSD,DUT-OMRON,TE-DUTS' ;; - "General") testsets='DIS-VD' ;; - "Matting") testsets='TE-P3M-500-P' ;; -esac -testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]} - -for testset in ${testsets}; do - python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} > ${log_dir}/eval_${testset}.out - # nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} > ${log_dir}/eval_${testset}.out 2>&1 & -done - - -echo Evaluation started at $(date) diff --git a/py/BiRefNet/train.py b/py/BiRefNet/train.py deleted file mode 100644 index 20bd9094..00000000 --- a/py/BiRefNet/train.py +++ /dev/null @@ -1,333 +0,0 @@ -import os -import datetime -import argparse -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -from BiRefNet.config import Config -from BiRefNet.loss import PixLoss, ClsLoss -from BiRefNet.dataset import MyData -from BiRefNet.models.birefnet import BiRefNet -from BiRefNet.utils import Logger, AverageMeter, set_seed, check_state_dict - -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group, destroy_process_group, get_rank -from torch.cuda import amp - - -parser = argparse.ArgumentParser(description='') -parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') -parser.add_argument('--epochs', default=120, type=int) -parser.add_argument('--trainset', default='DIS5K', type=str, help="Options: 'DIS5K'") -parser.add_argument('--ckpt_dir', default=None, help='Temporary folder') -parser.add_argument('--testsets', default='DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4', type=str) -parser.add_argument('--dist', default=False, type=lambda x: x == 'True') -args = parser.parse_args() - - -config = Config() -if config.rand_seed: - set_seed(config.rand_seed) - -if config.use_fp16: - # Half Precision - scaler = amp.GradScaler(enabled=config.use_fp16) - -# DDP -to_be_distributed = args.dist -if to_be_distributed: - init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)) - device = int(os.environ["LOCAL_RANK"]) -else: - device = config.device - -epoch_st = 1 -# make dir for ckpt -os.makedirs(args.ckpt_dir, exist_ok=True) - -# Init log file -logger = Logger(os.path.join(args.ckpt_dir, "log.txt")) -logger_loss_idx = 1 - -# log model and optimizer params -# logger.info("Model details:"); logger.info(model) -logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile)) -logger.info("Other hyperparameters:"); logger.info(args) -print('batch size:', config.batch_size) - - -if os.path.exists(os.path.join(config.data_root_dir, config.task, args.testsets.strip('+').split('+')[0])): - args.testsets = args.testsets.strip('+').split('+') -else: - args.testsets = [] - -# Init model -def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True): - if to_be_distributed: - return torch.utils.data.DataLoader( - dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True, - shuffle=False, sampler=DistributedSampler(dataset), drop_last=True - ) - else: - return torch.utils.data.DataLoader( - dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size, 0), pin_memory=True, - shuffle=is_train, drop_last=True - ) - - -def init_data_loaders(to_be_distributed): - # Prepare dataset - train_loader = prepare_dataloader( - MyData(datasets=config.training_set, image_size=config.size, is_train=True), - config.batch_size, to_be_distributed=to_be_distributed, is_train=True - ) - print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set)) - test_loaders = {} - for testset in args.testsets: - _data_loader_test = prepare_dataloader( - MyData(datasets=testset, image_size=config.size, is_train=False), - config.batch_size_valid, is_train=False - ) - print(len(_data_loader_test), "batches of valid dataloader {} have been created.".format(testset)) - test_loaders[testset] = _data_loader_test - return train_loader, test_loaders - - -def init_models_optimizers(epochs, to_be_distributed): - model = BiRefNet(bb_pretrained=True) - if args.resume: - if os.path.isfile(args.resume): - logger.info("=> loading checkpoint '{}'".format(args.resume)) - state_dict = torch.load(args.resume, map_location='cpu') - state_dict = check_state_dict(state_dict) - model.load_state_dict(state_dict) - global epoch_st - epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1 - else: - logger.info("=> no checkpoint found at '{}'".format(args.resume)) - if to_be_distributed: - model = model.to(device) - model = DDP(model, device_ids=[device]) - else: - model = model.to(device) - if config.compile: - model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0]) - if config.precisionHigh: - torch.set_float32_matmul_precision('high') - - - # Setting optimizer - if config.optimizer == 'AdamW': - optimizer = optim.AdamW(params=model.parameters(), lr=config.lr, weight_decay=1e-2) - elif config.optimizer == 'Adam': - optimizer = optim.Adam(params=model.parameters(), lr=config.lr, weight_decay=0) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs], - gamma=config.lr_decay_rate - ) - logger.info("Optimizer details:"); logger.info(optimizer) - logger.info("Scheduler details:"); logger.info(lr_scheduler) - - return model, optimizer, lr_scheduler - - -class Trainer: - def __init__( - self, data_loaders, model_opt_lrsch, - ): - self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch - self.train_loader, self.test_loaders = data_loaders - if config.out_ref: - self.criterion_gdt = nn.BCELoss() if not config.use_fp16 else nn.BCEWithLogitsLoss() - - # Setting Losses - self.pix_loss = PixLoss() - self.cls_loss = ClsLoss() - - # Others - self.loss_log = AverageMeter() - if config.lambda_adv_g: - self.optimizer_d, self.lr_scheduler_d, self.disc, self.adv_criterion = self._load_adv_components() - self.disc_update_for_odd = 0 - - def _load_adv_components(self): - # AIL - from loss import Discriminator - disc = Discriminator(channels=3, img_size=config.size) - if to_be_distributed: - disc = disc.to(device) - disc = DDP(disc, device_ids=[device], broadcast_buffers=False) - else: - disc = disc.to(device) - if config.compile: - disc = torch.compile(disc, mode=['default', 'reduce-overhead', 'max-autotune'][0]) - adv_criterion = nn.BCELoss() if not config.use_fp16 else nn.BCEWithLogitsLoss() - if config.optimizer == 'AdamW': - optimizer_d = optim.AdamW(params=disc.parameters(), lr=config.lr, weight_decay=1e-2) - elif config.optimizer == 'Adam': - optimizer_d = optim.Adam(params=disc.parameters(), lr=config.lr, weight_decay=0) - lr_scheduler_d = torch.optim.lr_scheduler.MultiStepLR( - optimizer_d, - milestones=[lde if lde > 0 else args.epochs + lde + 1 for lde in config.lr_decay_epochs], - gamma=config.lr_decay_rate - ) - return optimizer_d, lr_scheduler_d, disc, adv_criterion - - def _train_batch(self, batch): - inputs = batch[0].to(device) - gts = batch[1].to(device) - class_labels = batch[2].to(device) - if config.use_fp16: - with amp.autocast(enabled=config.use_fp16): - scaled_preds, class_preds_lst = self.model(inputs) - if config.out_ref: - (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds - for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): - _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True)#.sigmoid() - # _gdt_label = _gdt_label.sigmoid() - loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt - # self.loss_dict['loss_gdt'] = loss_gdt.item() - if None in class_preds_lst: - loss_cls = 0. - else: - loss_cls = self.cls_loss(class_preds_lst, class_labels) * 1.0 - self.loss_dict['loss_cls'] = loss_cls.item() - - # Loss - loss_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1)) * 1.0 - self.loss_dict['loss_pix'] = loss_pix.item() - # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py - loss = loss_pix + loss_cls - if config.out_ref: - loss = loss + loss_gdt * 1.0 - - if config.lambda_adv_g: - # gen - valid = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(1.0), requires_grad=False).to(device) - adv_loss_g = self.adv_criterion(self.disc(scaled_preds[-1] * inputs), valid) * config.lambda_adv_g - loss += adv_loss_g - self.loss_dict['loss_adv'] = adv_loss_g.item() - self.disc_update_for_odd += 1 - # self.loss_log.update(loss.item(), inputs.size(0)) - # self.optimizer.zero_grad() - # loss.backward() - # self.optimizer.step() - self.optimizer.zero_grad() - scaler.scale(loss).backward() - scaler.step(self.optimizer) - scaler.update() - - if config.lambda_adv_g and self.disc_update_for_odd % 2 == 0: - # disc - fake = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(0.0), requires_grad=False).to(device) - adv_loss_real = self.adv_criterion(self.disc(gts * inputs), valid) - adv_loss_fake = self.adv_criterion(self.disc(scaled_preds[-1].detach() * inputs.detach()), fake) - adv_loss_d = (adv_loss_real + adv_loss_fake) / 2 * config.lambda_adv_d - self.loss_dict['loss_adv_d'] = adv_loss_d.item() - # self.optimizer_d.zero_grad() - # adv_loss_d.backward() - # self.optimizer_d.step() - self.optimizer_d.zero_grad() - scaler.scale(adv_loss_d).backward() - scaler.step(self.optimizer_d) - scaler.update() - else: - scaled_preds, class_preds_lst = self.model(inputs) - if config.out_ref: - (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds - for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): - _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid() - _gdt_label = _gdt_label.sigmoid() - loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt - # self.loss_dict['loss_gdt'] = loss_gdt.item() - if None in class_preds_lst: - loss_cls = 0. - else: - loss_cls = self.cls_loss(class_preds_lst, class_labels) * 1.0 - self.loss_dict['loss_cls'] = loss_cls.item() - - # Loss - loss_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1)) * 1.0 - self.loss_dict['loss_pix'] = loss_pix.item() - # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py - loss = loss_pix + loss_cls - if config.out_ref: - loss = loss + loss_gdt * 1.0 - - if config.lambda_adv_g: - # gen - valid = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(1.0), requires_grad=False).to(device) - adv_loss_g = self.adv_criterion(self.disc(scaled_preds[-1] * inputs), valid) * config.lambda_adv_g - loss += adv_loss_g - self.loss_dict['loss_adv'] = adv_loss_g.item() - self.disc_update_for_odd += 1 - self.loss_log.update(loss.item(), inputs.size(0)) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - if config.lambda_adv_g and self.disc_update_for_odd % 2 == 0: - # disc - fake = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(0.0), requires_grad=False).to(device) - adv_loss_real = self.adv_criterion(self.disc(gts * inputs), valid) - adv_loss_fake = self.adv_criterion(self.disc(scaled_preds[-1].detach() * inputs.detach()), fake) - adv_loss_d = (adv_loss_real + adv_loss_fake) / 2 * config.lambda_adv_d - self.loss_dict['loss_adv_d'] = adv_loss_d.item() - self.optimizer_d.zero_grad() - adv_loss_d.backward() - self.optimizer_d.step() - - def train_epoch(self, epoch): - global logger_loss_idx - self.model.train() - self.loss_dict = {} - if epoch > args.epochs + config.finetune_last_epochs[1]: - for k in self.pix_loss.lambdas_pix_last.keys(): - if k.lower() == config.finetune_last_epochs[0].lower(): - self.pix_loss.lambdas_pix_last[k] = config.lambdas_pix_last[k] * 0.5 - else: - self.pix_loss.lambdas_pix_last[k] = 0 - - for batch_idx, batch in enumerate(self.train_loader): - self._train_batch(batch) - # Logger - if batch_idx % 20 == 0: - info_progress = 'Epoch[{0}/{1}] Iter[{2}/{3}].'.format(epoch, args.epochs, batch_idx, len(self.train_loader)) - info_loss = 'Training Losses' - for loss_name, loss_value in self.loss_dict.items(): - info_loss += ', {}: {:.3f}'.format(loss_name, loss_value) - logger.info(' '.join((info_progress, info_loss))) - info_loss = '@==Final== Epoch[{0}/{1}] Training Loss: {loss.avg:.3f} '.format(epoch, args.epochs, loss=self.loss_log) - logger.info(info_loss) - - self.lr_scheduler.step() - if config.lambda_adv_g: - self.lr_scheduler_d.step() - return self.loss_log.avg - - -def main(): - - trainer = Trainer( - data_loaders=init_data_loaders(to_be_distributed), - model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed) - ) - - for epoch in range(epoch_st, args.epochs+1): - train_loss = trainer.train_epoch(epoch) - # Save checkpoint - # DDP - if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0: - torch.save( - trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict(), - os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)) - ) - if to_be_distributed: - destroy_process_group() - -if __name__ == '__main__': - main() diff --git a/py/BiRefNet/train.sh b/py/BiRefNet/train.sh deleted file mode 100644 index 78421d8b..00000000 --- a/py/BiRefNet/train.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -# Run script -# Settings of training & test for different tasks. -method="$1" -task=$(python3 config.py) -case "${task}" in - "DIS5K") epochs=600 && val_last=50 && step=5 ;; - "COD") epochs=150 && val_last=50 && step=5 ;; - "HRSOD") epochs=150 && val_last=50 && step=5 ;; - "General") epochs=250 && val_last=20 && step=2 ;; - "Matting") epochs=100 && val_last=20 && step=2 ;; -esac -testsets=NO # Non-existing folder to skip. -# testsets=TE-COD10K # for COD - -# Train -devices=$2 -nproc_per_node=$(echo ${devices%%,} | grep -o "," | wc -l) - -to_be_distributed=`echo ${nproc_per_node} | awk '{if($e > 0) print "True"; else print "False";}'` - -echo Training started at $(date) -if [ ${to_be_distributed} == "True" ] -then - # Adapt the nproc_per_node by the number of GPUs. Give 8989 as the default value of master_port. - echo "Multi-GPU mode received..." - CUDA_VISIBLE_DEVICES=${devices} \ - torchrun --nproc_per_node $((nproc_per_node+1)) --master_port=${3:-8999} \ - train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \ - --testsets ${testsets} \ - --dist ${to_be_distributed} \ - --resume xx/xx-epoch_244.pth -else - echo "Single-GPU mode received..." - CUDA_VISIBLE_DEVICES=${devices} \ - python train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \ - --testsets ${testsets} \ - --dist ${to_be_distributed} \ - --resume xx/xx-epoch_244.pth -fi - -echo Training finished at $(date) diff --git a/py/BiRefNet/train_test.sh b/py/BiRefNet/train_test.sh deleted file mode 100644 index e9d3a265..00000000 --- a/py/BiRefNet/train_test.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh - -method=${1:-"BSL"} -devices=${2:-"0,1,2,3,4,5,6,7"} - -bash train.sh ${method} ${devices} - -devices_test=${3:-0} -bash test.sh ${devices_test} - -hostname diff --git a/py/BiRefNet/tutorials/BiRefNet_inference.ipynb b/py/BiRefNet/tutorials/BiRefNet_inference.ipynb deleted file mode 100644 index 4173711f..00000000 --- a/py/BiRefNet/tutorials/BiRefNet_inference.ipynb +++ /dev/null @@ -1,1575 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Online Colab Demo: https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba\n", - "### Hugging Face Spaces Demo: https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 391, - "referenced_widgets": [ - "7d19deaab4c845eea4705567bdc65d60", - "a8941bdff0984189be91fab5bfe1c52c", - "b6be81c6cc1e4608a88c785c448bfaa4", - "3b0ddb32ffa442aab3b02a22432bb233", - "688a14cd34704e4ea2e261f6619449ee", - "8a54f72e65a24d7a93852aca8ecae0a2", - "7745af46aa694f5b8ff0ec8fbd72025f", - "4e5fe4296291455f88ae0e5f257c398e", - "60d3c15d4b944546992d3517bb73912a", - "a124da94c5b143a29d9420eba3954859", - "96485cf384484a709e4c32b74e0223af", - "7bef3ce58df040c986a74eb5334f42c9", - "1848689bf6a14235b508647d1bdb7f69", - "055ec5d0d4da4d1b8df33ddb22d6b55b", - "3c9b1a3ed2f64ed58a29083280098292", - "b2b95dd8b75d4625ad1bd02a461f304d", - "044920317fb64cd088d49f334930b886", - "a8416e76428b46c092c7afbf7129b5f9", - "ebf0b02bc7734ebea9a846233a1a6ec4", - "fb3e80729a214bc5993b62ee04d7c58d", - "cf459ec049624ba29f54cceeb1469785", - "db64a13268ab452db1241d06f16d3d6a", - "80a53775d54e47b2922031cc4cd00548", - "e13013a9d20843bca6da8fe9f0fdb644", - "6f45e75849f749a1b5fad6e8f7879c8f", - "0ef4b41f8d0141859b7767d112596198", - "fecc8e75d8d643759123eaf2ff30fa2a", - "56eb4e5291404f3289196d526fef524f", - "7fb491d5b8f34a2e842518c1e4ea4906", - "e02fc794fb3645bda6b802277a3e5c1a", - "f9a719112f53400782f97d0683862a5f", - "4480d4d37c2648e7bb5635109b5d4715", - "67ce0f56b38941749bdb580a830598c2", - "6667788f1fa44603a4a04ba8aa5e1cc9", - "4d162f2efc364cf793709f0b9c9dc888", - "e078f8d23e8449e7a9b5771e342febeb", - "a330873af8cf45eca95780c50260dbb9", - "61746c0aaf96444391d1a86ea7223cb0", - "2ad78a6887eb4a369bac75e2436758b9", - "78d3ab060eb64941bcba51fa3b32f493", - "1f1a0c0dc2b74d56b68a779ef944416d", - "1546b1d7383c4a25b2ab3782ff204cad", - "6588cde28032444980c4e314d6b9a648", - "3e8d434d8e524c8c9f3fe3359df4b327" - ] - }, - "id": "7lFgKfPS8Icy", - "outputId": "2f00b063-86bf-4ba8-fa5e-38d2f5a66462" - }, - "outputs": [], - "source": [ - "# Imports\n", - "from PIL import Image\n", - "import torch\n", - "from torchvision import transforms\n", - "from IPython.display import display\n", - "\n", - "import sys\n", - "sys.path.insert(0, \"../\")\n", - "from models.birefnet import BiRefNet\n", - "\n", - "\n", - "# Load Model\n", - "# Option 2 and Option 3 is better for local running -- we can modify codes locally.\n", - "\n", - "# # # Option 1: loading BiRefNet with weights:\n", - "# from transformers import AutoModelForImageSegmentation\n", - "# birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)\n", - "\n", - "# Option-2: loading weights with BiReNet codes:\n", - "birefnet = BiRefNet.from_pretrained(\n", - " [\n", - " 'zhengpeng7/BiRefNet',\n", - " 'zhengpeng7/BiRefNet-portrait',\n", - " 'zhengpeng7/BiRefNet-legacy', 'zhengpeng7/BiRefNet-DIS5K-TR_TEs', 'zhengpeng7/BiRefNet-DIS5K', 'zhengpeng7/BiRefNet-HRSOD', 'zhengpeng7/BiRefNet-COD',\n", - " 'zhengpeng7/BiRefNet_lite', # Modify the `bb` in `config.py` to `swin_v1_tiny`.\n", - " ][0]\n", - ")\n", - "\n", - "# # Option-3: Loading model and weights from local disk:\n", - "# from utils import check_state_dict\n", - "\n", - "# birefnet = BiRefNet(bb_pretrained=False)\n", - "# state_dict = torch.load('../BiRefNet-general-epoch_244.pth', map_location='cpu')\n", - "# state_dict = check_state_dict(state_dict)\n", - "# birefnet.load_state_dict(state_dict)\n", - "\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "\n", - "torch.set_float32_matmul_precision(['high', 'highest'][0])\n", - "\n", - "birefnet.to(device)\n", - "birefnet.eval()\n", - "print('BiRefNet is ready to use.')\n", - "\n", - "# Input Data\n", - "transform_image = transforms.Compose([\n", - " transforms.Resize((1024, 1024)),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "PECYekO53hrR", - "outputId": "73f47406-9d92-48b1-fe74-abbb5b83c7a8" - }, - "outputs": [], - "source": [ - "import os\n", - "from glob import glob\n", - "from image_proc import refine_foreground\n", - "\n", - "src_dir = '../images_todo'\n", - "image_paths = glob(os.path.join(src_dir, '*'))\n", - "dst_dir = '../predictions'\n", - "os.makedirs(dst_dir, exist_ok=True)\n", - "for image_path in image_paths:\n", - " print('Processing {} ...'.format(image_path))\n", - " image = Image.open(image_path)\n", - " input_images = transform_image(image).unsqueeze(0).to(device)\n", - "\n", - " # Prediction\n", - " with torch.no_grad():\n", - " preds = birefnet(input_images)[-1].sigmoid().cpu()\n", - " pred = preds[0].squeeze()\n", - "\n", - " # Show Results\n", - " pred_pil = transforms.ToPILImage()(pred)\n", - " pred_pil.resize(image.size).save(image_path.replace(src_dir, dst_dir))\n", - "\n", - " # Visualize the last sample:\n", - " # Scale proportionally with max length to 1024 for faster showing\n", - " scale_ratio = 1024 / max(image.size)\n", - " scaled_size = (int(image.size[0] * scale_ratio), int(image.size[1] * scale_ratio))\n", - "\n", - " image_masked = refine_foreground(image, pred_pil)\n", - " image_masked.putalpha(pred_pil.resize(image.size))\n", - "\n", - "display(image.resize(scaled_size))\n", - "display(pred_pil.resize(scaled_size))\n", - "display(image_masked.resize(scaled_size))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.19" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "044920317fb64cd088d49f334930b886": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "055ec5d0d4da4d1b8df33ddb22d6b55b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ebf0b02bc7734ebea9a846233a1a6ec4", - "max": 298, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_fb3e80729a214bc5993b62ee04d7c58d", - "value": 298 - } - }, - "0ef4b41f8d0141859b7767d112596198": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_4480d4d37c2648e7bb5635109b5d4715", - "placeholder": "​", - "style": "IPY_MODEL_67ce0f56b38941749bdb580a830598c2", - "value": " 91.3k/91.3k [00:00<00:00, 1.97MB/s]" - } - }, - "1546b1d7383c4a25b2ab3782ff204cad": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "1848689bf6a14235b508647d1bdb7f69": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_044920317fb64cd088d49f334930b886", - "placeholder": "​", - "style": "IPY_MODEL_a8416e76428b46c092c7afbf7129b5f9", - "value": "BiRefNet_config.py: 100%" - } - }, - "1f1a0c0dc2b74d56b68a779ef944416d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2ad78a6887eb4a369bac75e2436758b9": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3b0ddb32ffa442aab3b02a22432bb233": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a124da94c5b143a29d9420eba3954859", - "placeholder": "​", - "style": "IPY_MODEL_96485cf384484a709e4c32b74e0223af", - "value": " 413/413 [00:00<00:00, 5.97kB/s]" - } - }, - "3c9b1a3ed2f64ed58a29083280098292": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_cf459ec049624ba29f54cceeb1469785", - "placeholder": "​", - "style": "IPY_MODEL_db64a13268ab452db1241d06f16d3d6a", - "value": " 298/298 [00:00<00:00, 9.24kB/s]" - } - }, - "3e8d434d8e524c8c9f3fe3359df4b327": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4480d4d37c2648e7bb5635109b5d4715": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4d162f2efc364cf793709f0b9c9dc888": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2ad78a6887eb4a369bac75e2436758b9", - "placeholder": "​", - "style": "IPY_MODEL_78d3ab060eb64941bcba51fa3b32f493", - "value": "model.safetensors: 100%" - } - }, - "4e5fe4296291455f88ae0e5f257c398e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "56eb4e5291404f3289196d526fef524f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "60d3c15d4b944546992d3517bb73912a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "61746c0aaf96444391d1a86ea7223cb0": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6588cde28032444980c4e314d6b9a648": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6667788f1fa44603a4a04ba8aa5e1cc9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_4d162f2efc364cf793709f0b9c9dc888", - "IPY_MODEL_e078f8d23e8449e7a9b5771e342febeb", - "IPY_MODEL_a330873af8cf45eca95780c50260dbb9" - ], - "layout": "IPY_MODEL_61746c0aaf96444391d1a86ea7223cb0" - } - }, - "67ce0f56b38941749bdb580a830598c2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "688a14cd34704e4ea2e261f6619449ee": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6f45e75849f749a1b5fad6e8f7879c8f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e02fc794fb3645bda6b802277a3e5c1a", - "max": 91316, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f9a719112f53400782f97d0683862a5f", - "value": 91316 - } - }, - "7745af46aa694f5b8ff0ec8fbd72025f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "78d3ab060eb64941bcba51fa3b32f493": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "7bef3ce58df040c986a74eb5334f42c9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1848689bf6a14235b508647d1bdb7f69", - "IPY_MODEL_055ec5d0d4da4d1b8df33ddb22d6b55b", - "IPY_MODEL_3c9b1a3ed2f64ed58a29083280098292" - ], - "layout": "IPY_MODEL_b2b95dd8b75d4625ad1bd02a461f304d" - } - }, - "7d19deaab4c845eea4705567bdc65d60": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a8941bdff0984189be91fab5bfe1c52c", - "IPY_MODEL_b6be81c6cc1e4608a88c785c448bfaa4", - "IPY_MODEL_3b0ddb32ffa442aab3b02a22432bb233" - ], - "layout": "IPY_MODEL_688a14cd34704e4ea2e261f6619449ee" - } - }, - "7fb491d5b8f34a2e842518c1e4ea4906": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "80a53775d54e47b2922031cc4cd00548": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e13013a9d20843bca6da8fe9f0fdb644", - "IPY_MODEL_6f45e75849f749a1b5fad6e8f7879c8f", - "IPY_MODEL_0ef4b41f8d0141859b7767d112596198" - ], - "layout": "IPY_MODEL_fecc8e75d8d643759123eaf2ff30fa2a" - } - }, - "8a54f72e65a24d7a93852aca8ecae0a2": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "96485cf384484a709e4c32b74e0223af": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a124da94c5b143a29d9420eba3954859": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a330873af8cf45eca95780c50260dbb9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_6588cde28032444980c4e314d6b9a648", - "placeholder": "​", - "style": "IPY_MODEL_3e8d434d8e524c8c9f3fe3359df4b327", - "value": " 885M/885M [00:05<00:00, 192MB/s]" - } - }, - "a8416e76428b46c092c7afbf7129b5f9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a8941bdff0984189be91fab5bfe1c52c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_8a54f72e65a24d7a93852aca8ecae0a2", - "placeholder": "​", - "style": "IPY_MODEL_7745af46aa694f5b8ff0ec8fbd72025f", - "value": "config.json: 100%" - } - }, - "b2b95dd8b75d4625ad1bd02a461f304d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b6be81c6cc1e4608a88c785c448bfaa4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_4e5fe4296291455f88ae0e5f257c398e", - "max": 413, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_60d3c15d4b944546992d3517bb73912a", - "value": 413 - } - }, - "cf459ec049624ba29f54cceeb1469785": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "db64a13268ab452db1241d06f16d3d6a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e02fc794fb3645bda6b802277a3e5c1a": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e078f8d23e8449e7a9b5771e342febeb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1f1a0c0dc2b74d56b68a779ef944416d", - "max": 884878856, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_1546b1d7383c4a25b2ab3782ff204cad", - "value": 884878856 - } - }, - "e13013a9d20843bca6da8fe9f0fdb644": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_56eb4e5291404f3289196d526fef524f", - "placeholder": "​", - "style": "IPY_MODEL_7fb491d5b8f34a2e842518c1e4ea4906", - "value": "birefnet.py: 100%" - } - }, - "ebf0b02bc7734ebea9a846233a1a6ec4": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f9a719112f53400782f97d0683862a5f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fb3e80729a214bc5993b62ee04d7c58d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fecc8e75d8d643759123eaf2ff30fa2a": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/py/BiRefNet/tutorials/BiRefNet_pth2onnx.ipynb b/py/BiRefNet/tutorials/BiRefNet_pth2onnx.ipynb deleted file mode 100644 index c087b08e..00000000 --- a/py/BiRefNet/tutorials/BiRefNet_pth2onnx.ipynb +++ /dev/null @@ -1,312 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "LTj2A0RUQFNo" - }, - "source": [ - "# Convert our BiRefNet weights to onnx format.\n", - "\n", - "> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.\n", - "\n", - "> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample \n", - "> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb\n", - "\n", - "+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "781JHjLJmveh" - }, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "\n", - "weights_file = 'BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth' # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open('config.py') as fp:\n", - " file_lines = fp.read()\n", - "if 'swin_v1_tiny' in weights_file:\n", - " print('Set `swin_v1_tiny` as the backbone.')\n", - " file_lines = file_lines.replace(\n", - " '''\n", - " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", - " ][6]\n", - " ''',\n", - " '''\n", - " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", - " ][3]\n", - " ''',\n", - " )\n", - " with open('config.py', mode=\"w\") as fp:\n", - " fp.write(file_lines)\n", - "else:\n", - " file_lines = file_lines.replace(\n", - " '''\n", - " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", - " ][3]\n", - " ''',\n", - " '''\n", - " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", - " ][6]\n", - " ''',\n", - " )\n", - " with open('config.py', mode=\"w\") as fp:\n", - " fp.write(file_lines)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7lFgKfPS8Icy" - }, - "outputs": [], - "source": [ - "from utils import check_state_dict\n", - "from models.birefnet import BiRefNet\n", - "\n", - "\n", - "birefnet = BiRefNet(bb_pretrained=False)\n", - "state_dict = torch.load('./{}'.format(weights_file), map_location=device)\n", - "state_dict = check_state_dict(state_dict)\n", - "birefnet.load_state_dict(state_dict)\n", - "\n", - "torch.set_float32_matmul_precision(['high', 'highest'][0])\n", - "\n", - "birefnet.to(device)\n", - "_ = birefnet.eval()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JVgJAdgxQVJW" - }, - "source": [ - "# Process deform_conv2d in the conversion to ONNX" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vJiZv0L75kTe" - }, - "outputs": [], - "source": [ - "from torchvision.ops.deform_conv import DeformConv2d\n", - "import deform_conv2d_onnx_exporter\n", - "\n", - "# register deform_conv2d operator\n", - "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", - "\n", - "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n", - " input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n", - "\n", - " input_layer_names = ['input_image']\n", - " output_layer_names = ['output_image']\n", - "\n", - " torch.onnx.export(\n", - " net,\n", - " input,\n", - " file_name,\n", - " verbose=False,\n", - " opset_version=17,\n", - " input_names=input_layer_names,\n", - " output_names=output_layer_names,\n", - " )\n", - "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-eU-g40P1zS-" - }, - "source": [ - "# Load ONNX weights and do the inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LZ4HVqcoDvto" - }, - "outputs": [], - "source": [ - "from PIL import Image\n", - "from torchvision import transforms\n", - "\n", - "\n", - "transform_image = transforms.Compose([\n", - " transforms.Resize((1024, 1024)),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", - "])\n", - "\n", - "imagepath = './Helicopter-HR.jpg'\n", - "image = Image.open(imagepath)\n", - "input_images = transform_image(image).unsqueeze(0).to(device)\n", - "input_images_numpy = input_images.cpu().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rwzdKX1EfYkd" - }, - "outputs": [], - "source": [ - "import onnxruntime\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n", - "onnx_session = onnxruntime.InferenceSession(\n", - " weights_file.replace('.pth', '.onnx'),\n", - " providers=providers\n", - ")\n", - "input_name = onnx_session.get_inputs()[0].name\n", - "print(onnxruntime.get_device(), onnx_session.get_providers())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DJVtxZUZum4-" - }, - "outputs": [], - "source": [ - "from time import time\n", - "import matplotlib.pyplot as plt\n", - "\n", - "time_st = time()\n", - "pred_onnx = torch.tensor(\n", - " onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]\n", - ").squeeze(0).sigmoid().cpu()\n", - "print(time() - time_st)\n", - "\n", - "plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with torch.no_grad():\n", - " preds = birefnet(input_images)[-1].sigmoid().cpu()\n", - "plt.imshow(preds.squeeze(), cmap='gray'); plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "diff = abs(preds - pred_onnx)\n", - "print('sum(diff):', diff.sum())\n", - "plt.imshow((diff).squeeze(), cmap='gray'); plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qzYHflt92Bjd" - }, - "source": [ - "# Efficiency Comparison between .pth and .onnx" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "A5IYfT-uzphA", - "outputId": "2999e345-950e-41b3-ddd3-9f58a71a3f21" - }, - "outputs": [], - "source": [ - "%%timeit\n", - "with torch.no_grad():\n", - " preds = birefnet(input_images)[-1].sigmoid().cpu()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "G0Ul4rfNg1za" - }, - "outputs": [], - "source": [ - "%%timeit\n", - "pred_onnx = torch.tensor(\n", - " onnx_session.run(None, {input_name: input_images_numpy})[-1]\n", - ").squeeze(0).sigmoid().cpu()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/py/BiRefNet/utils.py b/py/BiRefNet/utils.py deleted file mode 100644 index 1b437546..00000000 --- a/py/BiRefNet/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -import os -import torch -from torchvision import transforms -import numpy as np -import random -import cv2 -from PIL import Image - - -def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): - if color_type.lower() == 'rgb': - image = cv2.imread(path) - elif color_type.lower() == 'gray': - image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) - else: - print('Select the color_type to return, either to RGB or gray image.') - return - if size: - image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) - if color_type.lower() == 'rgb': - image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB') - else: - image = Image.fromarray(image).convert('L') - return image - - - -def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'): - for k, v in list(state_dict.items()): - if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) - return state_dict - - -def generate_smoothed_gt(gts): - epsilon = 0.001 - new_gts = (1-epsilon)*gts+epsilon/2 - return new_gts - - -class Logger(): - def __init__(self, path="log.txt"): - self.logger = logging.getLogger('BiRefNet') - self.file_handler = logging.FileHandler(path, "w") - self.stdout_handler = logging.StreamHandler() - self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) - self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) - self.logger.addHandler(self.file_handler) - self.logger.addHandler(self.stdout_handler) - self.logger.setLevel(logging.INFO) - self.logger.propagate = False - - def info(self, txt): - self.logger.info(txt) - - def close(self): - self.file_handler.close() - self.stdout_handler.close() - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0.0 - self.avg = 0.0 - self.sum = 0.0 - self.count = 0.0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def save_checkpoint(state, path, filename="latest.pth"): - torch.save(state, os.path.join(path, filename)) - - -def save_tensor_img(tenor_im, path): - im = tenor_im.cpu().clone() - im = im.squeeze(0) - tensor2pil = transforms.ToPILImage() - im = tensor2pil(im) - im.save(path) - - -def set_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.deterministic = True diff --git a/py/BiRefNet_v2 b/py/BiRefNet_v2 new file mode 160000 index 00000000..a3e781d6 --- /dev/null +++ b/py/BiRefNet_v2 @@ -0,0 +1 @@ +Subproject commit a3e781d6178021b22b6c8de0f66f18b8d33ca38e diff --git a/py/birefnet_ultra_v2.py b/py/birefnet_ultra_v2.py index 370ede13..68d7e330 100644 --- a/py/birefnet_ultra_v2.py +++ b/py/birefnet_ultra_v2.py @@ -5,8 +5,8 @@ import tqdm from .imagefunc import * from comfy.utils import ProgressBar -# sys.path.append(os.path.join(os.path.dirname(__file__), 'BiRefNet')) -from .BiRefNet.models.birefnet import BiRefNet +sys.path.append(os.path.join(os.path.dirname(__file__), 'BiRefNet_v2')) + def get_models(): model_path = os.path.join(folder_paths.models_dir, 'BiRefNet', 'pth') @@ -42,7 +42,8 @@ def INPUT_TYPES(s): CATEGORY = '😺dzNodes/LayerMask' def load_birefnet_model(self, model): - from .BiRefNet.utils import check_state_dict + from .BiRefNet_v2.models.birefnet import BiRefNet + from .BiRefNet_v2.utils import check_state_dict model_dict = get_models() self.birefnet = BiRefNet(bb_pretrained=False) self.state_dict = torch.load(model_dict[model], map_location='cpu', weights_only=True) diff --git a/pyproject.toml b/pyproject.toml index 15d6954f..4c88c00e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_layerstyle" description = "A set of nodes for ComfyUI it generate image like Adobe Photoshop's Layer Style. the Drop Shadow is first completed node, and follow-up work is in progress." -version = "1.0.65" +version = "1.0.66" license = "MIT" dependencies = ["numpy", "pillow", "torch", "matplotlib", "Scipy", "scikit_image", "scikit_learn", "opencv-contrib-python", "pymatting", "segment_anything", "timm", "addict", "yapf", "colour-science", "wget", "mediapipe", "loguru", "typer_config", "fastapi", "rich", "google-generativeai", "diffusers", "omegaconf", "tqdm", "transformers", "kornia", "image-reward", "ultralytics", "blend_modes", "blind-watermark", "qrcode", "pyzbar", "transparent-background", "huggingface_hub", "accelerate", "bitsandbytes", "torchscale", "wandb", "hydra-core", "psd-tools", "inference-cli[yolo-world]", "inference-gpu[yolo-world]", "onnxruntime"]