From 8253e5da975d5b8692b03a3346addd02ff1c0ed1 Mon Sep 17 00:00:00 2001 From: Ana Canizares Date: Tue, 5 Mar 2024 17:35:31 +0100 Subject: [PATCH 1/3] Make TypeEngine covariant --- lib/sqlalchemy/sql/type_api.py | 40 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index a56911fb9a1..355c0260877 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -116,7 +116,7 @@ def __call__( ) -> TypeEngine.Comparator[_T]: ... -class TypeEngine(Visitable, Generic[_T]): +class TypeEngine(Visitable, Generic[_T_co]): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -359,7 +359,7 @@ def copy_value(self, value: Any) -> Any: def literal_processor( self, dialect: Dialect - ) -> Optional[_LiteralProcessorType[_T]]: + ) -> Optional[_LiteralProcessorType[_T_co]]: """Return a conversion function for processing literal values that are to be rendered directly without using binds. @@ -396,7 +396,7 @@ class explicitly. def bind_processor( self, dialect: Dialect - ) -> Optional[_BindProcessorType[_T]]: + ) -> Optional[_BindProcessorType[_T_co]]: """Return a conversion function for processing bind values. Returns a callable which will receive a bind parameter value @@ -432,7 +432,7 @@ class explicitly. def result_processor( self, dialect: Dialect, coltype: object - ) -> Optional[_ResultProcessorType[_T]]: + ) -> Optional[_ResultProcessorType[_T_co]]: """Return a conversion function for processing result row values. Returns a callable which will receive a result row column @@ -468,8 +468,8 @@ class explicitly. return None def column_expression( - self, colexpr: ColumnElement[_T] - ) -> Optional[ColumnElement[_T]]: + self, colexpr: ColumnElement[_T_co] + ) -> Optional[ColumnElement[_T_co]]: """Given a SELECT column expression, return a wrapping SQL expression. This is typically a SQL function that wraps a column expression @@ -527,8 +527,8 @@ def _has_column_expression(self) -> bool: ) def bind_expression( - self, bindvalue: BindParameter[_T] - ) -> Optional[ColumnElement[_T]]: + self, bindvalue: BindParameter[_T_co] + ) -> Optional[ColumnElement[_T_co]]: """Given a bind value (i.e. a :class:`.BindParameter` instance), return a SQL expression in its place. @@ -576,7 +576,7 @@ class explicitly. def _sentinel_value_resolver( self, dialect: Dialect - ) -> Optional[_SentinelProcessorType[_T]]: + ) -> Optional[_SentinelProcessorType[_T_co]]: """Return an optional callable that will match parameter values (post-bind processing) to result values (pre-result-processing), for use in the "sentinel" feature. @@ -768,7 +768,7 @@ def _resolve_for_python_type( return self @util.ro_memoized_property - def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: + def _type_affinity(self) -> Optional[Type[TypeEngine[_T_co]]]: """Return a rudimental 'affinity' value expressing the general class of type.""" @@ -784,7 +784,7 @@ def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: @util.ro_memoized_property def _generic_type_affinity( self, - ) -> Type[TypeEngine[_T]]: + ) -> Type[TypeEngine[_T_co]]: best_camelcase = None best_uppercase = None @@ -811,10 +811,10 @@ def _generic_type_affinity( return ( best_camelcase or best_uppercase - or cast("Type[TypeEngine[_T]]", NULLTYPE.__class__) + or cast("Type[TypeEngine[_T_co]]", NULLTYPE.__class__) ) - def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T]: + def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T_co]: """ Return an instance of the generic type corresponding to this type using heuristic rule. The method may be overridden if this @@ -854,7 +854,7 @@ def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T]: return util.constructor_copy(self, self._generic_type_affinity) - def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: + def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T_co]: """Return a dialect-specific implementation for this :class:`.TypeEngine`. @@ -867,7 +867,7 @@ def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: return tm["impl"] return self._dialect_info(dialect)["impl"] - def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: + def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T_co]: """Return the 'unwrapped' dialect impl for this type. For a type that applies wrapping logic (e.g. TypeDecorator), give @@ -883,7 +883,7 @@ def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: def _cached_literal_processor( self, dialect: Dialect - ) -> Optional[_LiteralProcessorType[_T]]: + ) -> Optional[_LiteralProcessorType[_T_co]]: """Return a dialect-specific literal processor for this type.""" try: @@ -899,7 +899,7 @@ def _cached_literal_processor( def _cached_bind_processor( self, dialect: Dialect - ) -> Optional[_BindProcessorType[_T]]: + ) -> Optional[_BindProcessorType[_T_co]]: """Return a dialect-specific bind processor for this type.""" try: @@ -915,7 +915,7 @@ def _cached_bind_processor( def _cached_result_processor( self, dialect: Dialect, coltype: Any - ) -> Optional[_ResultProcessorType[_T]]: + ) -> Optional[_ResultProcessorType[_T_co]]: """Return a dialect-specific result processor for this type.""" try: @@ -935,7 +935,7 @@ def _cached_result_processor( def _cached_sentinel_value_processor( self, dialect: Dialect - ) -> Optional[_SentinelProcessorType[_T]]: + ) -> Optional[_SentinelProcessorType[_T_co]]: try: return dialect._type_memos[self]["sentinel"] except KeyError: @@ -946,7 +946,7 @@ def _cached_sentinel_value_processor( return bp def _cached_custom_processor( - self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O] + self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T_co]], _O] ) -> _O: """return a dialect-specific processing object for custom purposes. From 16565dfaee3139986bbe20b89c1206357d3581f6 Mon Sep 17 00:00:00 2001 From: Ana Canizares Date: Mon, 11 Mar 2024 18:55:07 +0100 Subject: [PATCH 2/3] Return MappedColumn type matching the type engine, default value type and nullable parameter --- lib/sqlalchemy/orm/_orm_constructors.py | 133 ++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 6cf16507ba6..056388820e9 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -95,6 +95,139 @@ def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: return AliasOption(alias) +# nullable=True -> MappedColumn[Optional[_T]] +@overload +def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Union[_NoArg, _T, Callable[..., _T]]] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + nullable: Literal[True], + primary_key: Optional[bool] = False, + deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, + deferred_group: Optional[str] = None, + deferred_raiseload: Optional[bool] = None, + use_existing_column: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: _AutoIncrementType = "auto", + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, + server_default: Optional[_ServerDefaultArgument] = None, + server_onupdate: Optional[FetchedValue] = None, + active_history: bool = False, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + sort_order: Union[_NoArg, int] = _NoArg.NO_ARG, + **kw: Any, +) -> MappedColumn[Optional[_T]]: ... + + +# nullable=False -> MappedColumn[_T] +@overload +def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T, Callable[..., _T]] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + nullable: Literal[False], + primary_key: Optional[bool] = False, + deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, + deferred_group: Optional[str] = None, + deferred_raiseload: Optional[bool] = None, + use_existing_column: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: _AutoIncrementType = "auto", + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, + server_default: Optional[_ServerDefaultArgument] = None, + server_onupdate: Optional[FetchedValue] = None, + active_history: bool = False, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + sort_order: Union[_NoArg, int] = _NoArg.NO_ARG, + **kw: Any, +) -> MappedColumn[_T]: ... + + +# nullable unset or None -> MappedColumn[_T] +# TODO this would be only correct if the default unset nullable was False, +# which implies a larger change. +@overload +def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[_T], SchemaEventTarget] + ] = None, + *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T, Callable[..., _T]] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + nullable: Optional[ + Literal[SchemaConst.NULL_UNSPECIFIED] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, + deferred_group: Optional[str] = None, + deferred_raiseload: Optional[bool] = None, + use_existing_column: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: _AutoIncrementType = "auto", + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, + server_default: Optional[_ServerDefaultArgument] = None, + server_onupdate: Optional[FetchedValue] = None, + active_history: bool = False, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + sort_order: Union[_NoArg, int] = _NoArg.NO_ARG, + **kw: Any, +) -> MappedColumn[_T]: ... + + def mapped_column( __name_pos: Optional[ Union[str, _TypeEngineArgument[Any], SchemaEventTarget] From 99cac9a3a5a767cb5ac2da9a10b952b2ac74ce7d Mon Sep 17 00:00:00 2001 From: Ana Canizares Date: Mon, 11 Mar 2024 18:55:43 +0100 Subject: [PATCH 3/3] Add and update mapped_column typing tests --- test/typing/plain_files/orm/issue_9340.py | 3 +- test/typing/plain_files/orm/mapped_column.py | 85 ++++++++++++++++++-- test/typing/plain_files/orm/relationship.py | 5 -- 3 files changed, 81 insertions(+), 12 deletions(-) diff --git a/test/typing/plain_files/orm/issue_9340.py b/test/typing/plain_files/orm/issue_9340.py index 6ccd2eed314..3706ffe0600 100644 --- a/test/typing/plain_files/orm/issue_9340.py +++ b/test/typing/plain_files/orm/issue_9340.py @@ -1,3 +1,4 @@ +from typing import Optional from typing import Sequence from typing import TYPE_CHECKING @@ -28,7 +29,7 @@ class UserComment(Message): __mapper_args__ = { "polymorphic_identity": "user_comment", } - username: Mapped[str] = mapped_column(nullable=True) + username: Mapped[Optional[str]] = mapped_column(nullable=True) engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/") diff --git a/test/typing/plain_files/orm/mapped_column.py b/test/typing/plain_files/orm/mapped_column.py index 26f5722a6fc..e3e20389c1b 100644 --- a/test/typing/plain_files/orm/mapped_column.py +++ b/test/typing/plain_files/orm/mapped_column.py @@ -1,3 +1,4 @@ +import typing from typing import Optional from sqlalchemy import ForeignKey @@ -44,7 +45,7 @@ class X(Base): b: Mapped[Optional[str]] = mapped_column() # this can't be detected because we don't know the type - c: Mapped[str] = mapped_column(nullable=True) + c: Mapped[Optional[str]] = mapped_column(nullable=True) d: Mapped[str] = mapped_column(nullable=False) e: Mapped[Optional[str]] = mapped_column(ForeignKey(c), nullable=True) @@ -58,8 +59,6 @@ class X(Base): # this probably is wrong. however at the moment it seems better to # decouple the right hand arguments from declaring things about the # left side since it mostly doesn't work in any case. - i: Mapped[str] = mapped_column(String, nullable=True) - j: Mapped[str] = mapped_column(String, nullable=False) k: Mapped[Optional[str]] = mapped_column(String, nullable=True) @@ -69,7 +68,6 @@ class X(Base): a_name: Mapped[str] = mapped_column("a_name") b_name: Mapped[Optional[str]] = mapped_column("b_name") - c_name: Mapped[str] = mapped_column("c_name", nullable=True) d_name: Mapped[str] = mapped_column("d_name", nullable=False) e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True) @@ -79,8 +77,6 @@ class X(Base): g_name: Mapped[str] = mapped_column("g_name", String) h_name: Mapped[Optional[str]] = mapped_column("h_name", String) - i_name: Mapped[str] = mapped_column("i_name", String, nullable=True) - j_name: Mapped[str] = mapped_column("j_name", String, nullable=False) k_name: Mapped[Optional[str]] = mapped_column( @@ -94,3 +90,80 @@ class X(Base): ) __table_args__ = (UniqueConstraint(a, b, name="uq1"), Index("ix1", c, d)) + + +if typing.TYPE_CHECKING: + # EXPECTED_RE_TYPE: sqlalchemy.orm.properties.MappedColumn\[builtins.int\] + reveal_type(mapped_column(Integer)) + + # EXPECTED_RE_TYPE: sqlalchemy.orm.properties.MappedColumn\[Union\[builtins.int, None\]\] + reveal_type(mapped_column(Integer, nullable=True)) + + # EXPECTED_RE_TYPE: sqlalchemy.orm.properties.MappedColumn\[builtins.int\] + reveal_type(mapped_column(Integer, default=7)) + + # EXPECTED_MYPY_RE: Argument 1 to "mapped_column" has incompatible type.* + a_err: Mapped[str] = mapped_column(Integer) + + # EXPECTED_MYPY_RE: Argument 2 to "mapped_column" has incompatible type.* + a_err_name: Mapped[str] = mapped_column("a", Integer) + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type "None".* + b_err: Mapped[int] = mapped_column(Integer, default=None) + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type "None".* + b_err_name: Mapped[int] = mapped_column("b", Integer, default=None) + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type "None".* + c_err: Mapped[int] = mapped_column(default=None) + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type "None".* + c_err_name: Mapped[int] = mapped_column("c", default=None) + + # EXPECTED_MYPY_RE: Incompatible types in assignment.* + d_err: Mapped[int] = mapped_column(Integer, nullable=True) + + # EXPECTED_MYPY_RE: Incompatible types in assignment.* + d_err_name: Mapped[int] = mapped_column("d", Integer, nullable=True) + + # EXPECTED_MYPY_RE: Incompatible types in assignment.* + e_err: Mapped[int] = mapped_column(nullable=True) + + # EXPECTED_MYPY_RE: Incompatible types in assignment.* + e_err_name: Mapped[int] = mapped_column("e", nullable=True) + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type.* + f_err: Mapped[int] = mapped_column(default="a") + + # EXPECTED_MYPY_RE: Argument "default" to "mapped_column" has incompatible type.* + f_err_name: Mapped[int] = mapped_column("f", default="a") + + # All of these are fine + x1: Mapped[str] = mapped_column(String, default="a", nullable=False) + x2: Mapped[str] = mapped_column(String, default="a") + x3: Mapped[str] = mapped_column(default="a", nullable=False) + x4: Mapped[str] = mapped_column(String, nullable=False) + x5: Mapped[str] = mapped_column(String) + x6: Mapped[str] = mapped_column(default="a") + x7: Mapped[str] = mapped_column(nullable=False) + x8: Mapped[str] = mapped_column() + + y1: Mapped[Optional[int]] = mapped_column( + Integer, default=None, nullable=True + ) + y2: Mapped[Optional[int]] = mapped_column(Integer, default=None) + y3: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + y4: Mapped[Optional[int]] = mapped_column(default=None, nullable=True) + y5: Mapped[Optional[int]] = mapped_column(default=None) + y6: Mapped[Optional[int]] = mapped_column(Integer) + y7: Mapped[Optional[int]] = mapped_column(nullable=True) + y8: Mapped[Optional[int]] = mapped_column() + + z1: Mapped[int] = mapped_column(Integer, default=7, nullable=False) + z2: Mapped[int] = mapped_column(Integer, default=7) + z3: Mapped[int] = mapped_column(default=7, nullable=False) + z4: Mapped[int] = mapped_column(Integer, nullable=False) + z5: Mapped[int] = mapped_column(Integer) + z6: Mapped[int] = mapped_column(default=7) + z7: Mapped[int] = mapped_column(nullable=False) + z8: Mapped[int] = mapped_column() diff --git a/test/typing/plain_files/orm/relationship.py b/test/typing/plain_files/orm/relationship.py index 6bfe19cc4e8..8757821e24c 100644 --- a/test/typing/plain_files/orm/relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -35,11 +35,6 @@ class User(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column() - # this currently doesnt generate an error. not sure how to get the - # overloads to hit this one, nor am i sure i really want to do that - # anyway - name_this_works_atm: Mapped[str] = mapped_column(nullable=True) - extra: Mapped[Optional[str]] = mapped_column() extra_name: Mapped[Optional[str]] = mapped_column("extra_name")