217 lines
5.9 KiB
Python
217 lines
5.9 KiB
Python
from typing import Any
|
|
from typing import cast
|
|
from typing import Iterable
|
|
from typing import Iterator
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import overload
|
|
from typing import Tuple
|
|
from typing import Type as TypingType
|
|
from typing import TypeVar
|
|
from typing import Union
|
|
|
|
from mypy.nodes import CallExpr
|
|
from mypy.nodes import ClassDef
|
|
from mypy.nodes import CLASSDEF_NO_INFO
|
|
from mypy.nodes import Context
|
|
from mypy.nodes import IfStmt
|
|
from mypy.nodes import JsonDict
|
|
from mypy.nodes import NameExpr
|
|
from mypy.nodes import Statement
|
|
from mypy.nodes import SymbolTableNode
|
|
from mypy.nodes import TypeInfo
|
|
from mypy.plugin import ClassDefContext
|
|
from mypy.plugin import DynamicClassDefContext
|
|
from mypy.plugin import SemanticAnalyzerPluginInterface
|
|
from mypy.plugins.common import deserialize_and_fixup_type
|
|
from mypy.types import Instance
|
|
from mypy.types import NoneType
|
|
from mypy.types import ProperType
|
|
from mypy.types import Type
|
|
from mypy.types import UnboundType
|
|
from mypy.types import UnionType
|
|
|
|
|
|
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
|
|
|
|
|
|
class DeclClassApplied:
|
|
def __init__(
|
|
self,
|
|
is_mapped: bool,
|
|
has_table: bool,
|
|
mapped_attr_names: Iterable[Tuple[str, ProperType]],
|
|
mapped_mro: Iterable[Instance],
|
|
):
|
|
self.is_mapped = is_mapped
|
|
self.has_table = has_table
|
|
self.mapped_attr_names = list(mapped_attr_names)
|
|
self.mapped_mro = list(mapped_mro)
|
|
|
|
def serialize(self) -> JsonDict:
|
|
return {
|
|
"is_mapped": self.is_mapped,
|
|
"has_table": self.has_table,
|
|
"mapped_attr_names": [
|
|
(name, type_.serialize())
|
|
for name, type_ in self.mapped_attr_names
|
|
],
|
|
"mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
|
|
}
|
|
|
|
@classmethod
|
|
def deserialize(
|
|
cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
|
|
) -> "DeclClassApplied":
|
|
|
|
return DeclClassApplied(
|
|
is_mapped=data["is_mapped"],
|
|
has_table=data["has_table"],
|
|
mapped_attr_names=cast(
|
|
List[Tuple[str, ProperType]],
|
|
[
|
|
(name, deserialize_and_fixup_type(type_, api))
|
|
for name, type_ in data["mapped_attr_names"]
|
|
],
|
|
),
|
|
mapped_mro=cast(
|
|
List[Instance],
|
|
[
|
|
deserialize_and_fixup_type(type_, api)
|
|
for type_ in data["mapped_mro"]
|
|
],
|
|
),
|
|
)
|
|
|
|
|
|
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
|
|
msg = "[SQLAlchemy Mypy plugin] %s" % msg
|
|
return api.fail(msg, ctx)
|
|
|
|
|
|
def add_global(
|
|
ctx: Union[ClassDefContext, DynamicClassDefContext],
|
|
module: str,
|
|
symbol_name: str,
|
|
asname: str,
|
|
) -> None:
|
|
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
|
|
|
|
if asname not in module_globals:
|
|
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
|
|
symbol_name
|
|
]
|
|
|
|
module_globals[asname] = lookup_sym
|
|
|
|
|
|
@overload
|
|
def _get_callexpr_kwarg(
|
|
callexpr: CallExpr, name: str, *, expr_types: None = ...
|
|
) -> Optional[Union[CallExpr, NameExpr]]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def _get_callexpr_kwarg(
|
|
callexpr: CallExpr,
|
|
name: str,
|
|
*,
|
|
expr_types: Tuple[TypingType[_TArgType], ...]
|
|
) -> Optional[_TArgType]:
|
|
...
|
|
|
|
|
|
def _get_callexpr_kwarg(
|
|
callexpr: CallExpr,
|
|
name: str,
|
|
*,
|
|
expr_types: Optional[Tuple[TypingType[Any], ...]] = None
|
|
) -> Optional[Any]:
|
|
try:
|
|
arg_idx = callexpr.arg_names.index(name)
|
|
except ValueError:
|
|
return None
|
|
|
|
kwarg = callexpr.args[arg_idx]
|
|
if isinstance(
|
|
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
|
|
):
|
|
return kwarg
|
|
|
|
return None
|
|
|
|
|
|
def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
|
|
for stmt in stmts:
|
|
if (
|
|
isinstance(stmt, IfStmt)
|
|
and isinstance(stmt.expr[0], NameExpr)
|
|
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
|
|
):
|
|
for substmt in stmt.body[0].body:
|
|
yield substmt
|
|
else:
|
|
yield stmt
|
|
|
|
|
|
def _unbound_to_instance(
|
|
api: SemanticAnalyzerPluginInterface, typ: Type
|
|
) -> Type:
|
|
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
|
|
and convert it into an Instance/TypeInfo kind of structure that seems
|
|
to work as the left-hand type of an AssignmentStatement.
|
|
|
|
"""
|
|
|
|
if not isinstance(typ, UnboundType):
|
|
return typ
|
|
|
|
# TODO: figure out a more robust way to check this. The node is some
|
|
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
|
|
# but I cant figure out how to get them to match up
|
|
if typ.name == "Optional":
|
|
# convert from "Optional?" to the more familiar
|
|
# UnionType[..., NoneType()]
|
|
return _unbound_to_instance(
|
|
api,
|
|
UnionType(
|
|
[_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
|
|
+ [NoneType()]
|
|
),
|
|
)
|
|
|
|
node = api.lookup_qualified(typ.name, typ)
|
|
|
|
if (
|
|
node is not None
|
|
and isinstance(node, SymbolTableNode)
|
|
and isinstance(node.node, TypeInfo)
|
|
):
|
|
bound_type = node.node
|
|
|
|
return Instance(
|
|
bound_type,
|
|
[
|
|
_unbound_to_instance(api, arg)
|
|
if isinstance(arg, UnboundType)
|
|
else arg
|
|
for arg in typ.args
|
|
],
|
|
)
|
|
else:
|
|
return typ
|
|
|
|
|
|
def _info_for_cls(
|
|
cls: ClassDef, api: SemanticAnalyzerPluginInterface
|
|
) -> TypeInfo:
|
|
if cls.info is CLASSDEF_NO_INFO:
|
|
sym = api.lookup_qualified(cls.name, cls)
|
|
if sym is None:
|
|
return None
|
|
assert sym and isinstance(sym.node, TypeInfo)
|
|
return sym.node
|
|
|
|
return cls.info
|