You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

587 lines
16KB

  1. # postgresql/pg8000.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS
  3. # file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. r"""
  8. .. dialect:: postgresql+pg8000
  9. :name: pg8000
  10. :dbapi: pg8000
  11. :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
  12. :url: https://pypi.org/project/pg8000/
  13. .. versionchanged:: 1.4 The pg8000 dialect has been updated for version
  14. 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
  15. with full feature support.
  16. .. _pg8000_unicode:
  17. Unicode
  18. -------
  19. pg8000 will encode / decode string values between it and the server using the
  20. PostgreSQL ``client_encoding`` parameter; by default this is the value in
  21. the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
  22. Typically, this can be changed to ``utf-8``, as a more useful default::
  23. #client_encoding = sql_ascii # actually, defaults to database
  24. # encoding
  25. client_encoding = utf8
  26. The ``client_encoding`` can be overridden for a session by executing the SQL:
  27. SET CLIENT_ENCODING TO 'utf8';
  28. SQLAlchemy will execute this SQL on all new connections based on the value
  29. passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
  30. engine = create_engine(
  31. "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
  32. .. _pg8000_ssl:
  33. SSL Connections
  34. ---------------
  35. pg8000 accepts a Python ``SSLContext`` object which may be specified using the
  36. :paramref:`_sa.create_engine.connect_args` dictionary::
  37. import ssl
  38. ssl_context = ssl.create_default_context()
  39. engine = sa.create_engine(
  40. "postgresql+pg8000://scott:tiger@192.168.0.199/test",
  41. connect_args={"ssl_context": ssl_context},
  42. )
  43. If the server uses an automatically-generated certificate that is self-signed
  44. or does not match the host name (as seen from the client), it may also be
  45. necessary to disable hostname checking::
  46. import ssl
  47. ssl_context = ssl.create_default_context()
  48. ssl_context.check_hostname = False
  49. ssl_context.verify_mode = ssl.CERT_NONE
  50. engine = sa.create_engine(
  51. "postgresql+pg8000://scott:tiger@192.168.0.199/test",
  52. connect_args={"ssl_context": ssl_context},
  53. )
  54. .. _pg8000_isolation_level:
  55. pg8000 Transaction Isolation Level
  56. -------------------------------------
  57. The pg8000 dialect offers the same isolation level settings as that
  58. of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
  59. * ``READ COMMITTED``
  60. * ``READ UNCOMMITTED``
  61. * ``REPEATABLE READ``
  62. * ``SERIALIZABLE``
  63. * ``AUTOCOMMIT``
  64. .. seealso::
  65. :ref:`postgresql_isolation_level`
  66. :ref:`psycopg2_isolation_level`
  67. """ # noqa
  68. import decimal
  69. import re
  70. from uuid import UUID as _python_UUID
  71. from .base import _DECIMAL_TYPES
  72. from .base import _FLOAT_TYPES
  73. from .base import _INT_TYPES
  74. from .base import ENUM
  75. from .base import INTERVAL
  76. from .base import PGCompiler
  77. from .base import PGDialect
  78. from .base import PGExecutionContext
  79. from .base import PGIdentifierPreparer
  80. from .base import UUID
  81. from .json import JSON
  82. from .json import JSONB
  83. from .json import JSONPathType
  84. from ... import exc
  85. from ... import processors
  86. from ... import types as sqltypes
  87. from ... import util
  88. from ...sql.elements import quoted_name
  89. class _PGNumeric(sqltypes.Numeric):
  90. def result_processor(self, dialect, coltype):
  91. if self.asdecimal:
  92. if coltype in _FLOAT_TYPES:
  93. return processors.to_decimal_processor_factory(
  94. decimal.Decimal, self._effective_decimal_return_scale
  95. )
  96. elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
  97. # pg8000 returns Decimal natively for 1700
  98. return None
  99. else:
  100. raise exc.InvalidRequestError(
  101. "Unknown PG numeric type: %d" % coltype
  102. )
  103. else:
  104. if coltype in _FLOAT_TYPES:
  105. # pg8000 returns float natively for 701
  106. return None
  107. elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
  108. return processors.to_float
  109. else:
  110. raise exc.InvalidRequestError(
  111. "Unknown PG numeric type: %d" % coltype
  112. )
  113. class _PGNumericNoBind(_PGNumeric):
  114. def bind_processor(self, dialect):
  115. return None
  116. class _PGJSON(JSON):
  117. def result_processor(self, dialect, coltype):
  118. return None
  119. def get_dbapi_type(self, dbapi):
  120. return dbapi.JSON
  121. class _PGJSONB(JSONB):
  122. def result_processor(self, dialect, coltype):
  123. return None
  124. def get_dbapi_type(self, dbapi):
  125. return dbapi.JSONB
  126. class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
  127. def get_dbapi_type(self, dbapi):
  128. raise NotImplementedError("should not be here")
  129. class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
  130. def get_dbapi_type(self, dbapi):
  131. return dbapi.INTEGER
  132. class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
  133. def get_dbapi_type(self, dbapi):
  134. return dbapi.STRING
  135. class _PGJSONPathType(JSONPathType):
  136. def get_dbapi_type(self, dbapi):
  137. return 1009
  138. class _PGUUID(UUID):
  139. def bind_processor(self, dialect):
  140. if not self.as_uuid:
  141. def process(value):
  142. if value is not None:
  143. value = _python_UUID(value)
  144. return value
  145. return process
  146. def result_processor(self, dialect, coltype):
  147. if not self.as_uuid:
  148. def process(value):
  149. if value is not None:
  150. value = str(value)
  151. return value
  152. return process
  153. class _PGEnum(ENUM):
  154. def get_dbapi_type(self, dbapi):
  155. return dbapi.UNKNOWN
  156. class _PGInterval(INTERVAL):
  157. def get_dbapi_type(self, dbapi):
  158. return dbapi.INTERVAL
  159. @classmethod
  160. def adapt_emulated_to_native(cls, interval, **kw):
  161. return _PGInterval(precision=interval.second_precision)
  162. class _PGTimeStamp(sqltypes.DateTime):
  163. def get_dbapi_type(self, dbapi):
  164. if self.timezone:
  165. # TIMESTAMPTZOID
  166. return 1184
  167. else:
  168. # TIMESTAMPOID
  169. return 1114
  170. class _PGTime(sqltypes.Time):
  171. def get_dbapi_type(self, dbapi):
  172. return dbapi.TIME
  173. class _PGInteger(sqltypes.Integer):
  174. def get_dbapi_type(self, dbapi):
  175. return dbapi.INTEGER
  176. class _PGSmallInteger(sqltypes.SmallInteger):
  177. def get_dbapi_type(self, dbapi):
  178. return dbapi.INTEGER
  179. class _PGNullType(sqltypes.NullType):
  180. def get_dbapi_type(self, dbapi):
  181. return dbapi.NULLTYPE
  182. class _PGBigInteger(sqltypes.BigInteger):
  183. def get_dbapi_type(self, dbapi):
  184. return dbapi.BIGINTEGER
  185. class _PGBoolean(sqltypes.Boolean):
  186. def get_dbapi_type(self, dbapi):
  187. return dbapi.BOOLEAN
  188. _server_side_id = util.counter()
  189. class PGExecutionContext_pg8000(PGExecutionContext):
  190. def create_server_side_cursor(self):
  191. ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
  192. return ServerSideCursor(self._dbapi_connection.cursor(), ident)
  193. def pre_exec(self):
  194. if not self.compiled:
  195. return
  196. class ServerSideCursor:
  197. server_side = True
  198. def __init__(self, cursor, ident):
  199. self.ident = ident
  200. self.cursor = cursor
  201. @property
  202. def connection(self):
  203. return self.cursor.connection
  204. @property
  205. def rowcount(self):
  206. return self.cursor.rowcount
  207. @property
  208. def description(self):
  209. return self.cursor.description
  210. def execute(self, operation, args=(), stream=None):
  211. op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
  212. self.cursor.execute(op, args, stream=stream)
  213. return self
  214. def executemany(self, operation, param_sets):
  215. self.cursor.executemany(operation, param_sets)
  216. return self
  217. def fetchone(self):
  218. self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
  219. return self.cursor.fetchone()
  220. def fetchmany(self, num=None):
  221. if num is None:
  222. return self.fetchall()
  223. else:
  224. self.cursor.execute(
  225. "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
  226. )
  227. return self.cursor.fetchall()
  228. def fetchall(self):
  229. self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
  230. return self.cursor.fetchall()
  231. def close(self):
  232. self.cursor.execute("CLOSE " + self.ident)
  233. self.cursor.close()
  234. def setinputsizes(self, *sizes):
  235. self.cursor.setinputsizes(*sizes)
  236. def setoutputsize(self, size, column=None):
  237. pass
  238. class PGCompiler_pg8000(PGCompiler):
  239. def visit_mod_binary(self, binary, operator, **kw):
  240. return (
  241. self.process(binary.left, **kw)
  242. + " %% "
  243. + self.process(binary.right, **kw)
  244. )
  245. class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
  246. def __init__(self, *args, **kwargs):
  247. PGIdentifierPreparer.__init__(self, *args, **kwargs)
  248. self._double_percents = False
  249. class PGDialect_pg8000(PGDialect):
  250. driver = "pg8000"
  251. supports_statement_cache = True
  252. supports_unicode_statements = True
  253. supports_unicode_binds = True
  254. default_paramstyle = "format"
  255. supports_sane_multi_rowcount = True
  256. execution_ctx_cls = PGExecutionContext_pg8000
  257. statement_compiler = PGCompiler_pg8000
  258. preparer = PGIdentifierPreparer_pg8000
  259. supports_server_side_cursors = True
  260. use_setinputsizes = True
  261. # reversed as of pg8000 1.16.6. 1.16.5 and lower
  262. # are no longer compatible
  263. description_encoding = None
  264. # description_encoding = "use_encoding"
  265. colspecs = util.update_copy(
  266. PGDialect.colspecs,
  267. {
  268. sqltypes.Numeric: _PGNumericNoBind,
  269. sqltypes.Float: _PGNumeric,
  270. sqltypes.JSON: _PGJSON,
  271. sqltypes.Boolean: _PGBoolean,
  272. sqltypes.NullType: _PGNullType,
  273. JSONB: _PGJSONB,
  274. sqltypes.JSON.JSONPathType: _PGJSONPathType,
  275. sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
  276. sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
  277. sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
  278. UUID: _PGUUID,
  279. sqltypes.Interval: _PGInterval,
  280. INTERVAL: _PGInterval,
  281. sqltypes.DateTime: _PGTimeStamp,
  282. sqltypes.Time: _PGTime,
  283. sqltypes.Integer: _PGInteger,
  284. sqltypes.SmallInteger: _PGSmallInteger,
  285. sqltypes.BigInteger: _PGBigInteger,
  286. sqltypes.Enum: _PGEnum,
  287. },
  288. )
  289. def __init__(self, client_encoding=None, **kwargs):
  290. PGDialect.__init__(self, **kwargs)
  291. self.client_encoding = client_encoding
  292. if self._dbapi_version < (1, 16, 6):
  293. raise NotImplementedError("pg8000 1.16.6 or greater is required")
  294. @util.memoized_property
  295. def _dbapi_version(self):
  296. if self.dbapi and hasattr(self.dbapi, "__version__"):
  297. return tuple(
  298. [
  299. int(x)
  300. for x in re.findall(
  301. r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
  302. )
  303. ]
  304. )
  305. else:
  306. return (99, 99, 99)
  307. @classmethod
  308. def dbapi(cls):
  309. return __import__("pg8000")
  310. def create_connect_args(self, url):
  311. opts = url.translate_connect_args(username="user")
  312. if "port" in opts:
  313. opts["port"] = int(opts["port"])
  314. opts.update(url.query)
  315. return ([], opts)
  316. def is_disconnect(self, e, connection, cursor):
  317. if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
  318. e
  319. ):
  320. # new as of pg8000 1.19.0 for broken connections
  321. return True
  322. # connection was closed normally
  323. return "connection is closed" in str(e)
  324. def set_isolation_level(self, connection, level):
  325. level = level.replace("_", " ")
  326. # adjust for ConnectionFairy possibly being present
  327. if hasattr(connection, "connection"):
  328. connection = connection.connection
  329. if level == "AUTOCOMMIT":
  330. connection.autocommit = True
  331. elif level in self._isolation_lookup:
  332. connection.autocommit = False
  333. cursor = connection.cursor()
  334. cursor.execute(
  335. "SET SESSION CHARACTERISTICS AS TRANSACTION "
  336. "ISOLATION LEVEL %s" % level
  337. )
  338. cursor.execute("COMMIT")
  339. cursor.close()
  340. else:
  341. raise exc.ArgumentError(
  342. "Invalid value '%s' for isolation_level. "
  343. "Valid isolation levels for %s are %s or AUTOCOMMIT"
  344. % (level, self.name, ", ".join(self._isolation_lookup))
  345. )
  346. def set_readonly(self, connection, value):
  347. cursor = connection.cursor()
  348. try:
  349. cursor.execute(
  350. "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
  351. % ("READ ONLY" if value else "READ WRITE")
  352. )
  353. cursor.execute("COMMIT")
  354. finally:
  355. cursor.close()
  356. def get_readonly(self, connection):
  357. cursor = connection.cursor()
  358. try:
  359. cursor.execute("show transaction_read_only")
  360. val = cursor.fetchone()[0]
  361. finally:
  362. cursor.close()
  363. return val == "on"
  364. def set_deferrable(self, connection, value):
  365. cursor = connection.cursor()
  366. try:
  367. cursor.execute(
  368. "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
  369. % ("DEFERRABLE" if value else "NOT DEFERRABLE")
  370. )
  371. cursor.execute("COMMIT")
  372. finally:
  373. cursor.close()
  374. def get_deferrable(self, connection):
  375. cursor = connection.cursor()
  376. try:
  377. cursor.execute("show transaction_deferrable")
  378. val = cursor.fetchone()[0]
  379. finally:
  380. cursor.close()
  381. return val == "on"
  382. def set_client_encoding(self, connection, client_encoding):
  383. # adjust for ConnectionFairy possibly being present
  384. if hasattr(connection, "connection"):
  385. connection = connection.connection
  386. cursor = connection.cursor()
  387. cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
  388. cursor.execute("COMMIT")
  389. cursor.close()
  390. def do_set_input_sizes(self, cursor, list_of_tuples, context):
  391. if self.positional:
  392. cursor.setinputsizes(
  393. *[dbtype for key, dbtype, sqltype in list_of_tuples]
  394. )
  395. else:
  396. cursor.setinputsizes(
  397. **{
  398. key: dbtype
  399. for key, dbtype, sqltype in list_of_tuples
  400. if dbtype
  401. }
  402. )
  403. def do_begin_twophase(self, connection, xid):
  404. connection.connection.tpc_begin((0, xid, ""))
  405. def do_prepare_twophase(self, connection, xid):
  406. connection.connection.tpc_prepare()
  407. def do_rollback_twophase(
  408. self, connection, xid, is_prepared=True, recover=False
  409. ):
  410. connection.connection.tpc_rollback((0, xid, ""))
  411. def do_commit_twophase(
  412. self, connection, xid, is_prepared=True, recover=False
  413. ):
  414. connection.connection.tpc_commit((0, xid, ""))
  415. def do_recover_twophase(self, connection):
  416. return [row[1] for row in connection.connection.tpc_recover()]
  417. def on_connect(self):
  418. fns = []
  419. def on_connect(conn):
  420. conn.py_types[quoted_name] = conn.py_types[util.text_type]
  421. fns.append(on_connect)
  422. if self.client_encoding is not None:
  423. def on_connect(conn):
  424. self.set_client_encoding(conn, self.client_encoding)
  425. fns.append(on_connect)
  426. if self.isolation_level is not None:
  427. def on_connect(conn):
  428. self.set_isolation_level(conn, self.isolation_level)
  429. fns.append(on_connect)
  430. if self._json_deserializer:
  431. def on_connect(conn):
  432. # json
  433. conn.register_in_adapter(114, self._json_deserializer)
  434. # jsonb
  435. conn.register_in_adapter(3802, self._json_deserializer)
  436. fns.append(on_connect)
  437. if len(fns) > 0:
  438. def on_connect(conn):
  439. for fn in fns:
  440. fn(conn)
  441. return on_connect
  442. else:
  443. return None
  444. dialect = PGDialect_pg8000