|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import functools |
14 | 15 | import inspect
|
15 | 16 | import logging
|
16 | 17 | import os
|
17 |
| -from types import MethodType |
18 | 18 | from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar
|
19 | 19 |
|
20 | 20 | from lightning_utilities.core.imports import RequirementCache
|
@@ -108,18 +108,23 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]):
|
108 | 108 | """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
|
109 | 109 | instead of a class type."""
|
110 | 110 |
|
111 |
| - def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None: |
| 111 | + def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None: |
112 | 112 | self.method = method
|
113 | 113 |
|
114 | 114 | def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]:
|
115 |
| - # Workaround for https://github.com/pytorch/pytorch/issues/67146 |
116 |
| - is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) |
117 |
| - if instance is not None and not is_scripting: |
118 |
| - raise TypeError( |
119 |
| - f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." |
120 |
| - " Please call it on the class type and make sure the return value is used." |
121 |
| - ) |
122 |
| - return MethodType(self.method, cls) |
| 115 | + # The wrapper ensures that the method can be inspected, but not called on an instance |
| 116 | + @functools.wraps(self.method) |
| 117 | + def wrapper(*args: Any, **kwargs: Any) -> _R_co: |
| 118 | + # Workaround for https://github.com/pytorch/pytorch/issues/67146 |
| 119 | + is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) |
| 120 | + if instance is not None and not is_scripting: |
| 121 | + raise TypeError( |
| 122 | + f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." |
| 123 | + " Please call it on the class type and make sure the return value is used." |
| 124 | + ) |
| 125 | + return self.method(cls, *args, **kwargs) |
| 126 | + |
| 127 | + return wrapper |
123 | 128 |
|
124 | 129 |
|
125 | 130 | # trick static type checkers into thinking it's a @classmethod
|
|
0 commit comments