diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..891fa0ab 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -8,15 +8,28 @@ from __future__ import absolute_import from __future__ import unicode_literals +from packaging import version import re from sqlalchemy import exc from sqlalchemy import types from sqlalchemy import util +import sys + # TODO shouldn't use mysql type from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler +try: + from sqlalchemy.sql.expression import ( + Alias, + CTE, + Subquery, + ) +except ImportError: + from sqlalchemy.sql.expression import Alias + CTE = type(None) + Subquery = type(None) from pyhive import presto from pyhive.common import UniversalSet @@ -46,6 +59,37 @@ class PrestoCompiler(SQLCompiler): def visit_char_length_func(self, fn, **kw): return 'length{}'.format(self.function_argspec(fn, **kw)) + def visit_column(self, column, add_to_result_map=None, include_table=True, **kwargs): + sql = super(PrestoCompiler, self).visit_column( + column, add_to_result_map, include_table, **kwargs + ) + table = column.table + return self.__add_catalog(sql, table) + + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + fromhints=None, use_schema=True, **kwargs): + sql = super(PrestoCompiler, self).visit_table( + table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs + ) + return self.__add_catalog(sql, table) + + def __add_catalog(self, sql, table): + if table is None: + return sql + + if isinstance(table, (Alias, CTE, Subquery)): + return sql + + if ( + "presto" not in table.dialect_options + or "catalog" not in table.dialect_options["presto"]._non_defaults + ): + return sql + + catalog = table.dialect_options["presto"]._non_defaults["catalog"] + sql = "\"{catalog}\".{sql}".format(catalog=catalog, sql=sql) + return sql + class PrestoTypeCompiler(compiler.GenericTypeCompiler): def visit_CLOB(self, type_, **kw): @@ -83,7 +127,10 @@ class PrestoDialect(default.DefaultDialect): returns_unicode_strings = True description_encoding = None supports_native_boolean = True + if version.parse(sys.modules['sqlalchemy'].__version__) >= version.parse('1.4.5'): + supports_statement_cache = False type_compiler = PrestoTypeCompiler + cte_follows_insert = True @classmethod def dbapi(cls): diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py index a01e4a35..2b339fe7 100644 --- a/pyhive/tests/test_sqlalchemy_presto.py +++ b/pyhive/tests/test_sqlalchemy_presto.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import unicode_literals from builtins import str +from decimal import Decimal from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase from pyhive.tests.sqlalchemy_test_case import with_engine_connection from sqlalchemy import types @@ -48,7 +49,7 @@ def test_reflect_select(self, engine, connection): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), ]) # TODO some of these types could be filled in better @@ -85,3 +86,16 @@ def test_reserved_words(self, engine, connection): self.assertIn('"current_timestamp"', query) self.assertNotIn('`select`', query) self.assertNotIn('`current_timestamp`', query) + + @with_engine_connection + def test_multiple_catalogs(self, engine, connection): + system_table = Table( + 'tables', + MetaData(bind=engine), + autoload=True, + schema='information_schema', + presto_catalog='system' + ) + query = str(system_table.select()) + self.assertIn('"system"."information_schema"', query) + self.assertNotIn('"hive"', query)