|
- from typing import Any
- from typing import cast
- from typing import Iterable
- from typing import Iterator
- from typing import List
- from typing import Optional
- from typing import overload
- from typing import Tuple
- from typing import Type as TypingType
- from typing import TypeVar
- from typing import Union
-
- from mypy.nodes import CallExpr
- from mypy.nodes import ClassDef
- from mypy.nodes import CLASSDEF_NO_INFO
- from mypy.nodes import Context
- from mypy.nodes import IfStmt
- from mypy.nodes import JsonDict
- from mypy.nodes import NameExpr
- from mypy.nodes import Statement
- from mypy.nodes import SymbolTableNode
- from mypy.nodes import TypeInfo
- from mypy.plugin import ClassDefContext
- from mypy.plugin import DynamicClassDefContext
- from mypy.plugin import SemanticAnalyzerPluginInterface
- from mypy.plugins.common import deserialize_and_fixup_type
- from mypy.types import Instance
- from mypy.types import NoneType
- from mypy.types import ProperType
- from mypy.types import Type
- from mypy.types import UnboundType
- from mypy.types import UnionType
-
-
- _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
-
-
- class DeclClassApplied:
- def __init__(
- self,
- is_mapped: bool,
- has_table: bool,
- mapped_attr_names: Iterable[Tuple[str, ProperType]],
- mapped_mro: Iterable[Instance],
- ):
- self.is_mapped = is_mapped
- self.has_table = has_table
- self.mapped_attr_names = list(mapped_attr_names)
- self.mapped_mro = list(mapped_mro)
-
- def serialize(self) -> JsonDict:
- return {
- "is_mapped": self.is_mapped,
- "has_table": self.has_table,
- "mapped_attr_names": [
- (name, type_.serialize())
- for name, type_ in self.mapped_attr_names
- ],
- "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
- }
-
- @classmethod
- def deserialize(
- cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
- ) -> "DeclClassApplied":
-
- return DeclClassApplied(
- is_mapped=data["is_mapped"],
- has_table=data["has_table"],
- mapped_attr_names=cast(
- List[Tuple[str, ProperType]],
- [
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- ),
- mapped_mro=cast(
- List[Instance],
- [
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
- ),
- )
-
-
- def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
- msg = "[SQLAlchemy Mypy plugin] %s" % msg
- return api.fail(msg, ctx)
-
-
- def add_global(
- ctx: Union[ClassDefContext, DynamicClassDefContext],
- module: str,
- symbol_name: str,
- asname: str,
- ) -> None:
- module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
-
- if asname not in module_globals:
- lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
- symbol_name
- ]
-
- module_globals[asname] = lookup_sym
-
-
- @overload
- def _get_callexpr_kwarg(
- callexpr: CallExpr, name: str, *, expr_types: None = ...
- ) -> Optional[Union[CallExpr, NameExpr]]:
- ...
-
-
- @overload
- def _get_callexpr_kwarg(
- callexpr: CallExpr,
- name: str,
- *,
- expr_types: Tuple[TypingType[_TArgType], ...]
- ) -> Optional[_TArgType]:
- ...
-
-
- def _get_callexpr_kwarg(
- callexpr: CallExpr,
- name: str,
- *,
- expr_types: Optional[Tuple[TypingType[Any], ...]] = None
- ) -> Optional[Any]:
- try:
- arg_idx = callexpr.arg_names.index(name)
- except ValueError:
- return None
-
- kwarg = callexpr.args[arg_idx]
- if isinstance(
- kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
- ):
- return kwarg
-
- return None
-
-
- def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
- for stmt in stmts:
- if (
- isinstance(stmt, IfStmt)
- and isinstance(stmt.expr[0], NameExpr)
- and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
- ):
- for substmt in stmt.body[0].body:
- yield substmt
- else:
- yield stmt
-
-
- def _unbound_to_instance(
- api: SemanticAnalyzerPluginInterface, typ: Type
- ) -> Type:
- """Take the UnboundType that we seem to get as the ret_type from a FuncDef
- and convert it into an Instance/TypeInfo kind of structure that seems
- to work as the left-hand type of an AssignmentStatement.
-
- """
-
- if not isinstance(typ, UnboundType):
- return typ
-
- # TODO: figure out a more robust way to check this. The node is some
- # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
- # but I cant figure out how to get them to match up
- if typ.name == "Optional":
- # convert from "Optional?" to the more familiar
- # UnionType[..., NoneType()]
- return _unbound_to_instance(
- api,
- UnionType(
- [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
- + [NoneType()]
- ),
- )
-
- node = api.lookup_qualified(typ.name, typ)
-
- if (
- node is not None
- and isinstance(node, SymbolTableNode)
- and isinstance(node.node, TypeInfo)
- ):
- bound_type = node.node
-
- return Instance(
- bound_type,
- [
- _unbound_to_instance(api, arg)
- if isinstance(arg, UnboundType)
- else arg
- for arg in typ.args
- ],
- )
- else:
- return typ
-
-
- def _info_for_cls(
- cls: ClassDef, api: SemanticAnalyzerPluginInterface
- ) -> TypeInfo:
- if cls.info is CLASSDEF_NO_INFO:
- sym = api.lookup_qualified(cls.name, cls)
- if sym is None:
- return None
- assert sym and isinstance(sym.node, TypeInfo)
- return sym.node
-
- return cls.info
|