Skip to content

Commit dba9b53

Browse files
authored
addedd citation, results, authors
1 parent 4545b23 commit dba9b53

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

README.md

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# Neural-Backed Decision Trees
22

3-
[Project Page](http://nbdt.alvinwan.com)  //  [Paper]()  //  [No-code Web Demo](http://nbdt.alvinwan.com/demo/)  //  [Colab Notebook](https://colab.research.google.com/github/alvinwan/neural-backed-decision-trees/blob/master/examples/load_pretrained_nbdts.ipynb)
3+
[Project Page](http://nbdt.alvinwan.com)  //  [Paper](http://nbdt.alvinwan.com/paper/)  //  [No-code Web Demo](http://nbdt.alvinwan.com/demo/)  //  [Colab Notebook](https://colab.research.google.com/github/alvinwan/neural-backed-decision-trees/blob/master/examples/load_pretrained_nbdts.ipynb)
44

5-
Run decision trees that achieve state-of-the-art accuracy for explainable models on CIFAR10, CIFAR100, TinyImagenet200, and ImageNet. NBDTs achieve accuracies within 1% of the original neural network on CIFAR10, CIFAR100, and TinyImagenet200 with the recently state-of-the-art WideResNet; and within 2% of the original neural network on Imagenet, using recently state-of-the-art EfficientNet.
5+
*By Alvin Wan, \*Lisa Dunlap, \*Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez*
6+
<sub>*denotes equal contribution</sub>
67

7-
<sub>**NBDT Accuracy per dataset**: CIFAR10 (97.57%), CIFAR100 (82.87%), TinyImagenet200 (66.66%), ImageNet (75.13%). [See more results](#results)</sub>
8+
Run decision trees that achieve state-of-the-art accuracy for explainable models on CIFAR10, CIFAR100, TinyImagenet200, and ImageNet. NBDTs achieve accuracies within 1% of the original neural network on CIFAR10, CIFAR100, and TinyImagenet200 with the recently state-of-the-art WideResNet; and within 2% of the original neural network on ImageNet, using recently state-of-the-art EfficientNet. We attain an ImageNet top-1 accuracy of 75.13%.
89

910
**Table of Contents**
1011

1112
- [Quickstart: Running and loading NBDTs](#quickstart)
1213
- [Convert your own neural network into a decision tree](#convert-neural-networks-to-decision-trees)
1314
- [Training and evaluation](#training-and-evaluation)
1415
- [Results](#results)
15-
- [Developing](#developing)
16+
- [Setup for Development](#setup-for-development)
17+
- [Citation](#citation)
1618

1719
![pipeline](https://user-images.githubusercontent.com/2068077/76384774-1ffb8480-631d-11ea-973f-7cac2a60bb10.jpg)
1820

@@ -95,7 +97,7 @@ Note `torchvision.models.resnet18` only supports 224x224 input. However, `nbdt.m
9597
**To convert your neural network** into a neural-backed decision tree, perform the following 3 steps:
9698

9799
1. **First**, if you haven't already, pip install the `nbdt` utility: `pip install nbdt`
98-
2. **Second, during training**, wrap your loss `criterion` with a custom NBDT loss. Below, we demonstrate the soft tree supervision loss on the CIFAR10 dataset. By default, we support CIFAR10, CIFAR100, TinyImagenet200, and Imagenet1000.
100+
2. **Second, during training**, wrap your loss `criterion` with a custom NBDT loss. Below, we demonstrate the soft tree supervision loss on the CIFAR10 dataset. By default, we support `CIFAR10`, `CIFAR100`, `TinyImagenet200`, and `Imagenet1000`.
99101

100102
```python
101103
from nbdt.loss import SoftTreeSupLoss
@@ -372,9 +374,19 @@ python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn
372374

373375
# Results
374376

377+
We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet, including methods that hinder interpretability by using impure leaves or a random forest. We report the baseline with the highest accuracy, of all these methods: Deep Neural Decision Forest (DNDF updated with ResNet18), Explainable Observer-Classifier (XOC), Deep ConvolutionalDecision Jungle (DCDJ), Network of Experts (NofE), Deep Decision Network(DDN), and Adaptive Neural Trees (ANT).
375378

379+
| | CIFAR10 | CIFAR100 | TinyImagenet200 | ImageNet |
380+
|----------------------|---------|----------|-----------------|----------|
381+
| NBDT-S (Ours) | 97.57% | 82.87% | 66.66% | 75.13% |
382+
| NBDT-H (Ours) | 97.55% | 82.21% | 64.39% | 74.79% |
383+
| Best Pre-NBDT Acc | 94.32% | 76.24% | 44.56% | 61.29% |
384+
| Best Pre-NBDT Method | DNDF | NofE | DNDF | NofE |
385+
| Our improvement | 3.25% | 6.63% | 22.1% | 13.84% |
376386

377-
# Developing
387+
As the last row denotes, we outperform all previous decision-tree-based methods by anywhere from 3% (CIFAR10) to 13%+ (ImageNet). Note that accuracies in our pretrained checkpoints for small to medium datasets (CIFAR10, CIFAR100, and TinyImagenet200) may fluctuate by 0.1-0.2%, as we retrained all models with the current public version of this repository.
388+
389+
# Setup for Development
378390

379391
As discussed above, you can use the `nbdt` python library to integrate NBDT training into any existing training pipeline. However, if you wish to use the barebones training utilities here, refer to the following sections for adding custom models and datasets.
380392

@@ -418,3 +430,18 @@ Without any modifications to `main.py`, you can use any image classification dat
418430
> nbdt-wnids --dataset=YourData10
419431
> ```
420432
> , where `YourData` is your dataset name. If a provided class name from `YourData.classes` does not exist in the WordNet corpus, the script will generate a fake wnid. This does not affect training but subsequent analysis scripts will be unable to provide WordNet-imputed node meanings.
433+
434+
# Citation
435+
436+
If you find this work useful for your research, please cite our [paper](http://nbdt.alvinwan.com/paper/):
437+
438+
```
439+
@article{wan2020nbdt,
440+
title={NBDT: Neural-Backed Decision Trees},
441+
author={Alvin Wan and Lisa Dunlap and Daniel Ho and Jihan Yin and Scott Lee and Henry Jin and Suzanne Petryk and Sarah Adel Bargal and Joseph E. Gonzalez},
442+
year={2020},
443+
eprint={},
444+
archivePrefix={arXiv},
445+
primaryClass={cs.CV}
446+
}
447+
```

0 commit comments

Comments
 (0)