Skip to content

Commit 3044e83

Browse files
authored
_restricted_classmethod: add wrapper, to allow inspection (#19332)
1 parent b1127e3 commit 3044e83

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777
- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))
7878

7979

80+
- Fixed issue where the `_restricted_classmethod_impl` would incorrectly raise a TypeError on inspection rather than on call ([#19332](https://github.com/Lightning-AI/lightning/pull/19332))
81+
82+
8083
## [2.1.3] - 2023-12-21
8184

8285
### Changed

src/lightning/pytorch/utilities/model_helpers.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import functools
1415
import inspect
1516
import logging
1617
import os
17-
from types import MethodType
1818
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar
1919

2020
from lightning_utilities.core.imports import RequirementCache
@@ -108,18 +108,23 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]):
108108
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
109109
instead of a class type."""
110110

111-
def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None:
111+
def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None:
112112
self.method = method
113113

114114
def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]:
115-
# Workaround for https://github.com/pytorch/pytorch/issues/67146
116-
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
117-
if instance is not None and not is_scripting:
118-
raise TypeError(
119-
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
120-
" Please call it on the class type and make sure the return value is used."
121-
)
122-
return MethodType(self.method, cls)
115+
# The wrapper ensures that the method can be inspected, but not called on an instance
116+
@functools.wraps(self.method)
117+
def wrapper(*args: Any, **kwargs: Any) -> _R_co:
118+
# Workaround for https://github.com/pytorch/pytorch/issues/67146
119+
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
120+
if instance is not None and not is_scripting:
121+
raise TypeError(
122+
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
123+
" Please call it on the class type and make sure the return value is used."
124+
)
125+
return self.method(cls, *args, **kwargs)
126+
127+
return wrapper
123128

124129

125130
# trick static type checkers into thinking it's a @classmethod

tests/tests_pytorch/utilities/test_model_helpers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import logging
1516

1617
import pytest
@@ -66,10 +67,12 @@ def cmethod(cls):
6667

6768

6869
def test_restricted_classmethod():
70+
restricted_method = RestrictedClass().restricted_cmethod # no exception when getting restricted method
71+
6972
with pytest.raises(TypeError, match="cannot be called on an instance"):
70-
RestrictedClass().restricted_cmethod()
73+
restricted_method()
7174

72-
RestrictedClass.restricted_cmethod() # no exception
75+
_ = inspect.getmembers(RestrictedClass()) # no exception on inspecting instance
7376

7477

7578
def test_module_mode():

0 commit comments

Comments
 (0)