diff --git a/pandasql/sqldf.py b/pandasql/sqldf.py index e25398a..4c6087f 100644 --- a/pandasql/sqldf.py +++ b/pandasql/sqldf.py @@ -59,15 +59,25 @@ def __call__(self, query, env=None): continue self.loaded_tables.add(table_name) write_table(env[table_name], table_name, conn) + result = self.__actually_execute_query(query,conn) + # But if result is None, this might also mean that an UDPATE or DELETE clause has been executed! + if not result: + table_to_query = re.findall(r'(?:FROM|JOIN|UPDATE)\s+(\w+(?:\s*,\s*\w+)*)', query, re.IGNORECASE)[0] + result = self.__actually_execute_query(f"SELECT * FROM {table_to_query}", conn) + return result - try: - result = read_sql(query, conn) - except DatabaseError as ex: - raise PandaSQLException(ex) - except ResourceClosedError: - # query returns nothing - result = None - + def __actually_execute_query(self, query, conn): + """ + Actually executes the SQL query + :return the query result + """ + try: + result = read_sql(query, conn) + except DatabaseError as ex: + raise PandaSQLException(ex) + except ResourceClosedError: + # query returns nothing. + result = None return result @property @@ -110,7 +120,7 @@ def get_outer_frame_variables(): def extract_table_names(query): """ Extract table names from an SQL query. """ # a good old fashioned regex. turns out this worked better than actually parsing the code - tables_blocks = re.findall(r'(?:FROM|JOIN)\s+(\w+(?:\s*,\s*\w+)*)', query, re.IGNORECASE) + tables_blocks = re.findall(r'(?:FROM|JOIN|UPDATE)\s+(\w+(?:\s*,\s*\w+)*)', query, re.IGNORECASE) tables = [tbl for block in tables_blocks for tbl in re.findall(r'\w+', block)]