You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

273 lines
8.8KB

  1. # ext/mypy/apply.py
  2. # Copyright (C) 2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. from typing import Optional
  8. from typing import Union
  9. from mypy import nodes
  10. from mypy.nodes import ARG_NAMED_OPT
  11. from mypy.nodes import Argument
  12. from mypy.nodes import AssignmentStmt
  13. from mypy.nodes import CallExpr
  14. from mypy.nodes import ClassDef
  15. from mypy.nodes import MDEF
  16. from mypy.nodes import MemberExpr
  17. from mypy.nodes import NameExpr
  18. from mypy.nodes import StrExpr
  19. from mypy.nodes import SymbolTableNode
  20. from mypy.nodes import TempNode
  21. from mypy.nodes import TypeInfo
  22. from mypy.nodes import Var
  23. from mypy.plugin import SemanticAnalyzerPluginInterface
  24. from mypy.plugins.common import add_method_to_class
  25. from mypy.types import AnyType
  26. from mypy.types import get_proper_type
  27. from mypy.types import Instance
  28. from mypy.types import NoneTyp
  29. from mypy.types import ProperType
  30. from mypy.types import TypeOfAny
  31. from mypy.types import UnboundType
  32. from mypy.types import UnionType
  33. from . import infer
  34. from . import util
  35. def _apply_mypy_mapped_attr(
  36. cls: ClassDef,
  37. api: SemanticAnalyzerPluginInterface,
  38. item: Union[NameExpr, StrExpr],
  39. cls_metadata: util.DeclClassApplied,
  40. ) -> None:
  41. if isinstance(item, NameExpr):
  42. name = item.name
  43. elif isinstance(item, StrExpr):
  44. name = item.value
  45. else:
  46. return
  47. for stmt in cls.defs.body:
  48. if (
  49. isinstance(stmt, AssignmentStmt)
  50. and isinstance(stmt.lvalues[0], NameExpr)
  51. and stmt.lvalues[0].name == name
  52. ):
  53. break
  54. else:
  55. util.fail(api, "Can't find mapped attribute {}".format(name), cls)
  56. return
  57. if stmt.type is None:
  58. util.fail(
  59. api,
  60. "Statement linked from _mypy_mapped_attrs has no "
  61. "typing information",
  62. stmt,
  63. )
  64. return
  65. left_hand_explicit_type = get_proper_type(stmt.type)
  66. assert isinstance(
  67. left_hand_explicit_type, (Instance, UnionType, UnboundType)
  68. )
  69. cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
  70. _apply_type_to_mapped_statement(
  71. api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
  72. )
  73. def _re_apply_declarative_assignments(
  74. cls: ClassDef,
  75. api: SemanticAnalyzerPluginInterface,
  76. cls_metadata: util.DeclClassApplied,
  77. ) -> None:
  78. """For multiple class passes, re-apply our left-hand side types as mypy
  79. seems to reset them in place.
  80. """
  81. mapped_attr_lookup = {
  82. name: typ for name, typ in cls_metadata.mapped_attr_names
  83. }
  84. update_cls_metadata = False
  85. for stmt in cls.defs.body:
  86. # for a re-apply, all of our statements are AssignmentStmt;
  87. # @declared_attr calls will have been converted and this
  88. # currently seems to be preserved by mypy (but who knows if this
  89. # will change).
  90. if (
  91. isinstance(stmt, AssignmentStmt)
  92. and isinstance(stmt.lvalues[0], NameExpr)
  93. and stmt.lvalues[0].name in mapped_attr_lookup
  94. and isinstance(stmt.lvalues[0].node, Var)
  95. ):
  96. left_node = stmt.lvalues[0].node
  97. python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
  98. # if we have scanned an UnboundType and now there's a more
  99. # specific type than UnboundType, call the re-scan so we
  100. # can get that set up correctly
  101. if (
  102. isinstance(python_type_for_type, UnboundType)
  103. and not isinstance(left_node.type, UnboundType)
  104. and (
  105. isinstance(stmt.rvalue.callee, MemberExpr)
  106. and stmt.rvalue.callee.expr.node.fullname
  107. == "sqlalchemy.orm.attributes.Mapped"
  108. and stmt.rvalue.callee.name == "_empty_constructor"
  109. and isinstance(stmt.rvalue.args[0], CallExpr)
  110. )
  111. ):
  112. python_type_for_type = (
  113. infer._infer_type_from_right_hand_nameexpr(
  114. api,
  115. stmt,
  116. left_node,
  117. left_node.type,
  118. stmt.rvalue.args[0].callee,
  119. )
  120. )
  121. if python_type_for_type is None or isinstance(
  122. python_type_for_type, UnboundType
  123. ):
  124. continue
  125. # update the DeclClassApplied with the better information
  126. mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
  127. update_cls_metadata = True
  128. left_node.type = api.named_type(
  129. "__sa_Mapped", [python_type_for_type]
  130. )
  131. if update_cls_metadata:
  132. cls_metadata.mapped_attr_names[:] = [
  133. (k, v) for k, v in mapped_attr_lookup.items()
  134. ]
  135. def _apply_type_to_mapped_statement(
  136. api: SemanticAnalyzerPluginInterface,
  137. stmt: AssignmentStmt,
  138. lvalue: NameExpr,
  139. left_hand_explicit_type: Optional[ProperType],
  140. python_type_for_type: Optional[ProperType],
  141. ) -> None:
  142. """Apply the Mapped[<type>] annotation and right hand object to a
  143. declarative assignment statement.
  144. This converts a Python declarative class statement such as::
  145. class User(Base):
  146. # ...
  147. attrname = Column(Integer)
  148. To one that describes the final Python behavior to Mypy::
  149. class User(Base):
  150. # ...
  151. attrname : Mapped[Optional[int]] = <meaningless temp node>
  152. """
  153. left_node = lvalue.node
  154. assert isinstance(left_node, Var)
  155. if left_hand_explicit_type is not None:
  156. left_node.type = api.named_type(
  157. "__sa_Mapped", [left_hand_explicit_type]
  158. )
  159. else:
  160. lvalue.is_inferred_def = False
  161. left_node.type = api.named_type(
  162. "__sa_Mapped",
  163. [] if python_type_for_type is None else [python_type_for_type],
  164. )
  165. # so to have it skip the right side totally, we can do this:
  166. # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
  167. # however, if we instead manufacture a new node that uses the old
  168. # one, then we can still get type checking for the call itself,
  169. # e.g. the Column, relationship() call, etc.
  170. # rewrite the node as:
  171. # <attr> : Mapped[<typ>] =
  172. # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
  173. # the original right-hand side is maintained so it gets type checked
  174. # internally
  175. column_descriptor = nodes.NameExpr("__sa_Mapped")
  176. column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
  177. mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
  178. orig_call_expr = stmt.rvalue
  179. stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
  180. def _add_additional_orm_attributes(
  181. cls: ClassDef,
  182. api: SemanticAnalyzerPluginInterface,
  183. cls_metadata: util.DeclClassApplied,
  184. ) -> None:
  185. """Apply __init__, __table__ and other attributes to the mapped class."""
  186. info = util._info_for_cls(cls, api)
  187. if "__init__" not in info.names and cls_metadata.is_mapped:
  188. mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
  189. for mapped_base in cls_metadata.mapped_mro:
  190. base_cls_metadata = util.DeclClassApplied.deserialize(
  191. mapped_base.type.metadata["_sa_decl_class_applied"], api
  192. )
  193. for n, t in base_cls_metadata.mapped_attr_names:
  194. mapped_attr_names.setdefault(n, t)
  195. arguments = []
  196. for name, typ in mapped_attr_names.items():
  197. if typ is None:
  198. typ = AnyType(TypeOfAny.special_form)
  199. arguments.append(
  200. Argument(
  201. variable=Var(name, typ),
  202. type_annotation=typ,
  203. initializer=TempNode(typ),
  204. kind=ARG_NAMED_OPT,
  205. )
  206. )
  207. add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
  208. if "__table__" not in info.names and cls_metadata.has_table:
  209. _apply_placeholder_attr_to_class(
  210. api, cls, "sqlalchemy.sql.schema.Table", "__table__"
  211. )
  212. if cls_metadata.is_mapped:
  213. _apply_placeholder_attr_to_class(
  214. api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
  215. )
  216. def _apply_placeholder_attr_to_class(
  217. api: SemanticAnalyzerPluginInterface,
  218. cls: ClassDef,
  219. qualified_name: str,
  220. attrname: str,
  221. ) -> None:
  222. sym = api.lookup_fully_qualified_or_none(qualified_name)
  223. if sym:
  224. assert isinstance(sym.node, TypeInfo)
  225. type_: ProperType = Instance(sym.node, [])
  226. else:
  227. type_ = AnyType(TypeOfAny.special_form)
  228. var = Var(attrname)
  229. var.info = cls.info
  230. var.type = type_
  231. cls.info.names[attrname] = SymbolTableNode(MDEF, var)