Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Small fixes on the tutorials (#1256)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1256

Changes in this diffs include:
- `tutorials.md`: fixed the URL to the sparse logistic regression tutorial
  - I also verified that the actual tutorial runs fine on Colab :).
- Linear Regression notebook:
  - `sklearn.model_selection` has to be imported explicitly
  - cast tensor to numpy before feeding it to `pandas` to avoid triggering: `TypeError: Index(...) must be called with a collection of some kind, tensor([4.]) was passed`
- hierarchical modeling, hierarchical regression, item response theory, zero inflated count: they share the same issue of assuming that the utils package exist on a local path. They are fixed by importing from `beanmachine.tutorials.utils` instead
- Hierarchical modeling tutorial has another issue with rendering svg using relative path. This is fixed by linking to the [svg image in our Github repo](https://raw.githubusercontent.com/facebookresearch/beanmachine/main/tutorials/assets/baseball/complete-pooling-dag.svg) instead
- Item response theory also has an arviz plotting error (with bokh backend): `ValueError: failed to validate Title(id='6813', ...).text: expected a value of type str, got p(,) of type RVIdentifier`. I fixed this by explicitly cast the keys to string.

allow-large-files

Reviewed By: jpchen, neerajprad

Differential Revision: D33052942

fbshipit-source-id: bcfae0fb3cbeaf1acae91d28d0fb3db3a9f4c73b
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Dec 13, 2021
1 parent eae9921 commit b9b8319
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/overview/tutorials/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ This tutorial shows how to build a simple Bayesian model to deduce the line whic

### Sparse Logistic Regression

[Open in GitHub](https://github.com/facebookresearch/beanmachine/blob/main/tutorials/Tutorial_Implement_sparse_logistic_regression.ipynb)[Run in Google Colab](https://colab.research.google.com/github/facebookresearch/beanmachine/blob/main/tutorials/Tutorial_Implement_sparse_logistic_regression.ipynb)
[Open in GitHub](https://github.com/facebookresearch/beanmachine/blob/main/tutorials/Sparse_Logistic_Regression.ipynb)[Run in Google Colab](https://colab.research.google.com/github/facebookresearch/beanmachine/blob/main/tutorials/Sparse_Logistic_Regression.ipynb)

This tutorial demonstrates modeling and running inference on a sparse logistic regression model in Bean Machine. This tutorial showcases the inference techniques in Bean Machine, and applies the model to a public dataset to evaluate performance. This tutorial will also introduce the `@bm.functional` decorator, which can be used to deterministically transform random variables. This tutorial uses it for convenient post-processing.

Expand Down
8 changes: 4 additions & 4 deletions tutorials/Hierarchical_modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@
"import torch.distributions as dist\n",
"from beanmachine.ppl.inference import VerboseLevel\n",
"from beanmachine.ppl.model import RVIdentifier\n",
"from beanmachine.tutorials.utils import baseball\n",
"from bokeh.io import output_notebook\n",
"from bokeh.plotting import gridplot, show\n",
"from IPython.display import SVG\n",
"\n",
"from utils import baseball\n",
"smoke_test = ('SANDCASTLE_NEXUS' in os.environ or 'CI' in os.environ)"
]
},
Expand Down Expand Up @@ -1181,7 +1181,7 @@
}
],
"source": [
"SVG(\"assets/baseball/complete-pooling-dag.svg\")"
"SVG(\"https://raw.githubusercontent.com/facebookresearch/beanmachine/main/tutorials/assets/baseball/complete-pooling-dag.svg\")"
]
},
{
Expand Down Expand Up @@ -2269,7 +2269,7 @@
}
],
"source": [
"SVG(\"assets/baseball/no-pooling-dag.svg\")"
"SVG(\"https://raw.githubusercontent.com/facebookresearch/beanmachine/main/tutorials/assets/baseball/no-pooling-dag.svg\")"
]
},
{
Expand Down Expand Up @@ -3194,7 +3194,7 @@
}
],
"source": [
"SVG(\"assets/baseball/partial-pooling-dag.svg\")"
"SVG(\"https://raw.githubusercontent.com/facebookresearch/beanmachine/main/tutorials/assets/baseball/partial-pooling-dag.svg\")"
]
},
{
Expand Down
4 changes: 1 addition & 3 deletions tutorials/Hierarchical_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,10 @@
"import torch.distributions as dist\n",
"from beanmachine.ppl.inference import VerboseLevel\n",
"from beanmachine.ppl.model import RVIdentifier\n",
"from beanmachine.tutorials.utils import radon\n",
"from bokeh.io import output_notebook\n",
"from bokeh.plotting import gridplot, show\n",
"from torch import tensor\n",
"\n",
"# Local packages\n",
"from utils import radon\n",
"smoke_test = ('SANDCASTLE_NEXUS' in os.environ or 'CI' in os.environ)"
]
},
Expand Down
4 changes: 3 additions & 1 deletion tutorials/Item_Response_Theory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"import beanmachine.ppl as bm\n",
"from beanmachine.ppl.inference import VerboseLevel\n",
"from beanmachine.ppl.model import RVIdentifier\n",
"from beanmachine.tutorials.utils import nba\n",
"\n",
"smoke_test = ('SANDCASTLE_NEXUS' in os.environ or 'CI' in os.environ)"
]
Expand Down Expand Up @@ -1471,7 +1472,8 @@
}
],
"source": [
"az.plot_trace(basic_trace, kind=\"rank_bars\");"
"str_trace = basic_trace.rename({basic_model.p(): str(basic_model.p())})\n",
"az.plot_trace(str_trace, kind=\"rank_bars\");"
]
},
{
Expand Down
5 changes: 2 additions & 3 deletions tutorials/Linear_Regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn\n",
"import sklearn.model_selection\n",
"import torch\n",
"import torch.distributions as dist\n",
"from torch import tensor\n",
"import sklearn\n",
"\n",
"\n",
"\n",
Expand Down Expand Up @@ -829,7 +828,7 @@
" [2.5, 50, 97.5],\n",
" axis=1,\n",
" ).T,\n",
" index=x.view(-1),\n",
" index=x.view(-1).numpy(),\n",
" columns=['2.5%', '50%', '97.5%'],\n",
" )"
]
Expand Down
6 changes: 2 additions & 4 deletions tutorials/Zero_inflated_count_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,9 @@
"import torch.distributions as dist\n",
"from beanmachine.ppl.inference import VerboseLevel\n",
"from beanmachine.ppl.model import RVIdentifier\n",
"from beanmachine.tutorials.utils import hearts\n",
"from bokeh.io import output_notebook\n",
"from bokeh.plotting import gridplot, show\n",
"\n",
"# Local packages\n",
"from utils import hearts"
"from bokeh.plotting import gridplot, show\n"
]
},
{
Expand Down

0 comments on commit b9b8319

Please sign in to comment.