diff --git a/pynuodb/cursor.py b/pynuodb/cursor.py index 4e515f5..a7f792a 100644 --- a/pynuodb/cursor.py +++ b/pynuodb/cursor.py @@ -165,7 +165,7 @@ def _executeprepared(self, operation, parameters): :raises ProgrammingError: Incorrect number of parameters """ p_statement = self._statement_cache.get_prepared_statement(operation) - if p_statement.parameter_count != len(parameters): + if p_statement.parameter_count != len(parameters) and p_statement.handle != -1: raise ProgrammingError( "Incorrect number of parameters: expected %d, got %d" % (p_statement.parameter_count, len(parameters))) @@ -275,8 +275,7 @@ def get_prepared_statement(self, query): self._ps_key_queue.append(query) return stmt - stmt = self._session.create_prepared_statement(query) - + stmt = self._session.create_local_prepared_statement(query) while len(self._ps_cache) >= self._ps_cache_size: lru_statement_key = self._ps_key_queue.popleft() statement_to_remove = self._ps_cache[lru_statement_key] diff --git a/pynuodb/encodedsession.py b/pynuodb/encodedsession.py index b319811..cbdd0ed 100644 --- a/pynuodb/encodedsession.py +++ b/pynuodb/encodedsession.py @@ -359,9 +359,10 @@ def test_connection(self): self._putMessageId(protocol.CREATE) self._exchangeMessages() handle = self.getInt() + stmt = statement.Statement(handle) - # Use handle to query dual - self._setup_statement(handle, protocol.EXECUTEQUERY) + # Use statement to query dual + self._setup_statement(stmt, protocol.EXECUTEQUERY) self.putString('select 1 as one from dual') self._exchangeMessages() @@ -401,7 +402,8 @@ def execute_statement(self, stmt, query): :param query: Operation to be executed. :returns: The result of the operation execution. """ - self._setup_statement(stmt.handle, protocol.EXECUTE).putString(query) + stmt.query = query + self._setup_statement(stmt, protocol.EXECUTE).putString(query) self._exchangeMessages() result = self.getInt() @@ -422,6 +424,15 @@ def close_result_set(self, resultset): self._putMessageId(protocol.CLOSERESULTSET).putInt(resultset.handle) self._exchangeMessages(False) + def create_local_prepared_statement(self, query): + # type: (str) -> statement.PreparedStatement + """Create a local prepared statement for the given query.""" + if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER: + stmt = statement.PreparedStatement(-1, -1) + stmt.query = query + return stmt + return self.create_prepared_statement(query) + def create_prepared_statement(self, query): # type: (str) -> statement.PreparedStatement """Create a prepared statement for the given query.""" @@ -445,7 +456,10 @@ def execute_prepared_statement( ): # type: (...) -> statement.ExecutionResult """Execute a prepared statement with the given parameters.""" - self._setup_statement(prepared_statement.handle, protocol.EXECUTEPREPAREDSTATEMENT) + if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER and not self._isPrepared(prepared_statement): + self._setup_statement(prepared_statement, protocol.PREPAREANDEXECUTETOGETHER) + else: + self._setup_statement(prepared_statement, protocol.EXECUTEPREPAREDSTATEMENT) self.putInt(len(parameters)) for param in parameters: @@ -453,6 +467,11 @@ def execute_prepared_statement( self._exchangeMessages() + # Update handle and parameter count if needed + if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER and not self._isPrepared(prepared_statement): + prepared_statement.handle = self.getInt() + prepared_statement.parameter_count = self.getInt() + result = self.getInt() rowcount = self.getInt() self.__execute_postfix() @@ -462,7 +481,7 @@ def execute_prepared_statement( def execute_batch_prepared_statement(self, prepared_statement, param_lists): # type: (statement.PreparedStatement, Collection[Collection[result_set.Value]]) -> List[int] """Batch the prepared statement with the given parameters.""" - self._setup_statement(prepared_statement.handle, protocol.EXECUTEBATCHPREPAREDSTATEMENT) + self._setup_statement(prepared_statement, protocol.EXECUTEBATCHPREPAREDSTATEMENT) for parameters in param_lists: plen = len(parameters) @@ -1317,14 +1336,26 @@ def _exchangeMessages(self, getResponse=True): # Protected utility routines - def _setup_statement(self, handle, msgId): - # type: (int, int) -> EncodedSession + def _setup_statement(self, prepared_statement, msgId): + # type: (statement.PreparedStatement, int) -> EncodedSession """Set up a new statement. - :type handle: int + :type prepared_statement: statement.PreparedStatement :type msgId: int """ + if msgId != protocol.PREPAREANDEXECUTETOGETHER and not self._isPrepared(prepared_statement): + statement = self.create_prepared_statement(prepared_statement.query) + prepared_statement.handle = statement.handle + prepared_statement.parameter_count = statement.parameter_count + self._putMessageId(msgId) + + if msgId == protocol.PREPAREANDEXECUTETOGETHER: + self.putInt(protocol.DEFAULT_EXECUTE_SUBTYPE) #executeSubtype + self.putInt(protocol.DEFAULT_PREPARE_SUBTYPE) #prepareSubtype + self.putString(prepared_statement.query) + self.putInt(protocol.DEFAULT_EXECUTE_TIMEOUT_MS) # timeout + self.putInt(protocol.DEFAULT_FETCH_SIZE) # fetchsize if self.__sessionVersion >= protocol.LAST_COMMIT_INFO: with EncodedSession.__dblock: self.putInt(len(self.__dbinfo)) @@ -1332,8 +1363,8 @@ def _setup_statement(self, handle, msgId): self.putInt(sid) self.putInt(tup[0]) self.putInt(tup[1]) - self.putInt(handle) - + if msgId != protocol.PREPAREANDEXECUTETOGETHER: + self.putInt(prepared_statement.handle) return self def _hasBytes(self, length): @@ -1368,3 +1399,9 @@ def _takeBytes(self, length): return self.__input[self.__inpos:self.__inpos + length] finally: self.__inpos += length + def _isPrepared(self,prepared_statement): + # type: (statement.PreparedStatement) -> bool + """Check if the prepared statement is valid.""" + if prepared_statement.handle == -1: + return False + return True diff --git a/pynuodb/protocol.py b/pynuodb/protocol.py index f559127..af8a4d4 100644 --- a/pynuodb/protocol.py +++ b/pynuodb/protocol.py @@ -88,6 +88,7 @@ EXECUTEPREPAREDSTATEMENT = 22 EXECUTEPREPAREDQUERY = 23 EXECUTEPREPAREDUPDATE = 24 +PREPAREANDEXECUTETOGETHER = 25 GETMETADATA = 26 NEXT = 27 CLOSERESULTSET = 28 @@ -343,6 +344,13 @@ def lookup_code(error_code): """Return a string-ified version of an error code.""" return stringifyError.get(error_code, '[UNKNOWN ERROR CODE]') +# +# Default Options +# +DEFAULT_EXECUTE_TIMEOUT_MS = 0 +DEFAULT_FETCH_SIZE = 0 +DEFAULT_EXECUTE_SUBTYPE = 5 +DEFAULT_PREPARE_SUBTYPE = 0 # # NuoDB Client-Server Features @@ -378,5 +386,5 @@ def lookup_code(error_code): # The newest feature this driver supports. # The server will negotiate the highest compatible version. CURRENT_PROTOCOL_MAJOR = 1 -CURRENT_PROTOCOL_VERSION = TIMESTAMP_WITHOUT_TZ +CURRENT_PROTOCOL_VERSION = PREPARE_AND_EXECUTE_TOGETHER AUTH_TEST_STR = 'Success!' diff --git a/pynuodb/statement.py b/pynuodb/statement.py index 17fee3e..0b155fc 100644 --- a/pynuodb/statement.py +++ b/pynuodb/statement.py @@ -37,6 +37,7 @@ def __init__(self, handle, parameter_count): super(PreparedStatement, self).__init__(handle) self.parameter_count = parameter_count self.description = None # type: Optional[List[List[Any]]] + self.query = None # type: Optional[str] class ExecutionResult(object): diff --git a/tests/nuodb_cursor_test.py b/tests/nuodb_cursor_test.py index 8a0badd..a8aeb57 100644 --- a/tests/nuodb_cursor_test.py +++ b/tests/nuodb_cursor_test.py @@ -7,7 +7,7 @@ import pytest -from pynuodb.exception import DataError, ProgrammingError, BatchError, OperationalError +from pynuodb.exception import DataError, ProgrammingError, BatchError, OperationalError, DatabaseError from . import nuodb_base @@ -45,17 +45,17 @@ def test_insufficient_parameters(self): con = self._connect() cursor = con.cursor() - with pytest.raises(ProgrammingError): + with pytest.raises(DatabaseError): cursor.execute("SELECT ?, ? FROM DUAL", [1]) def test_toomany_parameters(self): con = self._connect() cursor = con.cursor() - with pytest.raises(ProgrammingError): + with pytest.raises(DatabaseError): cursor.execute("SELECT 1 FROM DUAL", [1]) - with pytest.raises(ProgrammingError): + with pytest.raises(DatabaseError): cursor.execute("SELECT ? FROM DUAL", [1, 2]) def test_incorrect_parameters(self):