From c19e8331128a3aa27f42d3db2ef108ce2e28f444 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Mon, 1 Jul 2024 14:49:35 -0700 Subject: [PATCH] Include jaxtyping to allow for Tensor/LinearOperator typehints with sizes. Using the same trick in LinearOperator, sized Tensor/LinearOperator typehints are automatically included in the documentation. --- .conda/meta.yaml | 1 + docs/source/conf.py | 112 ++++++++++++++---------- docs/source/linear_operator_objects.inv | Bin 0 -> 2005 bytes setup.cfg | 2 +- setup.py | 1 + 5 files changed, 70 insertions(+), 46 deletions(-) create mode 100644 docs/source/linear_operator_objects.inv diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 79c4f9714..10b797008 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -19,6 +19,7 @@ requirements: - python>=3.8 - pytorch>=1.11 - scikit-learn + - jaxtyping>=0.2.9 - linear_operator>=0.5.2 test: diff --git a/docs/source/conf.py b/docs/source/conf.py index 0b872c98a..517a36486 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,8 @@ import warnings from typing import ForwardRef +import jaxtyping + def read(*names, **kwargs): with io.open( @@ -112,7 +114,8 @@ def find_version(*file_paths): intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "torch": ("https://pytorch.org/docs/stable/", None), - "linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None), + "linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", "linear_operator_objects.inv"), + # The local mapping here is temporary until we get a new release of linear_operator } # Disable docstring inheritance @@ -237,41 +240,79 @@ def find_version(*file_paths): ] -# -- Function to format typehints ---------------------------------------------- +# -- Functions to format typehints ---------------------------------------------- # Adapted from # https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py + + +# Helper function +# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings +# For external classes, the format will be e.g. "torch.Tensor" +# For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator" +def _convert_internal_and_external_class_to_strings(annotation): + module = annotation.__module__ + "." + if module.split(".")[0] == "gpytorch": + module = "~" + module + elif module == "linear_operator.operators._linear_operator.": + module = "~linear_operator." + elif module == "builtins.": + module = "" + res = f"{module}{annotation.__name__}" + return res + + +# Convert jaxtyping dimensions into strings +def _dim_to_str(dim): + if isinstance(dim, jaxtyping.array_types._NamedVariadicDim): + return "..." + elif isinstance(dim, jaxtyping.array_types._FixedDim): + res = str(dim.size) + if dim.broadcastable: + res = "#" + res + return res + elif isinstance(dim, jaxtyping.array_types._SymbolicDim): + expr = code_deparse(dim.expr).text.strip().split("return ")[1] + return f"({expr})" + elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis + return "..." + else: + res = str(dim.name) + if dim.broadcastable: + res = "#" + res + return res + + +# Function to format type hints def _process(annotation, config): """ A function to convert a type/rtype typehint annotation into a :type:/:rtype: string. This function is a bit hacky, and specific to the type annotations we use most frequently. + This function is recursive. """ # Simple/base case: any string annotation is ready to go if type(annotation) == str: return annotation + # Jaxtyping: shaped tensors or linear operator + elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__: + cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type) + shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims]) + return f"{cls_annotation} ({shape})" + # Convert Ellipsis into "..." elif annotation == Ellipsis: return "..." - # Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings - # For external classes, the format will be e.g. "torch.Tensor" - # For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator" - # For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel" + # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings elif hasattr(annotation, "__name__"): - module = annotation.__module__ + "." - if module.split(".")[0] == "linear_operator": - if annotation.__name__.endswith("LinearOperator"): - module = "~linear_operator." - elif annotation.__name__.endswith("LinearOperator"): - module = "~linear_operator.operators." - else: - module = "~" + module - elif module.split(".")[0] == "gpytorch": - module = "~" + module - elif module == "builtins.": - module = "" - res = f"{module}{annotation.__name__}" + res = _convert_internal_and_external_class_to_strings(annotation) + + elif str(annotation).startswith("typing.Callable"): + if len(annotation.__args__) == 2: + res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]" + else: + res = "Callable" # Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*" # Also, convert any Optional[*A*] into "*A*, optional" @@ -291,33 +332,14 @@ def _process(annotation, config): args = list(annotation.__args__) res = "(" + ", ".join(_process(arg, config) for arg in args) + ")" - # Convert any List[*A*] into "list(*A*)" - elif str(annotation).startswith("typing.List"): - arg = annotation.__args__[0] - res = "list(" + _process(arg, config) + ")" - - # Convert any List[*A*] into "list(*A*)" - elif str(annotation).startswith("typing.Dict"): - res = str(annotation) - - # Convert any Iterable[*A*] into "iterable(*A*)" - elif str(annotation).startswith("typing.Iterable"): - arg = annotation.__args__[0] - res = "iterable(" + _process(arg, config) + ")" - - # Handle "Callable" - elif str(annotation).startswith("typing.Callable"): - res = "callable" - - # Handle "Any" - elif str(annotation).startswith("typing.Any"): - res = "" + # Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]" + elif str(annotation).startswith("typing.Iterable") or str(annotation).startswith("typing.List"): + arg = list(annotation.__args__)[0] + res = f"[{_process(arg, config)}, ...]" - # Special cases for forward references. - # This is brittle, as it only contains case for a select few forward refs - # All others that aren't caught by this are handled by the default case - elif isinstance(annotation, ForwardRef): - res = str(annotation.__forward_arg__) + # Callable typing annotation + elif str(annotation).startswith("typing."): + return str(annotation)[7:] # For everything we didn't catch: use the simplist string representation else: diff --git a/docs/source/linear_operator_objects.inv b/docs/source/linear_operator_objects.inv new file mode 100644 index 0000000000000000000000000000000000000000..2de1dfa8b92314154970f9b7af8ec90e048e9d79 GIT binary patch literal 2005 zcmV;`2P*g@AX9K?X>NERX>N99Zgg*Qc_4OWa&u{KZXhxWBOp+6Z)#;@bUGkxX>Mg< za$j$7WpZJ3Z*mGFAXa5^b7^mGIv_AEH7+wQWMy_SI4fsiFkv}kG%^Y!AXI2&AaZ4G zVQFq;WpW^IW*~HEX>%ZEX>4U6X>%ZBZ*6dLWpi_7WFU2OX>MmAdTeQ8E(& z+cprr`&S5PuiiG-T#BYafgo9A+ufjOF9a>oHe-sENy;1VukTQ@y!tY2QNumimN;)7 zN5dg!JQ7?os);I?rdH{sIKIzSRtlE6Q_z&%ZEY@os;@79PL>PNmAd;<{r=hl zNd+|~d6)D@R!G+&?w%Y-q1|nJoPk5LteO+ETb0ejAsMF|C8J# z0N`0T>qxVLAq5J6(#$Yv@piQ?B(a(?B;0+;0hw$xgb*Z1v!w+RJ;YV7m9ivbsmcpw zINpu|?4TqXCVh#N&BIc`jcjp?Yd9W(f za0Ip79wYe5Ckf}HLJ4YF7T4K7S9RD3Q~0cvWa*CS2d%O)wegQG*EJ3n+#9JKTQ6^< zdf@Ic^u~Rz_tflSFN6E4m=IaVn=9}Y7cmq*nC!94<4)>trEG-o1iU{`n{MNhe^RUv z+`dGZz|*;`b{X7Q4Xd-%Y!2WcXyTyu&nGyjW07_e@_$UL2v3XL+WE4)P>8uA@6%0P z<)EtA3|L$(DlBhZ2h$4|KFTXsg&3!NwLr1-KBnoF40hUUp5|K#DqV7!sY3G$L5}26ede zdBeF(m`7eESOfC@ybFQyr~*GxJGNo?lQ$HiBMQ7k$5_UtIEi0jGGfYnz|o)3ACGPHiGHn<*c`I3FLs>wagO1iWgM)_}Wt@I-q+kMGdv*1za!3b5=KG zc{STz4ide9+}X#j`Uky#+m|keT5be$$ZM=WZrWeV<1J1B7^D46({K#4mdlOlyhhSo zy!sZ&d0@IB7R;fkOj)N6(n0b=ob1`^QJ900(;uz@s&2Ii0pk01vqk_y##ZGfSQbo> zbIFb6sXOGXf@!R&vBddES=S=ba3GgFz{qiLrWTC9ay_Xtl?b(|=8p3QK&(khr+tymg!j!row65zG+4b#1XK7UzDj@XW%wObR2$#1_zN@UA8^dfIXRoLjzOy?Kq_kXg z4{6nvt0%O2)T34Dmb$aVm%rNfZeiReJl_9%6HH3eE$er3-X?j9U#>j3(cr)%Eul97 zop!fiGoIv|PjXsR#2E!XgDBNk(;(q`j|h>hGoxTb)Z%xmk*?ke>R~V6u_j4dYn8(j z&UgyD9+eq!?<>xk-06d`BMSnlw}PVezh)rBUWW-ctB#2Mg$2CqVBN6vmkYX5rGM*)TX-X6#1=CFqI?2i9Dz@33v5v47C}x*)oYOrSU}n;W&%eI1zUshoG{wzBTFo&0%PVP?_7x}xOa z2wRx!>M}q!-F5}*YqiW{egA=O*y3D1Y0>H&VqkRrYjLiTb;foIXUSvpGaEei()89q nb)gpfR^97WuD-6mv($oMb#iD6o=0.2.9", "mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4 "scikit-learn", "scipy",