You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
8.9KB

  1. # mysql/aiomysql.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS
  3. # file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. r"""
  8. .. dialect:: mysql+aiomysql
  9. :name: aiomysql
  10. :dbapi: aiomysql
  11. :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
  12. :url: https://github.com/aio-libs/aiomysql
  13. The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
  14. Using a special asyncio mediation layer, the aiomysql dialect is usable
  15. as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
  16. extension package.
  17. This dialect should normally be used only with the
  18. :func:`_asyncio.create_async_engine` engine creation function::
  19. from sqlalchemy.ext.asyncio import create_async_engine
  20. engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname")
  21. Unicode
  22. -------
  23. Please see :ref:`mysql_unicode` for current recommendations on unicode
  24. handling.
  25. """ # noqa
  26. from .pymysql import MySQLDialect_pymysql
  27. from ... import pool
  28. from ... import util
  29. from ...util.concurrency import asyncio
  30. from ...util.concurrency import await_fallback
  31. from ...util.concurrency import await_only
  32. class AsyncAdapt_aiomysql_cursor:
  33. server_side = False
  34. __slots__ = (
  35. "_adapt_connection",
  36. "_connection",
  37. "await_",
  38. "_cursor",
  39. "_rows",
  40. )
  41. def __init__(self, adapt_connection):
  42. self._adapt_connection = adapt_connection
  43. self._connection = adapt_connection._connection
  44. self.await_ = adapt_connection.await_
  45. cursor = self._connection.cursor()
  46. # see https://github.com/aio-libs/aiomysql/issues/543
  47. self._cursor = self.await_(cursor.__aenter__())
  48. self._rows = []
  49. @property
  50. def description(self):
  51. return self._cursor.description
  52. @property
  53. def rowcount(self):
  54. return self._cursor.rowcount
  55. @property
  56. def arraysize(self):
  57. return self._cursor.arraysize
  58. @arraysize.setter
  59. def arraysize(self, value):
  60. self._cursor.arraysize = value
  61. @property
  62. def lastrowid(self):
  63. return self._cursor.lastrowid
  64. def close(self):
  65. # note we aren't actually closing the cursor here,
  66. # we are just letting GC do it. to allow this to be async
  67. # we would need the Result to change how it does "Safe close cursor".
  68. # MySQL "cursors" don't actually have state to be "closed" besides
  69. # exhausting rows, which we already have done for sync cursor.
  70. # another option would be to emulate aiosqlite dialect and assign
  71. # cursor only if we are doing server side cursor operation.
  72. self._rows[:] = []
  73. def execute(self, operation, parameters=None):
  74. return self.await_(self._execute_async(operation, parameters))
  75. def executemany(self, operation, seq_of_parameters):
  76. return self.await_(
  77. self._executemany_async(operation, seq_of_parameters)
  78. )
  79. async def _execute_async(self, operation, parameters):
  80. async with self._adapt_connection._execute_mutex:
  81. if parameters is None:
  82. result = await self._cursor.execute(operation)
  83. else:
  84. result = await self._cursor.execute(operation, parameters)
  85. if not self.server_side:
  86. # aiomysql has a "fake" async result, so we have to pull it out
  87. # of that here since our default result is not async.
  88. # we could just as easily grab "_rows" here and be done with it
  89. # but this is safer.
  90. self._rows = list(await self._cursor.fetchall())
  91. return result
  92. async def _executemany_async(self, operation, seq_of_parameters):
  93. async with self._adapt_connection._execute_mutex:
  94. return await self._cursor.executemany(operation, seq_of_parameters)
  95. def setinputsizes(self, *inputsizes):
  96. pass
  97. def __iter__(self):
  98. while self._rows:
  99. yield self._rows.pop(0)
  100. def fetchone(self):
  101. if self._rows:
  102. return self._rows.pop(0)
  103. else:
  104. return None
  105. def fetchmany(self, size=None):
  106. if size is None:
  107. size = self.arraysize
  108. retval = self._rows[0:size]
  109. self._rows[:] = self._rows[size:]
  110. return retval
  111. def fetchall(self):
  112. retval = self._rows[:]
  113. self._rows[:] = []
  114. return retval
  115. class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
  116. __slots__ = ()
  117. server_side = True
  118. def __init__(self, adapt_connection):
  119. self._adapt_connection = adapt_connection
  120. self._connection = adapt_connection._connection
  121. self.await_ = adapt_connection.await_
  122. cursor = self._connection.cursor(
  123. adapt_connection.dbapi.aiomysql.SSCursor
  124. )
  125. self._cursor = self.await_(cursor.__aenter__())
  126. def close(self):
  127. if self._cursor is not None:
  128. self.await_(self._cursor.close())
  129. self._cursor = None
  130. def fetchone(self):
  131. return self.await_(self._cursor.fetchone())
  132. def fetchmany(self, size=None):
  133. return self.await_(self._cursor.fetchmany(size=size))
  134. def fetchall(self):
  135. return self.await_(self._cursor.fetchall())
  136. class AsyncAdapt_aiomysql_connection:
  137. await_ = staticmethod(await_only)
  138. __slots__ = ("dbapi", "_connection", "_execute_mutex")
  139. def __init__(self, dbapi, connection):
  140. self.dbapi = dbapi
  141. self._connection = connection
  142. self._execute_mutex = asyncio.Lock()
  143. def ping(self, reconnect):
  144. return self.await_(self._connection.ping(reconnect))
  145. def character_set_name(self):
  146. return self._connection.character_set_name()
  147. def autocommit(self, value):
  148. self.await_(self._connection.autocommit(value))
  149. def cursor(self, server_side=False):
  150. if server_side:
  151. return AsyncAdapt_aiomysql_ss_cursor(self)
  152. else:
  153. return AsyncAdapt_aiomysql_cursor(self)
  154. def rollback(self):
  155. self.await_(self._connection.rollback())
  156. def commit(self):
  157. self.await_(self._connection.commit())
  158. def close(self):
  159. # it's not awaitable.
  160. self._connection.close()
  161. class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
  162. __slots__ = ()
  163. await_ = staticmethod(await_fallback)
  164. class AsyncAdapt_aiomysql_dbapi:
  165. def __init__(self, aiomysql, pymysql):
  166. self.aiomysql = aiomysql
  167. self.pymysql = pymysql
  168. self.paramstyle = "format"
  169. self._init_dbapi_attributes()
  170. def _init_dbapi_attributes(self):
  171. for name in (
  172. "Warning",
  173. "Error",
  174. "InterfaceError",
  175. "DataError",
  176. "DatabaseError",
  177. "OperationalError",
  178. "InterfaceError",
  179. "IntegrityError",
  180. "ProgrammingError",
  181. "InternalError",
  182. "NotSupportedError",
  183. ):
  184. setattr(self, name, getattr(self.aiomysql, name))
  185. for name in (
  186. "NUMBER",
  187. "STRING",
  188. "DATETIME",
  189. "BINARY",
  190. "TIMESTAMP",
  191. "Binary",
  192. ):
  193. setattr(self, name, getattr(self.pymysql, name))
  194. def connect(self, *arg, **kw):
  195. async_fallback = kw.pop("async_fallback", False)
  196. if util.asbool(async_fallback):
  197. return AsyncAdaptFallback_aiomysql_connection(
  198. self,
  199. await_fallback(self.aiomysql.connect(*arg, **kw)),
  200. )
  201. else:
  202. return AsyncAdapt_aiomysql_connection(
  203. self,
  204. await_only(self.aiomysql.connect(*arg, **kw)),
  205. )
  206. class MySQLDialect_aiomysql(MySQLDialect_pymysql):
  207. driver = "aiomysql"
  208. supports_statement_cache = True
  209. supports_server_side_cursors = True
  210. _sscursor = AsyncAdapt_aiomysql_ss_cursor
  211. is_async = True
  212. @classmethod
  213. def dbapi(cls):
  214. return AsyncAdapt_aiomysql_dbapi(
  215. __import__("aiomysql"), __import__("pymysql")
  216. )
  217. @classmethod
  218. def get_pool_class(cls, url):
  219. async_fallback = url.query.get("async_fallback", False)
  220. if util.asbool(async_fallback):
  221. return pool.FallbackAsyncAdaptedQueuePool
  222. else:
  223. return pool.AsyncAdaptedQueuePool
  224. def create_connect_args(self, url):
  225. return super(MySQLDialect_aiomysql, self).create_connect_args(
  226. url, _translate_args=dict(username="user", database="db")
  227. )
  228. def is_disconnect(self, e, connection, cursor):
  229. if super(MySQLDialect_aiomysql, self).is_disconnect(
  230. e, connection, cursor
  231. ):
  232. return True
  233. else:
  234. str_e = str(e).lower()
  235. return "not connected" in str_e
  236. def _found_rows_client_flag(self):
  237. from pymysql.constants import CLIENT
  238. return CLIENT.FOUND_ROWS
  239. dialect = MySQLDialect_aiomysql