You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
5.9KB

  1. from typing import Any
  2. from typing import cast
  3. from typing import Iterable
  4. from typing import Iterator
  5. from typing import List
  6. from typing import Optional
  7. from typing import overload
  8. from typing import Tuple
  9. from typing import Type as TypingType
  10. from typing import TypeVar
  11. from typing import Union
  12. from mypy.nodes import CallExpr
  13. from mypy.nodes import ClassDef
  14. from mypy.nodes import CLASSDEF_NO_INFO
  15. from mypy.nodes import Context
  16. from mypy.nodes import IfStmt
  17. from mypy.nodes import JsonDict
  18. from mypy.nodes import NameExpr
  19. from mypy.nodes import Statement
  20. from mypy.nodes import SymbolTableNode
  21. from mypy.nodes import TypeInfo
  22. from mypy.plugin import ClassDefContext
  23. from mypy.plugin import DynamicClassDefContext
  24. from mypy.plugin import SemanticAnalyzerPluginInterface
  25. from mypy.plugins.common import deserialize_and_fixup_type
  26. from mypy.types import Instance
  27. from mypy.types import NoneType
  28. from mypy.types import ProperType
  29. from mypy.types import Type
  30. from mypy.types import UnboundType
  31. from mypy.types import UnionType
  32. _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
  33. class DeclClassApplied:
  34. def __init__(
  35. self,
  36. is_mapped: bool,
  37. has_table: bool,
  38. mapped_attr_names: Iterable[Tuple[str, ProperType]],
  39. mapped_mro: Iterable[Instance],
  40. ):
  41. self.is_mapped = is_mapped
  42. self.has_table = has_table
  43. self.mapped_attr_names = list(mapped_attr_names)
  44. self.mapped_mro = list(mapped_mro)
  45. def serialize(self) -> JsonDict:
  46. return {
  47. "is_mapped": self.is_mapped,
  48. "has_table": self.has_table,
  49. "mapped_attr_names": [
  50. (name, type_.serialize())
  51. for name, type_ in self.mapped_attr_names
  52. ],
  53. "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
  54. }
  55. @classmethod
  56. def deserialize(
  57. cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
  58. ) -> "DeclClassApplied":
  59. return DeclClassApplied(
  60. is_mapped=data["is_mapped"],
  61. has_table=data["has_table"],
  62. mapped_attr_names=cast(
  63. List[Tuple[str, ProperType]],
  64. [
  65. (name, deserialize_and_fixup_type(type_, api))
  66. for name, type_ in data["mapped_attr_names"]
  67. ],
  68. ),
  69. mapped_mro=cast(
  70. List[Instance],
  71. [
  72. deserialize_and_fixup_type(type_, api)
  73. for type_ in data["mapped_mro"]
  74. ],
  75. ),
  76. )
  77. def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
  78. msg = "[SQLAlchemy Mypy plugin] %s" % msg
  79. return api.fail(msg, ctx)
  80. def add_global(
  81. ctx: Union[ClassDefContext, DynamicClassDefContext],
  82. module: str,
  83. symbol_name: str,
  84. asname: str,
  85. ) -> None:
  86. module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
  87. if asname not in module_globals:
  88. lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
  89. symbol_name
  90. ]
  91. module_globals[asname] = lookup_sym
  92. @overload
  93. def _get_callexpr_kwarg(
  94. callexpr: CallExpr, name: str, *, expr_types: None = ...
  95. ) -> Optional[Union[CallExpr, NameExpr]]:
  96. ...
  97. @overload
  98. def _get_callexpr_kwarg(
  99. callexpr: CallExpr,
  100. name: str,
  101. *,
  102. expr_types: Tuple[TypingType[_TArgType], ...]
  103. ) -> Optional[_TArgType]:
  104. ...
  105. def _get_callexpr_kwarg(
  106. callexpr: CallExpr,
  107. name: str,
  108. *,
  109. expr_types: Optional[Tuple[TypingType[Any], ...]] = None
  110. ) -> Optional[Any]:
  111. try:
  112. arg_idx = callexpr.arg_names.index(name)
  113. except ValueError:
  114. return None
  115. kwarg = callexpr.args[arg_idx]
  116. if isinstance(
  117. kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
  118. ):
  119. return kwarg
  120. return None
  121. def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
  122. for stmt in stmts:
  123. if (
  124. isinstance(stmt, IfStmt)
  125. and isinstance(stmt.expr[0], NameExpr)
  126. and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
  127. ):
  128. for substmt in stmt.body[0].body:
  129. yield substmt
  130. else:
  131. yield stmt
  132. def _unbound_to_instance(
  133. api: SemanticAnalyzerPluginInterface, typ: Type
  134. ) -> Type:
  135. """Take the UnboundType that we seem to get as the ret_type from a FuncDef
  136. and convert it into an Instance/TypeInfo kind of structure that seems
  137. to work as the left-hand type of an AssignmentStatement.
  138. """
  139. if not isinstance(typ, UnboundType):
  140. return typ
  141. # TODO: figure out a more robust way to check this. The node is some
  142. # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
  143. # but I cant figure out how to get them to match up
  144. if typ.name == "Optional":
  145. # convert from "Optional?" to the more familiar
  146. # UnionType[..., NoneType()]
  147. return _unbound_to_instance(
  148. api,
  149. UnionType(
  150. [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
  151. + [NoneType()]
  152. ),
  153. )
  154. node = api.lookup_qualified(typ.name, typ)
  155. if (
  156. node is not None
  157. and isinstance(node, SymbolTableNode)
  158. and isinstance(node.node, TypeInfo)
  159. ):
  160. bound_type = node.node
  161. return Instance(
  162. bound_type,
  163. [
  164. _unbound_to_instance(api, arg)
  165. if isinstance(arg, UnboundType)
  166. else arg
  167. for arg in typ.args
  168. ],
  169. )
  170. else:
  171. return typ
  172. def _info_for_cls(
  173. cls: ClassDef, api: SemanticAnalyzerPluginInterface
  174. ) -> TypeInfo:
  175. if cls.info is CLASSDEF_NO_INFO:
  176. sym = api.lookup_qualified(cls.name, cls)
  177. if sym is None:
  178. return None
  179. assert sym and isinstance(sym.node, TypeInfo)
  180. return sym.node
  181. return cls.info