Skip to content

Commit

Permalink
Include jaxtyping to allow for Tensor/LinearOperator typehints with s…
Browse files Browse the repository at this point in the history
…izes.

Using the same trick in LinearOperator, sized Tensor/LinearOperator
typehints are automatically included in the documentation.
  • Loading branch information
gpleiss committed Jul 2, 2024
1 parent 07fa68e commit c19e833
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 46 deletions.
1 change: 1 addition & 0 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ requirements:
- python>=3.8
- pytorch>=1.11
- scikit-learn
- jaxtyping>=0.2.9
- linear_operator>=0.5.2

test:
Expand Down
112 changes: 67 additions & 45 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import warnings
from typing import ForwardRef

import jaxtyping


def read(*names, **kwargs):
with io.open(
Expand Down Expand Up @@ -112,7 +114,8 @@ def find_version(*file_paths):
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None),
"linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", "linear_operator_objects.inv"),
# The local mapping here is temporary until we get a new release of linear_operator
}

# Disable docstring inheritance
Expand Down Expand Up @@ -237,41 +240,79 @@ def find_version(*file_paths):
]


# -- Function to format typehints ----------------------------------------------
# -- Functions to format typehints ----------------------------------------------
# Adapted from
# https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py


# Helper function
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
# For external classes, the format will be e.g. "torch.Tensor"
# For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
def _convert_internal_and_external_class_to_strings(annotation):
module = annotation.__module__ + "."
if module.split(".")[0] == "gpytorch":
module = "~" + module
elif module == "linear_operator.operators._linear_operator.":
module = "~linear_operator."
elif module == "builtins.":
module = ""
res = f"{module}{annotation.__name__}"
return res


# Convert jaxtyping dimensions into strings
def _dim_to_str(dim):
if isinstance(dim, jaxtyping.array_types._NamedVariadicDim):
return "..."
elif isinstance(dim, jaxtyping.array_types._FixedDim):
res = str(dim.size)
if dim.broadcastable:
res = "#" + res
return res
elif isinstance(dim, jaxtyping.array_types._SymbolicDim):
expr = code_deparse(dim.expr).text.strip().split("return ")[1]
return f"({expr})"
elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis
return "..."
else:
res = str(dim.name)
if dim.broadcastable:
res = "#" + res
return res


# Function to format type hints
def _process(annotation, config):
"""
A function to convert a type/rtype typehint annotation into a :type:/:rtype: string.
This function is a bit hacky, and specific to the type annotations we use most frequently.
This function is recursive.
"""
# Simple/base case: any string annotation is ready to go
if type(annotation) == str:
return annotation

# Jaxtyping: shaped tensors or linear operator
elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__:
cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type)
shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims])
return f"{cls_annotation} ({shape})"

# Convert Ellipsis into "..."
elif annotation == Ellipsis:
return "..."

# Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings
# For external classes, the format will be e.g. "torch.Tensor"
# For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator"
# For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel"
# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings
elif hasattr(annotation, "__name__"):
module = annotation.__module__ + "."
if module.split(".")[0] == "linear_operator":
if annotation.__name__.endswith("LinearOperator"):
module = "~linear_operator."
elif annotation.__name__.endswith("LinearOperator"):
module = "~linear_operator.operators."
else:
module = "~" + module
elif module.split(".")[0] == "gpytorch":
module = "~" + module
elif module == "builtins.":
module = ""
res = f"{module}{annotation.__name__}"
res = _convert_internal_and_external_class_to_strings(annotation)

elif str(annotation).startswith("typing.Callable"):
if len(annotation.__args__) == 2:
res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]"
else:
res = "Callable"

# Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*"
# Also, convert any Optional[*A*] into "*A*, optional"
Expand All @@ -291,33 +332,14 @@ def _process(annotation, config):
args = list(annotation.__args__)
res = "(" + ", ".join(_process(arg, config) for arg in args) + ")"

# Convert any List[*A*] into "list(*A*)"
elif str(annotation).startswith("typing.List"):
arg = annotation.__args__[0]
res = "list(" + _process(arg, config) + ")"

# Convert any List[*A*] into "list(*A*)"
elif str(annotation).startswith("typing.Dict"):
res = str(annotation)

# Convert any Iterable[*A*] into "iterable(*A*)"
elif str(annotation).startswith("typing.Iterable"):
arg = annotation.__args__[0]
res = "iterable(" + _process(arg, config) + ")"

# Handle "Callable"
elif str(annotation).startswith("typing.Callable"):
res = "callable"

# Handle "Any"
elif str(annotation).startswith("typing.Any"):
res = ""
# Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]"
elif str(annotation).startswith("typing.Iterable") or str(annotation).startswith("typing.List"):
arg = list(annotation.__args__)[0]
res = f"[{_process(arg, config)}, ...]"

# Special cases for forward references.
# This is brittle, as it only contains case for a select few forward refs
# All others that aren't caught by this are handled by the default case
elif isinstance(annotation, ForwardRef):
res = str(annotation.__forward_arg__)
# Callable typing annotation
elif str(annotation).startswith("typing."):
return str(annotation)[7:]

# For everything we didn't catch: use the simplist string representation
else:
Expand Down
Binary file added docs/source/linear_operator_objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore = E203, F403, F405, E731, E741, W503, W605
ignore = E203, E731, E741, F403, F405, F722, W503, W605
exclude =
build,examples

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def find_version(*file_paths):

torch_min = "1.11"
install_requires = [
"jaxtyping>=0.2.9",
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
"scikit-learn",
"scipy",
Expand Down

0 comments on commit c19e833

Please sign in to comment.