Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pynuodb/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _executeprepared(self, operation, parameters):
:raises ProgrammingError: Incorrect number of parameters
"""
p_statement = self._statement_cache.get_prepared_statement(operation)
if p_statement.parameter_count != len(parameters):
if p_statement.parameter_count != len(parameters) and p_statement.handle != -1:
raise ProgrammingError(
"Incorrect number of parameters: expected %d, got %d" %
(p_statement.parameter_count, len(parameters)))
Expand Down Expand Up @@ -275,8 +275,7 @@ def get_prepared_statement(self, query):
self._ps_key_queue.append(query)
return stmt

stmt = self._session.create_prepared_statement(query)

stmt = self._session.create_local_prepared_statement(query)
while len(self._ps_cache) >= self._ps_cache_size:
lru_statement_key = self._ps_key_queue.popleft()
statement_to_remove = self._ps_cache[lru_statement_key]
Expand Down
57 changes: 47 additions & 10 deletions pynuodb/encodedsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,10 @@ def test_connection(self):
self._putMessageId(protocol.CREATE)
self._exchangeMessages()
handle = self.getInt()
stmt = statement.Statement(handle)

# Use handle to query dual
self._setup_statement(handle, protocol.EXECUTEQUERY)
# Use statement to query dual
self._setup_statement(stmt, protocol.EXECUTEQUERY)
self.putString('select 1 as one from dual')
self._exchangeMessages()

Expand Down Expand Up @@ -401,7 +402,8 @@ def execute_statement(self, stmt, query):
:param query: Operation to be executed.
:returns: The result of the operation execution.
"""
self._setup_statement(stmt.handle, protocol.EXECUTE).putString(query)
stmt.query = query
self._setup_statement(stmt, protocol.EXECUTE).putString(query)
self._exchangeMessages()

result = self.getInt()
Expand All @@ -422,6 +424,15 @@ def close_result_set(self, resultset):
self._putMessageId(protocol.CLOSERESULTSET).putInt(resultset.handle)
self._exchangeMessages(False)

def create_local_prepared_statement(self, query):
# type: (str) -> statement.PreparedStatement
"""Create a local prepared statement for the given query."""
if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER:
stmt = statement.PreparedStatement(-1, -1)
stmt.query = query
return stmt
return self.create_prepared_statement(query)

def create_prepared_statement(self, query):
# type: (str) -> statement.PreparedStatement
"""Create a prepared statement for the given query."""
Expand All @@ -445,14 +456,22 @@ def execute_prepared_statement(
):
# type: (...) -> statement.ExecutionResult
"""Execute a prepared statement with the given parameters."""
self._setup_statement(prepared_statement.handle, protocol.EXECUTEPREPAREDSTATEMENT)
if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER and not self._isPrepared(prepared_statement):
self._setup_statement(prepared_statement, protocol.PREPAREANDEXECUTETOGETHER)
else:
self._setup_statement(prepared_statement, protocol.EXECUTEPREPAREDSTATEMENT)

self.putInt(len(parameters))
for param in parameters:
self.putValue(param)

self._exchangeMessages()

# Update handle and parameter count if needed
if self.__sessionVersion >= protocol.PREPARE_AND_EXECUTE_TOGETHER and not self._isPrepared(prepared_statement):
prepared_statement.handle = self.getInt()
prepared_statement.parameter_count = self.getInt()

result = self.getInt()
rowcount = self.getInt()
self.__execute_postfix()
Expand All @@ -462,7 +481,7 @@ def execute_prepared_statement(
def execute_batch_prepared_statement(self, prepared_statement, param_lists):
# type: (statement.PreparedStatement, Collection[Collection[result_set.Value]]) -> List[int]
"""Batch the prepared statement with the given parameters."""
self._setup_statement(prepared_statement.handle, protocol.EXECUTEBATCHPREPAREDSTATEMENT)
self._setup_statement(prepared_statement, protocol.EXECUTEBATCHPREPAREDSTATEMENT)

for parameters in param_lists:
plen = len(parameters)
Expand Down Expand Up @@ -1317,23 +1336,35 @@ def _exchangeMessages(self, getResponse=True):

# Protected utility routines

def _setup_statement(self, handle, msgId):
# type: (int, int) -> EncodedSession
def _setup_statement(self, prepared_statement, msgId):
# type: (statement.PreparedStatement, int) -> EncodedSession
"""Set up a new statement.

:type handle: int
:type prepared_statement: statement.PreparedStatement
:type msgId: int
"""
if msgId != protocol.PREPAREANDEXECUTETOGETHER and not self._isPrepared(prepared_statement):
statement = self.create_prepared_statement(prepared_statement.query)
prepared_statement.handle = statement.handle
prepared_statement.parameter_count = statement.parameter_count

self._putMessageId(msgId)

if msgId == protocol.PREPAREANDEXECUTETOGETHER:
self.putInt(protocol.DEFAULT_EXECUTE_SUBTYPE) #executeSubtype
self.putInt(protocol.DEFAULT_PREPARE_SUBTYPE) #prepareSubtype
self.putString(prepared_statement.query)
self.putInt(protocol.DEFAULT_EXECUTE_TIMEOUT_MS) # timeout
self.putInt(protocol.DEFAULT_FETCH_SIZE) # fetchsize
if self.__sessionVersion >= protocol.LAST_COMMIT_INFO:
with EncodedSession.__dblock:
self.putInt(len(self.__dbinfo))
for sid, tup in self.__dbinfo.items():
self.putInt(sid)
self.putInt(tup[0])
self.putInt(tup[1])
self.putInt(handle)

if msgId != protocol.PREPAREANDEXECUTETOGETHER:
self.putInt(prepared_statement.handle)
return self

def _hasBytes(self, length):
Expand Down Expand Up @@ -1368,3 +1399,9 @@ def _takeBytes(self, length):
return self.__input[self.__inpos:self.__inpos + length]
finally:
self.__inpos += length
def _isPrepared(self,prepared_statement):
# type: (statement.PreparedStatement) -> bool
"""Check if the prepared statement is valid."""
if prepared_statement.handle == -1:
return False
return True
10 changes: 9 additions & 1 deletion pynuodb/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
EXECUTEPREPAREDSTATEMENT = 22
EXECUTEPREPAREDQUERY = 23
EXECUTEPREPAREDUPDATE = 24
PREPAREANDEXECUTETOGETHER = 25
GETMETADATA = 26
NEXT = 27
CLOSERESULTSET = 28
Expand Down Expand Up @@ -343,6 +344,13 @@ def lookup_code(error_code):
"""Return a string-ified version of an error code."""
return stringifyError.get(error_code, '[UNKNOWN ERROR CODE]')

#
# Default Options
#
DEFAULT_EXECUTE_TIMEOUT_MS = 0
DEFAULT_FETCH_SIZE = 0
DEFAULT_EXECUTE_SUBTYPE = 5
DEFAULT_PREPARE_SUBTYPE = 0

#
# NuoDB Client-Server Features
Expand Down Expand Up @@ -378,5 +386,5 @@ def lookup_code(error_code):
# The newest feature this driver supports.
# The server will negotiate the highest compatible version.
CURRENT_PROTOCOL_MAJOR = 1
CURRENT_PROTOCOL_VERSION = TIMESTAMP_WITHOUT_TZ
CURRENT_PROTOCOL_VERSION = PREPARE_AND_EXECUTE_TOGETHER
AUTH_TEST_STR = 'Success!'
1 change: 1 addition & 0 deletions pynuodb/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, handle, parameter_count):
super(PreparedStatement, self).__init__(handle)
self.parameter_count = parameter_count
self.description = None # type: Optional[List[List[Any]]]
self.query = None # type: Optional[str]


class ExecutionResult(object):
Expand Down
8 changes: 4 additions & 4 deletions tests/nuodb_cursor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from pynuodb.exception import DataError, ProgrammingError, BatchError, OperationalError
from pynuodb.exception import DataError, ProgrammingError, BatchError, OperationalError, DatabaseError

from . import nuodb_base

Expand Down Expand Up @@ -45,17 +45,17 @@ def test_insufficient_parameters(self):
con = self._connect()
cursor = con.cursor()

with pytest.raises(ProgrammingError):
with pytest.raises(DatabaseError):
cursor.execute("SELECT ?, ? FROM DUAL", [1])

def test_toomany_parameters(self):
con = self._connect()
cursor = con.cursor()

with pytest.raises(ProgrammingError):
with pytest.raises(DatabaseError):
cursor.execute("SELECT 1 FROM DUAL", [1])

with pytest.raises(ProgrammingError):
with pytest.raises(DatabaseError):
cursor.execute("SELECT ? FROM DUAL", [1, 2])

def test_incorrect_parameters(self):
Expand Down