diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a352918211..fbd3b1986e 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -489,6 +489,21 @@ async def _ensure_tables_created(self): await conn.run_sync(Base.metadata.create_all) self._tables_created = True + async def _commit_or_rollback( + self, sql_session: DatabaseSessionFactory + ) -> None: + """Commit the transaction, rolling back on error. + + This ensures that database connections are not left in a failed + 'pending rollback' state (see issue #3328). Any exception from + commit is re-raised after rollback so callers still see failures. + """ + try: + await sql_session.commit() + except Exception: + await sql_session.rollback() + raise + @override async def create_session( self, @@ -505,58 +520,62 @@ async def create_session( # 5. Return the session await self._ensure_tables_created() async with self.database_session_factory() as sql_session: - - if session_id and await sql_session.get( - StorageSession, (app_name, user_id, session_id) - ): - raise AlreadyExistsError( - f"Session with id {session_id} already exists." + try: + if session_id and await sql_session.get( + StorageSession, (app_name, user_id, session_id) + ): + raise AlreadyExistsError( + f"Session with id {session_id} already exists." + ) + # Fetch app and user states from storage + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( + StorageUserState, (app_name, user_id) ) - # Fetch app and user states from storage - storage_app_state = await sql_session.get(StorageAppState, (app_name)) - storage_user_state = await sql_session.get( - StorageUserState, (app_name, user_id) - ) - # Create state tables if not exist - if not storage_app_state: - storage_app_state = StorageAppState(app_name=app_name, state={}) - sql_session.add(storage_app_state) - if not storage_user_state: - storage_user_state = StorageUserState( - app_name=app_name, user_id=user_id, state={} + # Create state tables if not exist + if not storage_app_state: + storage_app_state = StorageAppState(app_name=app_name, state={}) + sql_session.add(storage_app_state) + if not storage_user_state: + storage_user_state = StorageUserState( + app_name=app_name, user_id=user_id, state={} + ) + sql_session.add(storage_user_state) + + # Extract state deltas + state_deltas = _session_util.extract_state_delta(state) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state = state_deltas["session"] + + # Apply state delta + if app_state_delta: + storage_app_state.state = storage_app_state.state | app_state_delta + if user_state_delta: + storage_user_state.state = storage_user_state.state | user_state_delta + + # Store the session + storage_session = StorageSession( + app_name=app_name, + user_id=user_id, + id=session_id, + state=session_state, ) - sql_session.add(storage_user_state) - - # Extract state deltas - state_deltas = _session_util.extract_state_delta(state) - app_state_delta = state_deltas["app"] - user_state_delta = state_deltas["user"] - session_state = state_deltas["session"] - - # Apply state delta - if app_state_delta: - storage_app_state.state = storage_app_state.state | app_state_delta - if user_state_delta: - storage_user_state.state = storage_user_state.state | user_state_delta - - # Store the session - storage_session = StorageSession( - app_name=app_name, - user_id=user_id, - id=session_id, - state=session_state, - ) - sql_session.add(storage_session) - await sql_session.commit() + sql_session.add(storage_session) + await self._commit_or_rollback(sql_session) - await sql_session.refresh(storage_session) + await sql_session.refresh(storage_session) - # Merge states for response - merged_state = _merge_state( - storage_app_state.state, storage_user_state.state, session_state - ) - session = storage_session.to_session(state=merged_state) + # Merge states for response + merged_state = _merge_state( + storage_app_state.state, storage_user_state.state, session_state + ) + session = storage_session.to_session(state=merged_state) + except Exception: + if sql_session.in_transaction(): + await sql_session.rollback() + raise return session @override @@ -573,47 +592,52 @@ async def get_session( # 2. Get all the events based on session id and filtering config # 3. Convert and return the session async with self.database_session_factory() as sql_session: - storage_session = await sql_session.get( - StorageSession, (app_name, user_id, session_id) - ) - if storage_session is None: - return None - - stmt = ( - select(StorageEvent) - .filter(StorageEvent.app_name == app_name) - .filter(StorageEvent.session_id == storage_session.id) - .filter(StorageEvent.user_id == user_id) - ) + try: + storage_session = await sql_session.get( + StorageSession, (app_name, user_id, session_id) + ) + if storage_session is None: + return None + + stmt = ( + select(StorageEvent) + .filter(StorageEvent.app_name == app_name) + .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) + ) - if config and config.after_timestamp: - after_dt = datetime.fromtimestamp(config.after_timestamp) - stmt = stmt.filter(StorageEvent.timestamp >= after_dt) + if config and config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + stmt = stmt.filter(StorageEvent.timestamp >= after_dt) - stmt = stmt.order_by(StorageEvent.timestamp.desc()) + stmt = stmt.order_by(StorageEvent.timestamp.desc()) - if config and config.num_recent_events: - stmt = stmt.limit(config.num_recent_events) + if config and config.num_recent_events: + stmt = stmt.limit(config.num_recent_events) - result = await sql_session.execute(stmt) - storage_events = result.scalars().all() + result = await sql_session.execute(stmt) + storage_events = result.scalars().all() - # Fetch states from storage - storage_app_state = await sql_session.get(StorageAppState, (app_name)) - storage_user_state = await sql_session.get( - StorageUserState, (app_name, user_id) - ) + # Fetch states from storage + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( + StorageUserState, (app_name, user_id) + ) - app_state = storage_app_state.state if storage_app_state else {} - user_state = storage_user_state.state if storage_user_state else {} - session_state = storage_session.state + app_state = storage_app_state.state if storage_app_state else {} + user_state = storage_user_state.state if storage_user_state else {} + session_state = storage_session.state - # Merge states - merged_state = _merge_state(app_state, user_state, session_state) + # Merge states + merged_state = _merge_state(app_state, user_state, session_state) - # Convert storage session to session - events = [e.to_event() for e in reversed(storage_events)] - session = storage_session.to_session(state=merged_state, events=events) + # Convert storage session to session + events = [e.to_event() for e in reversed(storage_events)] + session = storage_session.to_session(state=merged_state, events=events) + except Exception: + if sql_session.in_transaction(): + await sql_session.rollback() + raise return session @override @@ -622,41 +646,50 @@ async def list_sessions( ) -> ListSessionsResponse: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: - stmt = select(StorageSession).filter(StorageSession.app_name == app_name) - if user_id is not None: - stmt = stmt.filter(StorageSession.user_id == user_id) - - result = await sql_session.execute(stmt) - results = result.scalars().all() - - # Fetch app state from storage - storage_app_state = await sql_session.get(StorageAppState, (app_name)) - app_state = storage_app_state.state if storage_app_state else {} - - # Fetch user state(s) from storage - user_states_map = {} - if user_id is not None: - storage_user_state = await sql_session.get( - StorageUserState, (app_name, user_id) - ) - if storage_user_state: - user_states_map[user_id] = storage_user_state.state - else: - user_state_stmt = select(StorageUserState).filter( - StorageUserState.app_name == app_name + try: + stmt = select(StorageSession).filter( + StorageSession.app_name == app_name ) - user_state_result = await sql_session.execute(user_state_stmt) - all_user_states_for_app = user_state_result.scalars().all() - for storage_user_state in all_user_states_for_app: - user_states_map[storage_user_state.user_id] = storage_user_state.state - - sessions = [] - for storage_session in results: - session_state = storage_session.state - user_state = user_states_map.get(storage_session.user_id, {}) - merged_state = _merge_state(app_state, user_state, session_state) - sessions.append(storage_session.to_session(state=merged_state)) - return ListSessionsResponse(sessions=sessions) + if user_id is not None: + stmt = stmt.filter(StorageSession.user_id == user_id) + + result = await sql_session.execute(stmt) + results = result.scalars().all() + + # Fetch app state from storage + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + app_state = storage_app_state.state if storage_app_state else {} + + # Fetch user state(s) from storage + user_states_map = {} + if user_id is not None: + storage_user_state = await sql_session.get( + StorageUserState, (app_name, user_id) + ) + if storage_user_state: + user_states_map[user_id] = storage_user_state.state + else: + user_state_stmt = select(StorageUserState).filter( + StorageUserState.app_name == app_name + ) + user_state_result = await sql_session.execute(user_state_stmt) + all_user_states_for_app = user_state_result.scalars().all() + for storage_user_state in all_user_states_for_app: + user_states_map[storage_user_state.user_id] = ( + storage_user_state.state + ) + + sessions = [] + for storage_session in results: + session_state = storage_session.state + user_state = user_states_map.get(storage_session.user_id, {}) + merged_state = _merge_state(app_state, user_state, session_state) + sessions.append(storage_session.to_session(state=merged_state)) + return ListSessionsResponse(sessions=sessions) + except Exception: + if sql_session.in_transaction(): + await sql_session.rollback() + raise @override async def delete_session( @@ -664,13 +697,18 @@ async def delete_session( ) -> None: await self._ensure_tables_created() async with self.database_session_factory() as sql_session: - stmt = delete(StorageSession).where( - StorageSession.app_name == app_name, - StorageSession.user_id == user_id, - StorageSession.id == session_id, - ) - await sql_session.execute(stmt) - await sql_session.commit() + try: + stmt = delete(StorageSession).where( + StorageSession.app_name == app_name, + StorageSession.user_id == user_id, + StorageSession.id == session_id, + ) + await sql_session.execute(stmt) + await self._commit_or_rollback(sql_session) + except Exception: + if sql_session.in_transaction(): + await sql_session.rollback() + raise @override async def append_event(self, session: Session, event: Event) -> Event: @@ -685,43 +723,11 @@ async def append_event(self, session: Session, event: Event) -> Event: # 2. Update session attributes based on event config # 3. Store event to table async with self.database_session_factory() as sql_session: - storage_session = await sql_session.get( - StorageSession, (session.app_name, session.user_id, session.id) - ) - - if storage_session.update_timestamp_tz > session.last_update_time: - raise ValueError( - "The last_update_time provided in the session object" - f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is" - " earlier than the update_time in the storage_session" - f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}." - " Please check if it is a stale session." + try: + storage_session = await sql_session.get( + StorageSession, (session.app_name, session.user_id, session.id) ) - # Fetch states from storage - storage_app_state = await sql_session.get( - StorageAppState, (session.app_name) - ) - storage_user_state = await sql_session.get( - StorageUserState, (session.app_name, session.user_id) - ) - - # Extract state delta - if event.actions and event.actions.state_delta: - state_deltas = _session_util.extract_state_delta( - event.actions.state_delta - ) - app_state_delta = state_deltas["app"] - user_state_delta = state_deltas["user"] - session_state_delta = state_deltas["session"] - # Merge state and update storage - if app_state_delta: - storage_app_state.state = storage_app_state.state | app_state_delta - if user_state_delta: - storage_user_state.state = storage_user_state.state | user_state_delta - if session_state_delta: - storage_session.state = storage_session.state | session_state_delta - if storage_session._dialect_name == "sqlite": update_time = datetime.utcfromtimestamp(event.timestamp) else: @@ -729,11 +735,50 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.update_time = update_time sql_session.add(StorageEvent.from_event(session, event)) - await sql_session.commit() - await sql_session.refresh(storage_session) + # Fetch states from storage + storage_app_state = await sql_session.get( + StorageAppState, (session.app_name) + ) + storage_user_state = await sql_session.get( + StorageUserState, (session.app_name, session.user_id) + ) - # Update timestamp with commit time - session.last_update_time = storage_session.update_timestamp_tz + # Extract state delta + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + # Merge state and update storage + if app_state_delta: + storage_app_state.state = storage_app_state.state | app_state_delta + if user_state_delta: + storage_user_state.state = ( + storage_user_state.state | user_state_delta + ) + if session_state_delta: + storage_session.state = storage_session.state | session_state_delta + + if storage_session._dialect_name == "sqlite": + update_time = datetime.fromtimestamp( + event.timestamp, timezone.utc + ).replace(tzinfo=None) + else: + update_time = datetime.fromtimestamp(event.timestamp) + storage_session.update_time = update_time + sql_session.add(StorageEvent.from_event(session, event)) + + await self._commit_or_rollback(sql_session) + await sql_session.refresh(storage_session) + + # Update timestamp with commit time + session.last_update_time = storage_session.update_timestamp_tz + except Exception: + if sql_session.in_transaction(): + await sql_session.rollback() + raise # Also update the in-memory session await super().append_event(session=session, event=event) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 45aa3feede..6a786a747a 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -699,3 +699,30 @@ async def test_partial_events_are_not_persisted(service_type, tmp_path): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +async def test_database_session_service_commit_rollback_on_error(monkeypatch): + """DatabaseSessionService should rollback if commit fails.""" + session_service = get_session_service(SessionServiceType.DATABASE) + + async with session_service.database_session_factory() as sql_session: + rollback_called = False + + async def failing_commit(): + raise RuntimeError('commit failed') + + async def recording_rollback(): + nonlocal rollback_called + rollback_called = True + + # Make this particular session's commit fail and ensure rollback is called. + monkeypatch.setattr(sql_session, 'commit', failing_commit) + monkeypatch.setattr(sql_session, 'rollback', recording_rollback) + + # _commit_or_rollback should re-raise the commit error... + with pytest.raises(RuntimeError): + await session_service._commit_or_rollback(sql_session) + + # ...and it must rollback the transaction. + assert rollback_called is True