|
1 | 1 | # Neural-Backed Decision Trees
|
2 | 2 |
|
3 |
| -[Project Page](http://nbdt.alvinwan.com) // [Paper]() // [No-code Web Demo](http://nbdt.alvinwan.com/demo/) // [Colab Notebook](https://colab.research.google.com/github/alvinwan/neural-backed-decision-trees/blob/master/examples/load_pretrained_nbdts.ipynb) |
| 3 | +[Project Page](http://nbdt.alvinwan.com) // [Paper](http://nbdt.alvinwan.com/paper/) // [No-code Web Demo](http://nbdt.alvinwan.com/demo/) // [Colab Notebook](https://colab.research.google.com/github/alvinwan/neural-backed-decision-trees/blob/master/examples/load_pretrained_nbdts.ipynb) |
4 | 4 |
|
5 |
| -Run decision trees that achieve state-of-the-art accuracy for explainable models on CIFAR10, CIFAR100, TinyImagenet200, and ImageNet. NBDTs achieve accuracies within 1% of the original neural network on CIFAR10, CIFAR100, and TinyImagenet200 with the recently state-of-the-art WideResNet; and within 2% of the original neural network on Imagenet, using recently state-of-the-art EfficientNet. |
| 5 | +*By Alvin Wan, \*Lisa Dunlap, \*Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez* |
| 6 | +<sub>*denotes equal contribution</sub> |
6 | 7 |
|
7 |
| -<sub>**NBDT Accuracy per dataset**: CIFAR10 (97.57%), CIFAR100 (82.87%), TinyImagenet200 (66.66%), ImageNet (75.13%). [See more results](#results)</sub> |
| 8 | +Run decision trees that achieve state-of-the-art accuracy for explainable models on CIFAR10, CIFAR100, TinyImagenet200, and ImageNet. NBDTs achieve accuracies within 1% of the original neural network on CIFAR10, CIFAR100, and TinyImagenet200 with the recently state-of-the-art WideResNet; and within 2% of the original neural network on ImageNet, using recently state-of-the-art EfficientNet. We attain an ImageNet top-1 accuracy of 75.13%. |
8 | 9 |
|
9 | 10 | **Table of Contents**
|
10 | 11 |
|
11 | 12 | - [Quickstart: Running and loading NBDTs](#quickstart)
|
12 | 13 | - [Convert your own neural network into a decision tree](#convert-neural-networks-to-decision-trees)
|
13 | 14 | - [Training and evaluation](#training-and-evaluation)
|
14 | 15 | - [Results](#results)
|
15 |
| -- [Developing](#developing) |
| 16 | +- [Setup for Development](#setup-for-development) |
| 17 | +- [Citation](#citation) |
16 | 18 |
|
17 | 19 | 
|
18 | 20 |
|
@@ -95,7 +97,7 @@ Note `torchvision.models.resnet18` only supports 224x224 input. However, `nbdt.m
|
95 | 97 | **To convert your neural network** into a neural-backed decision tree, perform the following 3 steps:
|
96 | 98 |
|
97 | 99 | 1. **First**, if you haven't already, pip install the `nbdt` utility: `pip install nbdt`
|
98 |
| -2. **Second, during training**, wrap your loss `criterion` with a custom NBDT loss. Below, we demonstrate the soft tree supervision loss on the CIFAR10 dataset. By default, we support CIFAR10, CIFAR100, TinyImagenet200, and Imagenet1000. |
| 100 | +2. **Second, during training**, wrap your loss `criterion` with a custom NBDT loss. Below, we demonstrate the soft tree supervision loss on the CIFAR10 dataset. By default, we support `CIFAR10`, `CIFAR100`, `TinyImagenet200`, and `Imagenet1000`. |
99 | 101 |
|
100 | 102 | ```python
|
101 | 103 | from nbdt.loss import SoftTreeSupLoss
|
@@ -372,9 +374,19 @@ python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn
|
372 | 374 |
|
373 | 375 | # Results
|
374 | 376 |
|
| 377 | +We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet, including methods that hinder interpretability by using impure leaves or a random forest. We report the baseline with the highest accuracy, of all these methods: Deep Neural Decision Forest (DNDF updated with ResNet18), Explainable Observer-Classifier (XOC), Deep ConvolutionalDecision Jungle (DCDJ), Network of Experts (NofE), Deep Decision Network(DDN), and Adaptive Neural Trees (ANT). |
375 | 378 |
|
| 379 | +| | CIFAR10 | CIFAR100 | TinyImagenet200 | ImageNet | |
| 380 | +|----------------------|---------|----------|-----------------|----------| |
| 381 | +| NBDT-S (Ours) | 97.57% | 82.87% | 66.66% | 75.13% | |
| 382 | +| NBDT-H (Ours) | 97.55% | 82.21% | 64.39% | 74.79% | |
| 383 | +| Best Pre-NBDT Acc | 94.32% | 76.24% | 44.56% | 61.29% | |
| 384 | +| Best Pre-NBDT Method | DNDF | NofE | DNDF | NofE | |
| 385 | +| Our improvement | 3.25% | 6.63% | 22.1% | 13.84% | |
376 | 386 |
|
377 |
| -# Developing |
| 387 | +As the last row denotes, we outperform all previous decision-tree-based methods by anywhere from 3% (CIFAR10) to 13%+ (ImageNet). Note that accuracies in our pretrained checkpoints for small to medium datasets (CIFAR10, CIFAR100, and TinyImagenet200) may fluctuate by 0.1-0.2%, as we retrained all models with the current public version of this repository. |
| 388 | + |
| 389 | +# Setup for Development |
378 | 390 |
|
379 | 391 | As discussed above, you can use the `nbdt` python library to integrate NBDT training into any existing training pipeline. However, if you wish to use the barebones training utilities here, refer to the following sections for adding custom models and datasets.
|
380 | 392 |
|
@@ -418,3 +430,18 @@ Without any modifications to `main.py`, you can use any image classification dat
|
418 | 430 | > nbdt-wnids --dataset=YourData10
|
419 | 431 | > ```
|
420 | 432 | > , where `YourData` is your dataset name. If a provided class name from `YourData.classes` does not exist in the WordNet corpus, the script will generate a fake wnid. This does not affect training but subsequent analysis scripts will be unable to provide WordNet-imputed node meanings.
|
| 433 | +
|
| 434 | +# Citation |
| 435 | +
|
| 436 | +If you find this work useful for your research, please cite our [paper](http://nbdt.alvinwan.com/paper/): |
| 437 | +
|
| 438 | +``` |
| 439 | +@article{wan2020nbdt, |
| 440 | + title={NBDT: Neural-Backed Decision Trees}, |
| 441 | + author={Alvin Wan and Lisa Dunlap and Daniel Ho and Jihan Yin and Scott Lee and Henry Jin and Suzanne Petryk and Sarah Adel Bargal and Joseph E. Gonzalez}, |
| 442 | + year={2020}, |
| 443 | + eprint={}, |
| 444 | + archivePrefix={arXiv}, |
| 445 | + primaryClass={cs.CV} |
| 446 | +} |
| 447 | +``` |
0 commit comments