You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1133 lines
40KB

  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. import functools
  4. import os
  5. import sys
  6. import time
  7. import warnings
  8. from math import ceil
  9. from operator import itemgetter
  10. from threading import Lock
  11. import sqlalchemy
  12. from flask import _app_ctx_stack, abort, current_app, request
  13. from flask.signals import Namespace
  14. from sqlalchemy import event, inspect, orm
  15. from sqlalchemy.engine.url import make_url
  16. from sqlalchemy.orm.exc import UnmappedClassError
  17. from sqlalchemy.orm.session import Session as SessionBase
  18. from ._compat import itervalues, string_types, xrange
  19. from .model import DefaultMeta
  20. from .model import Model
  21. from . import utils
  22. try:
  23. from sqlalchemy.orm import declarative_base
  24. from sqlalchemy.orm import DeclarativeMeta
  25. except ImportError:
  26. # SQLAlchemy <= 1.3
  27. from sqlalchemy.ext.declarative import declarative_base
  28. from sqlalchemy.ext.declarative import DeclarativeMeta
  29. # Scope the session to the current greenlet if greenlet is available,
  30. # otherwise fall back to the current thread.
  31. try:
  32. from greenlet import getcurrent as _ident_func
  33. except ImportError:
  34. try:
  35. from threading import get_ident as _ident_func
  36. except ImportError:
  37. # Python 2.7
  38. from thread import get_ident as _ident_func
  39. __version__ = "2.5.1"
  40. # the best timer function for the platform
  41. if sys.platform == 'win32':
  42. if sys.version_info >= (3, 3):
  43. _timer = time.perf_counter
  44. else:
  45. _timer = time.clock
  46. else:
  47. _timer = time.time
  48. _signals = Namespace()
  49. models_committed = _signals.signal('models-committed')
  50. before_models_committed = _signals.signal('before-models-committed')
  51. def _sa_url_set(url, **kwargs):
  52. try:
  53. url = url.set(**kwargs)
  54. except AttributeError:
  55. # SQLAlchemy <= 1.3
  56. for key, value in kwargs.items():
  57. setattr(url, key, value)
  58. return url
  59. def _sa_url_query_setdefault(url, **kwargs):
  60. query = dict(url.query)
  61. for key, value in kwargs.items():
  62. query.setdefault(key, value)
  63. return _sa_url_set(url, query=query)
  64. def _make_table(db):
  65. def _make_table(*args, **kwargs):
  66. if len(args) > 1 and isinstance(args[1], db.Column):
  67. args = (args[0], db.metadata) + args[1:]
  68. info = kwargs.pop('info', None) or {}
  69. info.setdefault('bind_key', None)
  70. kwargs['info'] = info
  71. return sqlalchemy.Table(*args, **kwargs)
  72. return _make_table
  73. def _set_default_query_class(d, cls):
  74. if 'query_class' not in d:
  75. d['query_class'] = cls
  76. def _wrap_with_default_query_class(fn, cls):
  77. @functools.wraps(fn)
  78. def newfn(*args, **kwargs):
  79. _set_default_query_class(kwargs, cls)
  80. if "backref" in kwargs:
  81. backref = kwargs['backref']
  82. if isinstance(backref, string_types):
  83. backref = (backref, {})
  84. _set_default_query_class(backref[1], cls)
  85. return fn(*args, **kwargs)
  86. return newfn
  87. def _include_sqlalchemy(obj, cls):
  88. for module in sqlalchemy, sqlalchemy.orm:
  89. for key in module.__all__:
  90. if not hasattr(obj, key):
  91. setattr(obj, key, getattr(module, key))
  92. # Note: obj.Table does not attempt to be a SQLAlchemy Table class.
  93. obj.Table = _make_table(obj)
  94. obj.relationship = _wrap_with_default_query_class(obj.relationship, cls)
  95. obj.relation = _wrap_with_default_query_class(obj.relation, cls)
  96. obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls)
  97. obj.event = event
  98. class _DebugQueryTuple(tuple):
  99. statement = property(itemgetter(0))
  100. parameters = property(itemgetter(1))
  101. start_time = property(itemgetter(2))
  102. end_time = property(itemgetter(3))
  103. context = property(itemgetter(4))
  104. @property
  105. def duration(self):
  106. return self.end_time - self.start_time
  107. def __repr__(self):
  108. return '<query statement="%s" parameters=%r duration=%.03f>' % (
  109. self.statement,
  110. self.parameters,
  111. self.duration
  112. )
  113. def _calling_context(app_path):
  114. frm = sys._getframe(1)
  115. while frm.f_back is not None:
  116. name = frm.f_globals.get('__name__')
  117. if name and (name == app_path or name.startswith(app_path + '.')):
  118. funcname = frm.f_code.co_name
  119. return '%s:%s (%s)' % (
  120. frm.f_code.co_filename,
  121. frm.f_lineno,
  122. funcname
  123. )
  124. frm = frm.f_back
  125. return '<unknown>'
  126. class SignallingSession(SessionBase):
  127. """The signalling session is the default session that Flask-SQLAlchemy
  128. uses. It extends the default session system with bind selection and
  129. modification tracking.
  130. If you want to use a different session you can override the
  131. :meth:`SQLAlchemy.create_session` function.
  132. .. versionadded:: 2.0
  133. .. versionadded:: 2.1
  134. The `binds` option was added, which allows a session to be joined
  135. to an external transaction.
  136. """
  137. def __init__(self, db, autocommit=False, autoflush=True, **options):
  138. #: The application that this session belongs to.
  139. self.app = app = db.get_app()
  140. track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
  141. bind = options.pop('bind', None) or db.engine
  142. binds = options.pop('binds', db.get_binds(app))
  143. if track_modifications is None or track_modifications:
  144. _SessionSignalEvents.register(self)
  145. SessionBase.__init__(
  146. self, autocommit=autocommit, autoflush=autoflush,
  147. bind=bind, binds=binds, **options
  148. )
  149. def get_bind(self, mapper=None, clause=None):
  150. """Return the engine or connection for a given model or
  151. table, using the ``__bind_key__`` if it is set.
  152. """
  153. # mapper is None if someone tries to just get a connection
  154. if mapper is not None:
  155. try:
  156. # SA >= 1.3
  157. persist_selectable = mapper.persist_selectable
  158. except AttributeError:
  159. # SA < 1.3
  160. persist_selectable = mapper.mapped_table
  161. info = getattr(persist_selectable, 'info', {})
  162. bind_key = info.get('bind_key')
  163. if bind_key is not None:
  164. state = get_state(self.app)
  165. return state.db.get_engine(self.app, bind=bind_key)
  166. return SessionBase.get_bind(self, mapper, clause)
  167. class _SessionSignalEvents(object):
  168. @classmethod
  169. def register(cls, session):
  170. if not hasattr(session, '_model_changes'):
  171. session._model_changes = {}
  172. event.listen(session, 'before_flush', cls.record_ops)
  173. event.listen(session, 'before_commit', cls.record_ops)
  174. event.listen(session, 'before_commit', cls.before_commit)
  175. event.listen(session, 'after_commit', cls.after_commit)
  176. event.listen(session, 'after_rollback', cls.after_rollback)
  177. @classmethod
  178. def unregister(cls, session):
  179. if hasattr(session, '_model_changes'):
  180. del session._model_changes
  181. event.remove(session, 'before_flush', cls.record_ops)
  182. event.remove(session, 'before_commit', cls.record_ops)
  183. event.remove(session, 'before_commit', cls.before_commit)
  184. event.remove(session, 'after_commit', cls.after_commit)
  185. event.remove(session, 'after_rollback', cls.after_rollback)
  186. @staticmethod
  187. def record_ops(session, flush_context=None, instances=None):
  188. try:
  189. d = session._model_changes
  190. except AttributeError:
  191. return
  192. for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')):
  193. for target in targets:
  194. state = inspect(target)
  195. key = state.identity_key if state.has_identity else id(target)
  196. d[key] = (target, operation)
  197. @staticmethod
  198. def before_commit(session):
  199. try:
  200. d = session._model_changes
  201. except AttributeError:
  202. return
  203. if d:
  204. before_models_committed.send(session.app, changes=list(d.values()))
  205. @staticmethod
  206. def after_commit(session):
  207. try:
  208. d = session._model_changes
  209. except AttributeError:
  210. return
  211. if d:
  212. models_committed.send(session.app, changes=list(d.values()))
  213. d.clear()
  214. @staticmethod
  215. def after_rollback(session):
  216. try:
  217. d = session._model_changes
  218. except AttributeError:
  219. return
  220. d.clear()
  221. class _EngineDebuggingSignalEvents(object):
  222. """Sets up handlers for two events that let us track the execution time of
  223. queries."""
  224. def __init__(self, engine, import_name):
  225. self.engine = engine
  226. self.app_package = import_name
  227. def register(self):
  228. event.listen(
  229. self.engine, 'before_cursor_execute', self.before_cursor_execute
  230. )
  231. event.listen(
  232. self.engine, 'after_cursor_execute', self.after_cursor_execute
  233. )
  234. def before_cursor_execute(
  235. self, conn, cursor, statement, parameters, context, executemany
  236. ):
  237. if current_app:
  238. context._query_start_time = _timer()
  239. def after_cursor_execute(
  240. self, conn, cursor, statement, parameters, context, executemany
  241. ):
  242. if current_app:
  243. try:
  244. queries = _app_ctx_stack.top.sqlalchemy_queries
  245. except AttributeError:
  246. queries = _app_ctx_stack.top.sqlalchemy_queries = []
  247. queries.append(_DebugQueryTuple((
  248. statement, parameters, context._query_start_time, _timer(),
  249. _calling_context(self.app_package)
  250. )))
  251. def get_debug_queries():
  252. """In debug mode Flask-SQLAlchemy will log all the SQL queries sent
  253. to the database. This information is available until the end of request
  254. which makes it possible to easily ensure that the SQL generated is the
  255. one expected on errors or in unittesting. If you don't want to enable
  256. the DEBUG mode for your unittests you can also enable the query
  257. recording by setting the ``'SQLALCHEMY_RECORD_QUERIES'`` config variable
  258. to `True`. This is automatically enabled if Flask is in testing mode.
  259. The value returned will be a list of named tuples with the following
  260. attributes:
  261. `statement`
  262. The SQL statement issued
  263. `parameters`
  264. The parameters for the SQL statement
  265. `start_time` / `end_time`
  266. Time the query started / the results arrived. Please keep in mind
  267. that the timer function used depends on your platform. These
  268. values are only useful for sorting or comparing. They do not
  269. necessarily represent an absolute timestamp.
  270. `duration`
  271. Time the query took in seconds
  272. `context`
  273. A string giving a rough estimation of where in your application
  274. query was issued. The exact format is undefined so don't try
  275. to reconstruct filename or function name.
  276. """
  277. return getattr(_app_ctx_stack.top, 'sqlalchemy_queries', [])
  278. class Pagination(object):
  279. """Internal helper class returned by :meth:`BaseQuery.paginate`. You
  280. can also construct it from any other SQLAlchemy query object if you are
  281. working with other libraries. Additionally it is possible to pass `None`
  282. as query object in which case the :meth:`prev` and :meth:`next` will
  283. no longer work.
  284. """
  285. def __init__(self, query, page, per_page, total, items):
  286. #: the unlimited query object that was used to create this
  287. #: pagination object.
  288. self.query = query
  289. #: the current page number (1 indexed)
  290. self.page = page
  291. #: the number of items to be displayed on a page.
  292. self.per_page = per_page
  293. #: the total number of items matching the query
  294. self.total = total
  295. #: the items for the current page
  296. self.items = items
  297. @property
  298. def pages(self):
  299. """The total number of pages"""
  300. if self.per_page == 0:
  301. pages = 0
  302. else:
  303. pages = int(ceil(self.total / float(self.per_page)))
  304. return pages
  305. def prev(self, error_out=False):
  306. """Returns a :class:`Pagination` object for the previous page."""
  307. assert self.query is not None, 'a query object is required ' \
  308. 'for this method to work'
  309. return self.query.paginate(self.page - 1, self.per_page, error_out)
  310. @property
  311. def prev_num(self):
  312. """Number of the previous page."""
  313. if not self.has_prev:
  314. return None
  315. return self.page - 1
  316. @property
  317. def has_prev(self):
  318. """True if a previous page exists"""
  319. return self.page > 1
  320. def next(self, error_out=False):
  321. """Returns a :class:`Pagination` object for the next page."""
  322. assert self.query is not None, 'a query object is required ' \
  323. 'for this method to work'
  324. return self.query.paginate(self.page + 1, self.per_page, error_out)
  325. @property
  326. def has_next(self):
  327. """True if a next page exists."""
  328. return self.page < self.pages
  329. @property
  330. def next_num(self):
  331. """Number of the next page"""
  332. if not self.has_next:
  333. return None
  334. return self.page + 1
  335. def iter_pages(self, left_edge=2, left_current=2,
  336. right_current=5, right_edge=2):
  337. """Iterates over the page numbers in the pagination. The four
  338. parameters control the thresholds how many numbers should be produced
  339. from the sides. Skipped page numbers are represented as `None`.
  340. This is how you could render such a pagination in the templates:
  341. .. sourcecode:: html+jinja
  342. {% macro render_pagination(pagination, endpoint) %}
  343. <div class=pagination>
  344. {%- for page in pagination.iter_pages() %}
  345. {% if page %}
  346. {% if page != pagination.page %}
  347. <a href="{{ url_for(endpoint, page=page) }}">{{ page }}</a>
  348. {% else %}
  349. <strong>{{ page }}</strong>
  350. {% endif %}
  351. {% else %}
  352. <span class=ellipsis>…</span>
  353. {% endif %}
  354. {%- endfor %}
  355. </div>
  356. {% endmacro %}
  357. """
  358. last = 0
  359. for num in xrange(1, self.pages + 1):
  360. if num <= left_edge or \
  361. (num > self.page - left_current - 1 and
  362. num < self.page + right_current) or \
  363. num > self.pages - right_edge:
  364. if last + 1 != num:
  365. yield None
  366. yield num
  367. last = num
  368. class BaseQuery(orm.Query):
  369. """SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with convenience methods for querying in a web application.
  370. This is the default :attr:`~Model.query` object used for models, and exposed as :attr:`~SQLAlchemy.Query`.
  371. Override the query class for an individual model by subclassing this and setting :attr:`~Model.query_class`.
  372. """
  373. def get_or_404(self, ident, description=None):
  374. """Like :meth:`get` but aborts with 404 if not found instead of returning ``None``."""
  375. rv = self.get(ident)
  376. if rv is None:
  377. abort(404, description=description)
  378. return rv
  379. def first_or_404(self, description=None):
  380. """Like :meth:`first` but aborts with 404 if not found instead of returning ``None``."""
  381. rv = self.first()
  382. if rv is None:
  383. abort(404, description=description)
  384. return rv
  385. def paginate(self, page=None, per_page=None, error_out=True, max_per_page=None):
  386. """Returns ``per_page`` items from page ``page``.
  387. If ``page`` or ``per_page`` are ``None``, they will be retrieved from
  388. the request query. If ``max_per_page`` is specified, ``per_page`` will
  389. be limited to that value. If there is no request or they aren't in the
  390. query, they default to 1 and 20 respectively.
  391. When ``error_out`` is ``True`` (default), the following rules will
  392. cause a 404 response:
  393. * No items are found and ``page`` is not 1.
  394. * ``page`` is less than 1, or ``per_page`` is negative.
  395. * ``page`` or ``per_page`` are not ints.
  396. When ``error_out`` is ``False``, ``page`` and ``per_page`` default to
  397. 1 and 20 respectively.
  398. Returns a :class:`Pagination` object.
  399. """
  400. if request:
  401. if page is None:
  402. try:
  403. page = int(request.args.get('page', 1))
  404. except (TypeError, ValueError):
  405. if error_out:
  406. abort(404)
  407. page = 1
  408. if per_page is None:
  409. try:
  410. per_page = int(request.args.get('per_page', 20))
  411. except (TypeError, ValueError):
  412. if error_out:
  413. abort(404)
  414. per_page = 20
  415. else:
  416. if page is None:
  417. page = 1
  418. if per_page is None:
  419. per_page = 20
  420. if max_per_page is not None:
  421. per_page = min(per_page, max_per_page)
  422. if page < 1:
  423. if error_out:
  424. abort(404)
  425. else:
  426. page = 1
  427. if per_page < 0:
  428. if error_out:
  429. abort(404)
  430. else:
  431. per_page = 20
  432. items = self.limit(per_page).offset((page - 1) * per_page).all()
  433. if not items and page != 1 and error_out:
  434. abort(404)
  435. total = self.order_by(None).count()
  436. return Pagination(self, page, per_page, total, items)
  437. class _QueryProperty(object):
  438. def __init__(self, sa):
  439. self.sa = sa
  440. def __get__(self, obj, type):
  441. try:
  442. mapper = orm.class_mapper(type)
  443. if mapper:
  444. return type.query_class(mapper, session=self.sa.session())
  445. except UnmappedClassError:
  446. return None
  447. def _record_queries(app):
  448. if app.debug:
  449. return True
  450. rq = app.config['SQLALCHEMY_RECORD_QUERIES']
  451. if rq is not None:
  452. return rq
  453. return bool(app.config.get('TESTING'))
  454. class _EngineConnector(object):
  455. def __init__(self, sa, app, bind=None):
  456. self._sa = sa
  457. self._app = app
  458. self._engine = None
  459. self._connected_for = None
  460. self._bind = bind
  461. self._lock = Lock()
  462. def get_uri(self):
  463. if self._bind is None:
  464. return self._app.config['SQLALCHEMY_DATABASE_URI']
  465. binds = self._app.config.get('SQLALCHEMY_BINDS') or ()
  466. assert self._bind in binds, \
  467. 'Bind %r is not specified. Set it in the SQLALCHEMY_BINDS ' \
  468. 'configuration variable' % self._bind
  469. return binds[self._bind]
  470. def get_engine(self):
  471. with self._lock:
  472. uri = self.get_uri()
  473. echo = self._app.config['SQLALCHEMY_ECHO']
  474. if (uri, echo) == self._connected_for:
  475. return self._engine
  476. sa_url = make_url(uri)
  477. sa_url, options = self.get_options(sa_url, echo)
  478. self._engine = rv = self._sa.create_engine(sa_url, options)
  479. if _record_queries(self._app):
  480. _EngineDebuggingSignalEvents(self._engine,
  481. self._app.import_name).register()
  482. self._connected_for = (uri, echo)
  483. return rv
  484. def get_options(self, sa_url, echo):
  485. options = {}
  486. options = self._sa.apply_pool_defaults(self._app, options)
  487. sa_url, options = self._sa.apply_driver_hacks(self._app, sa_url, options)
  488. if echo:
  489. options['echo'] = echo
  490. # Give the config options set by a developer explicitly priority
  491. # over decisions FSA makes.
  492. options.update(self._app.config['SQLALCHEMY_ENGINE_OPTIONS'])
  493. # Give options set in SQLAlchemy.__init__() ultimate priority
  494. options.update(self._sa._engine_options)
  495. return sa_url, options
  496. def get_state(app):
  497. """Gets the state for the application"""
  498. assert 'sqlalchemy' in app.extensions, \
  499. 'The sqlalchemy extension was not registered to the current ' \
  500. 'application. Please make sure to call init_app() first.'
  501. return app.extensions['sqlalchemy']
  502. class _SQLAlchemyState(object):
  503. """Remembers configuration for the (db, app) tuple."""
  504. def __init__(self, db):
  505. self.db = db
  506. self.connectors = {}
  507. class SQLAlchemy(object):
  508. """This class is used to control the SQLAlchemy integration to one
  509. or more Flask applications. Depending on how you initialize the
  510. object it is usable right away or will attach as needed to a
  511. Flask application.
  512. There are two usage modes which work very similarly. One is binding
  513. the instance to a very specific Flask application::
  514. app = Flask(__name__)
  515. db = SQLAlchemy(app)
  516. The second possibility is to create the object once and configure the
  517. application later to support it::
  518. db = SQLAlchemy()
  519. def create_app():
  520. app = Flask(__name__)
  521. db.init_app(app)
  522. return app
  523. The difference between the two is that in the first case methods like
  524. :meth:`create_all` and :meth:`drop_all` will work all the time but in
  525. the second case a :meth:`flask.Flask.app_context` has to exist.
  526. By default Flask-SQLAlchemy will apply some backend-specific settings
  527. to improve your experience with them.
  528. As of SQLAlchemy 0.6 SQLAlchemy
  529. will probe the library for native unicode support. If it detects
  530. unicode it will let the library handle that, otherwise do that itself.
  531. Sometimes this detection can fail in which case you might want to set
  532. ``use_native_unicode`` (or the ``SQLALCHEMY_NATIVE_UNICODE`` configuration
  533. key) to ``False``. Note that the configuration key overrides the
  534. value you pass to the constructor. Direct support for ``use_native_unicode``
  535. and SQLALCHEMY_NATIVE_UNICODE are deprecated as of v2.4 and will be removed
  536. in v3.0. ``engine_options`` and ``SQLALCHEMY_ENGINE_OPTIONS`` may be used
  537. instead.
  538. This class also provides access to all the SQLAlchemy functions and classes
  539. from the :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` modules. So you can
  540. declare models like this::
  541. class User(db.Model):
  542. username = db.Column(db.String(80), unique=True)
  543. pw_hash = db.Column(db.String(80))
  544. You can still use :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly, but
  545. note that Flask-SQLAlchemy customizations are available only through an
  546. instance of this :class:`SQLAlchemy` class. Query classes default to
  547. :class:`BaseQuery` for `db.Query`, `db.Model.query_class`, and the default
  548. query_class for `db.relationship` and `db.backref`. If you use these
  549. interfaces through :mod:`sqlalchemy` and :mod:`sqlalchemy.orm` directly,
  550. the default query class will be that of :mod:`sqlalchemy`.
  551. .. admonition:: Check types carefully
  552. Don't perform type or `isinstance` checks against `db.Table`, which
  553. emulates `Table` behavior but is not a class. `db.Table` exposes the
  554. `Table` interface, but is a function which allows omission of metadata.
  555. The ``session_options`` parameter, if provided, is a dict of parameters
  556. to be passed to the session constructor. See :class:`~sqlalchemy.orm.session.Session`
  557. for the standard options.
  558. The ``engine_options`` parameter, if provided, is a dict of parameters
  559. to be passed to create engine. See :func:`~sqlalchemy.create_engine`
  560. for the standard options. The values given here will be merged with and
  561. override anything set in the ``'SQLALCHEMY_ENGINE_OPTIONS'`` config
  562. variable or othewise set by this library.
  563. .. versionadded:: 0.10
  564. The `session_options` parameter was added.
  565. .. versionadded:: 0.16
  566. `scopefunc` is now accepted on `session_options`. It allows specifying
  567. a custom function which will define the SQLAlchemy session's scoping.
  568. .. versionadded:: 2.1
  569. The `metadata` parameter was added. This allows for setting custom
  570. naming conventions among other, non-trivial things.
  571. The `query_class` parameter was added, to allow customisation
  572. of the query class, in place of the default of :class:`BaseQuery`.
  573. The `model_class` parameter was added, which allows a custom model
  574. class to be used in place of :class:`Model`.
  575. .. versionchanged:: 2.1
  576. Utilise the same query class across `session`, `Model.query` and `Query`.
  577. .. versionadded:: 2.4
  578. The `engine_options` parameter was added.
  579. .. versionchanged:: 2.4
  580. The `use_native_unicode` parameter was deprecated.
  581. .. versionchanged:: 2.4.3
  582. ``COMMIT_ON_TEARDOWN`` is deprecated and will be removed in
  583. version 3.1. Call ``db.session.commit()`` directly instead.
  584. """
  585. #: Default query class used by :attr:`Model.query` and other queries.
  586. #: Customize this by passing ``query_class`` to :func:`SQLAlchemy`.
  587. #: Defaults to :class:`BaseQuery`.
  588. Query = None
  589. def __init__(self, app=None, use_native_unicode=True, session_options=None,
  590. metadata=None, query_class=BaseQuery, model_class=Model,
  591. engine_options=None):
  592. self.use_native_unicode = use_native_unicode
  593. self.Query = query_class
  594. self.session = self.create_scoped_session(session_options)
  595. self.Model = self.make_declarative_base(model_class, metadata)
  596. self._engine_lock = Lock()
  597. self.app = app
  598. self._engine_options = engine_options or {}
  599. _include_sqlalchemy(self, query_class)
  600. if app is not None:
  601. self.init_app(app)
  602. @property
  603. def metadata(self):
  604. """The metadata associated with ``db.Model``."""
  605. return self.Model.metadata
  606. def create_scoped_session(self, options=None):
  607. """Create a :class:`~sqlalchemy.orm.scoping.scoped_session`
  608. on the factory from :meth:`create_session`.
  609. An extra key ``'scopefunc'`` can be set on the ``options`` dict to
  610. specify a custom scope function. If it's not provided, Flask's app
  611. context stack identity is used. This will ensure that sessions are
  612. created and removed with the request/response cycle, and should be fine
  613. in most cases.
  614. :param options: dict of keyword arguments passed to session class in
  615. ``create_session``
  616. """
  617. if options is None:
  618. options = {}
  619. scopefunc = options.pop('scopefunc', _ident_func)
  620. options.setdefault('query_cls', self.Query)
  621. return orm.scoped_session(
  622. self.create_session(options), scopefunc=scopefunc
  623. )
  624. def create_session(self, options):
  625. """Create the session factory used by :meth:`create_scoped_session`.
  626. The factory **must** return an object that SQLAlchemy recognizes as a session,
  627. or registering session events may raise an exception.
  628. Valid factories include a :class:`~sqlalchemy.orm.session.Session`
  629. class or a :class:`~sqlalchemy.orm.session.sessionmaker`.
  630. The default implementation creates a ``sessionmaker`` for :class:`SignallingSession`.
  631. :param options: dict of keyword arguments passed to session class
  632. """
  633. return orm.sessionmaker(class_=SignallingSession, db=self, **options)
  634. def make_declarative_base(self, model, metadata=None):
  635. """Creates the declarative base that all models will inherit from.
  636. :param model: base model class (or a tuple of base classes) to pass
  637. to :func:`~sqlalchemy.ext.declarative.declarative_base`. Or a class
  638. returned from ``declarative_base``, in which case a new base class
  639. is not created.
  640. :param metadata: :class:`~sqlalchemy.MetaData` instance to use, or
  641. none to use SQLAlchemy's default.
  642. .. versionchanged 2.3.0::
  643. ``model`` can be an existing declarative base in order to support
  644. complex customization such as changing the metaclass.
  645. """
  646. if not isinstance(model, DeclarativeMeta):
  647. model = declarative_base(
  648. cls=model,
  649. name='Model',
  650. metadata=metadata,
  651. metaclass=DefaultMeta
  652. )
  653. # if user passed in a declarative base and a metaclass for some reason,
  654. # make sure the base uses the metaclass
  655. if metadata is not None and model.metadata is not metadata:
  656. model.metadata = metadata
  657. if not getattr(model, 'query_class', None):
  658. model.query_class = self.Query
  659. model.query = _QueryProperty(self)
  660. return model
  661. def init_app(self, app):
  662. """This callback can be used to initialize an application for the
  663. use with this database setup. Never use a database in the context
  664. of an application not initialized that way or connections will
  665. leak.
  666. """
  667. if (
  668. 'SQLALCHEMY_DATABASE_URI' not in app.config and
  669. 'SQLALCHEMY_BINDS' not in app.config
  670. ):
  671. warnings.warn(
  672. 'Neither SQLALCHEMY_DATABASE_URI nor SQLALCHEMY_BINDS is set. '
  673. 'Defaulting SQLALCHEMY_DATABASE_URI to "sqlite:///:memory:".'
  674. )
  675. app.config.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:')
  676. app.config.setdefault('SQLALCHEMY_BINDS', None)
  677. app.config.setdefault('SQLALCHEMY_NATIVE_UNICODE', None)
  678. app.config.setdefault('SQLALCHEMY_ECHO', False)
  679. app.config.setdefault('SQLALCHEMY_RECORD_QUERIES', None)
  680. app.config.setdefault('SQLALCHEMY_POOL_SIZE', None)
  681. app.config.setdefault('SQLALCHEMY_POOL_TIMEOUT', None)
  682. app.config.setdefault('SQLALCHEMY_POOL_RECYCLE', None)
  683. app.config.setdefault('SQLALCHEMY_MAX_OVERFLOW', None)
  684. app.config.setdefault('SQLALCHEMY_COMMIT_ON_TEARDOWN', False)
  685. track_modifications = app.config.setdefault(
  686. 'SQLALCHEMY_TRACK_MODIFICATIONS', None
  687. )
  688. app.config.setdefault('SQLALCHEMY_ENGINE_OPTIONS', {})
  689. if track_modifications is None:
  690. warnings.warn(FSADeprecationWarning(
  691. 'SQLALCHEMY_TRACK_MODIFICATIONS adds significant overhead and '
  692. 'will be disabled by default in the future. Set it to True '
  693. 'or False to suppress this warning.'
  694. ))
  695. # Deprecation warnings for config keys that should be replaced by SQLALCHEMY_ENGINE_OPTIONS.
  696. utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_SIZE', 'pool_size')
  697. utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_TIMEOUT', 'pool_timeout')
  698. utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_RECYCLE', 'pool_recycle')
  699. utils.engine_config_warning(app.config, '3.0', 'SQLALCHEMY_MAX_OVERFLOW', 'max_overflow')
  700. app.extensions['sqlalchemy'] = _SQLAlchemyState(self)
  701. @app.teardown_appcontext
  702. def shutdown_session(response_or_exc):
  703. if app.config['SQLALCHEMY_COMMIT_ON_TEARDOWN']:
  704. warnings.warn(
  705. "'COMMIT_ON_TEARDOWN' is deprecated and will be"
  706. " removed in version 3.1. Call"
  707. " 'db.session.commit()'` directly instead.",
  708. DeprecationWarning,
  709. )
  710. if response_or_exc is None:
  711. self.session.commit()
  712. self.session.remove()
  713. return response_or_exc
  714. def apply_pool_defaults(self, app, options):
  715. """
  716. .. versionchanged:: 2.5
  717. Returns the ``options`` dict, for consistency with
  718. :meth:`apply_driver_hacks`.
  719. """
  720. def _setdefault(optionkey, configkey):
  721. value = app.config[configkey]
  722. if value is not None:
  723. options[optionkey] = value
  724. _setdefault('pool_size', 'SQLALCHEMY_POOL_SIZE')
  725. _setdefault('pool_timeout', 'SQLALCHEMY_POOL_TIMEOUT')
  726. _setdefault('pool_recycle', 'SQLALCHEMY_POOL_RECYCLE')
  727. _setdefault('max_overflow', 'SQLALCHEMY_MAX_OVERFLOW')
  728. return options
  729. def apply_driver_hacks(self, app, sa_url, options):
  730. """This method is called before engine creation and used to inject
  731. driver specific hacks into the options. The `options` parameter is
  732. a dictionary of keyword arguments that will then be used to call
  733. the :func:`sqlalchemy.create_engine` function.
  734. The default implementation provides some saner defaults for things
  735. like pool sizes for MySQL and sqlite. Also it injects the setting of
  736. `SQLALCHEMY_NATIVE_UNICODE`.
  737. .. versionchanged:: 2.5
  738. Returns ``(sa_url, options)``. SQLAlchemy 1.4 made the URL
  739. immutable, so any changes to it must now be passed back up
  740. to the original caller.
  741. """
  742. if sa_url.drivername.startswith('mysql'):
  743. sa_url = _sa_url_query_setdefault(sa_url, charset="utf8")
  744. if sa_url.drivername != 'mysql+gaerdbms':
  745. options.setdefault('pool_size', 10)
  746. options.setdefault('pool_recycle', 7200)
  747. elif sa_url.drivername == 'sqlite':
  748. pool_size = options.get('pool_size')
  749. detected_in_memory = False
  750. if sa_url.database in (None, '', ':memory:'):
  751. detected_in_memory = True
  752. from sqlalchemy.pool import StaticPool
  753. options['poolclass'] = StaticPool
  754. if 'connect_args' not in options:
  755. options['connect_args'] = {}
  756. options['connect_args']['check_same_thread'] = False
  757. # we go to memory and the pool size was explicitly set
  758. # to 0 which is fail. Let the user know that
  759. if pool_size == 0:
  760. raise RuntimeError('SQLite in memory database with an '
  761. 'empty queue not possible due to data '
  762. 'loss.')
  763. # if pool size is None or explicitly set to 0 we assume the
  764. # user did not want a queue for this sqlite connection and
  765. # hook in the null pool.
  766. elif not pool_size:
  767. from sqlalchemy.pool import NullPool
  768. options['poolclass'] = NullPool
  769. # if it's not an in memory database we make the path absolute.
  770. if not detected_in_memory:
  771. sa_url = _sa_url_set(
  772. sa_url, database=os.path.join(app.root_path, sa_url.database)
  773. )
  774. unu = app.config['SQLALCHEMY_NATIVE_UNICODE']
  775. if unu is None:
  776. unu = self.use_native_unicode
  777. if not unu:
  778. options['use_native_unicode'] = False
  779. if app.config['SQLALCHEMY_NATIVE_UNICODE'] is not None:
  780. warnings.warn(
  781. "The 'SQLALCHEMY_NATIVE_UNICODE' config option is deprecated and will be removed in"
  782. " v3.0. Use 'SQLALCHEMY_ENGINE_OPTIONS' instead.",
  783. DeprecationWarning
  784. )
  785. if not self.use_native_unicode:
  786. warnings.warn(
  787. "'use_native_unicode' is deprecated and will be removed in v3.0."
  788. " Use the 'engine_options' parameter instead.",
  789. DeprecationWarning
  790. )
  791. return sa_url, options
  792. @property
  793. def engine(self):
  794. """Gives access to the engine. If the database configuration is bound
  795. to a specific application (initialized with an application) this will
  796. always return a database connection. If however the current application
  797. is used this might raise a :exc:`RuntimeError` if no application is
  798. active at the moment.
  799. """
  800. return self.get_engine()
  801. def make_connector(self, app=None, bind=None):
  802. """Creates the connector for a given state and bind."""
  803. return _EngineConnector(self, self.get_app(app), bind)
  804. def get_engine(self, app=None, bind=None):
  805. """Returns a specific engine."""
  806. app = self.get_app(app)
  807. state = get_state(app)
  808. with self._engine_lock:
  809. connector = state.connectors.get(bind)
  810. if connector is None:
  811. connector = self.make_connector(app, bind)
  812. state.connectors[bind] = connector
  813. return connector.get_engine()
  814. def create_engine(self, sa_url, engine_opts):
  815. """
  816. Override this method to have final say over how the SQLAlchemy engine
  817. is created.
  818. In most cases, you will want to use ``'SQLALCHEMY_ENGINE_OPTIONS'``
  819. config variable or set ``engine_options`` for :func:`SQLAlchemy`.
  820. """
  821. return sqlalchemy.create_engine(sa_url, **engine_opts)
  822. def get_app(self, reference_app=None):
  823. """Helper method that implements the logic to look up an
  824. application."""
  825. if reference_app is not None:
  826. return reference_app
  827. if current_app:
  828. return current_app._get_current_object()
  829. if self.app is not None:
  830. return self.app
  831. raise RuntimeError(
  832. 'No application found. Either work inside a view function or push'
  833. ' an application context. See'
  834. ' http://flask-sqlalchemy.pocoo.org/contexts/.'
  835. )
  836. def get_tables_for_bind(self, bind=None):
  837. """Returns a list of all tables relevant for a bind."""
  838. result = []
  839. for table in itervalues(self.Model.metadata.tables):
  840. if table.info.get('bind_key') == bind:
  841. result.append(table)
  842. return result
  843. def get_binds(self, app=None):
  844. """Returns a dictionary with a table->engine mapping.
  845. This is suitable for use of sessionmaker(binds=db.get_binds(app)).
  846. """
  847. app = self.get_app(app)
  848. binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
  849. retval = {}
  850. for bind in binds:
  851. engine = self.get_engine(app, bind)
  852. tables = self.get_tables_for_bind(bind)
  853. retval.update(dict((table, engine) for table in tables))
  854. return retval
  855. def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
  856. app = self.get_app(app)
  857. if bind == '__all__':
  858. binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
  859. elif isinstance(bind, string_types) or bind is None:
  860. binds = [bind]
  861. else:
  862. binds = bind
  863. for bind in binds:
  864. extra = {}
  865. if not skip_tables:
  866. tables = self.get_tables_for_bind(bind)
  867. extra['tables'] = tables
  868. op = getattr(self.Model.metadata, operation)
  869. op(bind=self.get_engine(app, bind), **extra)
  870. def create_all(self, bind='__all__', app=None):
  871. """Creates all tables.
  872. .. versionchanged:: 0.12
  873. Parameters were added
  874. """
  875. self._execute_for_all_tables(app, bind, 'create_all')
  876. def drop_all(self, bind='__all__', app=None):
  877. """Drops all tables.
  878. .. versionchanged:: 0.12
  879. Parameters were added
  880. """
  881. self._execute_for_all_tables(app, bind, 'drop_all')
  882. def reflect(self, bind='__all__', app=None):
  883. """Reflects tables from the database.
  884. .. versionchanged:: 0.12
  885. Parameters were added
  886. """
  887. self._execute_for_all_tables(app, bind, 'reflect', skip_tables=True)
  888. def __repr__(self):
  889. return '<%s engine=%r>' % (
  890. self.__class__.__name__,
  891. self.engine.url if self.app or current_app else None
  892. )
  893. class _BoundDeclarativeMeta(DefaultMeta):
  894. def __init__(cls, name, bases, d):
  895. warnings.warn(FSADeprecationWarning(
  896. '"_BoundDeclarativeMeta" has been renamed to "DefaultMeta". The'
  897. ' old name will be removed in 3.0.'
  898. ), stacklevel=3)
  899. super(_BoundDeclarativeMeta, cls).__init__(name, bases, d)
  900. class FSADeprecationWarning(DeprecationWarning):
  901. pass
  902. warnings.simplefilter('always', FSADeprecationWarning)