Skip to content
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,26 @@ will spawn a number of threads in order to optimize performance for
iRODS server versions 4.2.9+ and file sizes larger than a default
threshold value of 32 Megabytes.

Because multithread processes under Unix-type operating systems sometimes
need special handling, it is recommended that any put or get of a large file
be appropriately handled in the case that a terminating signal aborts the
transfer:

```python
from irods.parallel import abort_parallel_transfers

def handler(*arguments):
abort_parallel_transfers()

signal(SIGINT,handler)

try:
# a multi-1247 put or get can leave non-daemon threads running if not treated with care.
session.data_objects.put( ...)
except KeyboardInterrupt
abort_parallel_transfers()
```

Progress bars
-------------

Expand Down
3 changes: 2 additions & 1 deletion irods/manager/data_object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@ def __init__(self, *a, **kwd):
self._iRODS_session = kwd.pop("_session", None)
super(ManagedBufferedRandom, self).__init__(*a, **kwd)
import irods.session
self.no_close = False

with irods.session._fds_lock:
irods.session._fds[self] = None

def __del__(self):
if not self.closed:
if not self.no_close and not self.closed:
self.close()
call___del__if_exists(super(ManagedBufferedRandom, self))

Expand Down
102 changes: 81 additions & 21 deletions irods/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@
import concurrent.futures
import threading
import multiprocessing
from typing import List, Union
from typing import List, Union, Any
import weakref

from irods.data_object import iRODSDataObject
from irods.exception import DataObjectDoesNotExist
import irods.keywords as kw
from queue import Queue, Full, Empty


transfer_managers: weakref.WeakKeyDictionary["_Multipart_close_manager", Any] = weakref.WeakKeyDictionary()

def abort_parallel_transfers(dry_run = False):
if not dry_run:
for mgr in transfer_managers:
mgr.quit()
return dict(transfer_managers)


logger = logging.getLogger(__name__)
_nullh = logging.NullHandler()
logger.addHandler(_nullh)
Expand Down Expand Up @@ -91,9 +101,11 @@ def __init__(
for future in self._futures:
future.add_done_callback(self)
else:
self.__invoke_done_callback()
self.__invoke_futures_done_logic()
return

self.progress = [0, 0]

if (progress_Queue) and (total is not None):
self.progress[1] = total

Expand All @@ -112,7 +124,7 @@ def _progress(Q, this): # - thread to update progress indicator

self._progress_fn = _progress
self._progress_thread = threading.Thread(
target=self._progress_fn, args=(progress_Queue, self)
target=self._progress_fn, args=(progress_Queue, self), daemon=True
)
self._progress_thread.start()

Expand Down Expand Up @@ -153,11 +165,13 @@ def __call__(
with self._lock:
self._futures_done[future] = future.result()
if len(self._futures) == len(self._futures_done):
self.__invoke_done_callback()
self.__invoke_futures_done_logic(
skip_user_callback=(None in self._futures_done.values())
)

def __invoke_done_callback(self):
def __invoke_futures_done_logic(self, skip_user_callback=False):
try:
if callable(self.done_callback):
if not skip_user_callback and callable(self.done_callback):
self.done_callback(self)
finally:
self.keep.pop("mgr", None)
Expand Down Expand Up @@ -240,6 +254,10 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()):
bytecount = 0
accum = 0
while True and bytecount < length:
print (('T' if mgr._quit else 'F'), end = '', flush=True)
if mgr._quit:
bytecount = None
break
buf = src.read(min(COPY_BUF_SIZE, length - bytecount))
buf_len = len(buf)
if 0 == buf_len:
Expand Down Expand Up @@ -274,11 +292,39 @@ class _Multipart_close_manager:

"""

def __init__(self, initial_io_, exit_barrier_):
def __init__(self, initial_io_, exit_barrier_, executor = None):
self._quit = False
self.exit_barrier = exit_barrier_
self.initial_io = initial_io_
self.__lock = threading.Lock()
self.aux = []
self.futures = set()
self.executor = executor

def add_future(self, future): self.futures.add(future)

@property
def active_futures(self):
return tuple(_ for _ in self.futures if not _.done())

def shutdown(self):
if self.executor:
self.executor.shutdown(cancel_futures = True)

def quit(self):
from irods.manager.data_object_manager import ManagedBufferedRandom
# remove all descriptors from consideration for auto_close.
import irods.session
with irods.session._fds_lock:
for fd in self.aux + [self.initial_io]:
irods.session._fds.pop(fd, ())
if type(fd) is ManagedBufferedRandom:
fd.no_close = True
# abort threads.
self._quit = True
self.exit_barrier.abort()
self.shutdown()
return self.active_futures

def __contains__(self, Io):
with self.__lock:
Expand All @@ -297,15 +343,20 @@ def add_io(self, Io):
# synchronizes all of the parallel threads just before exit, so that we know
# exactly when to perform a finalizing close on the data object


def remove_io(self, Io):
is_initial = True
with self.__lock:
if Io is not self.initial_io:
Io.close()
self.aux.remove(Io)
is_initial = False
self.exit_barrier.wait()
if is_initial:
broken = False
try:
self.exit_barrier.wait()
except threading.BrokenBarrierError:
broken = True
if is_initial and not (broken or self._quit):
self.finalize()

def finalize(self):
Expand Down Expand Up @@ -393,7 +444,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
futures = []
executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
num_threads = min(num_threads, len(ranges))
mgr = _Multipart_close_manager(Io, Barrier(num_threads))
mgr = _Multipart_close_manager(Io, Barrier(num_threads), executor)
counter = 1
gen_file_handle = lambda: open(
fname, Operation.disk_file_mode(initial_open=(counter == 1))
Expand Down Expand Up @@ -425,7 +476,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
if File is None:
File = gen_file_handle()
futures.append(
executor.submit(
f := executor.submit(
_io_part,
Io,
byte_range,
Expand All @@ -436,17 +487,26 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
**thread_opts
)
)
mgr.add_future(f)
counter += 1
Io = File = None

if Operation.isNonBlocking():
if queueLength:
return futures, queueObject, mgr
else:
return futures
return futures, queueObject, mgr
else:
bytecounts = [f.result() for f in futures]
return sum(bytecounts), total_size
bytes_transferred = 0
try:
transfer_managers[mgr] = 1
bytecounts = [f.result() for f in futures]
if None not in bytecounts:
bytes_transferred = sum(bytecounts)
except (KeyboardInterrupt, #SystemExit
):
print ('\nraising KBI\n')
raise
finally:
pass
return bytes_transferred, total_size


def io_main(session, Data, opr_, fname, R="", **kwopt):
Expand Down Expand Up @@ -559,10 +619,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt):

if Operation.isNonBlocking():

if queueLength > 0:
(futures, chunk_notify_queue, mgr) = retval
else:
futures = retval
(futures, chunk_notify_queue, mgr) = retval
transfer_managers[mgr] = None

if queueLength <= 0:
chunk_notify_queue = total_bytes = None

return AsyncNotify(
Expand Down
8 changes: 8 additions & 0 deletions irods/test/data_obj_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,14 @@ def test_access_time__issue_700(self):
# Test that access_time is there, and of the right type.
self.assertIs(type(data.access_time), datetime)

def test_handling_of_termination_signals_during_multithread_get__issue_722(self):
from irods.test.modules.test_signal_handling_in_multithread_get import (
test as test__issue_722,
)

test__issue_722(self)


if __name__ == "__main__":
# let the tests find the parent irods lib
sys.path.insert(0, os.path.abspath("../.."))
Expand Down
135 changes: 135 additions & 0 deletions irods/test/modules/test_signal_handling_in_multithread_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import re
import signal
import subprocess
import sys
import tempfile
import time

import irods
import irods.helpers
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that we only actually use irods.helpers. Could we remove the import irods?

from irods.test import modules as test_modules
from irods.parallel import abort_parallel_transfers

OBJECT_SIZE = 2 * 1024**3
OBJECT_NAME = "data_get_issue__722"
LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat"


_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME))


def wait_till_true(function, timeout=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case... can we set the default timeout to some high-ish value that is abundantly more than enough time to complete whatever transfer we are waiting for, but that will eventually fail if it is stuck? We can still support None as "no timeout", but making it the default makes me squirm. The value I had in mind was like... 10 minutes?

start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME)
while not (truth_value := function()):
if (
timeout is not None
and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9
> timeout
):
break
time.sleep(_clock_polling_interval)
return truth_value


def test(test_case, signal_names=("SIGTERM", "SIGINT")):
"""Creates a child process executing a long get() and ensures the process can be
terminated using SIGINT or SIGTERM.
"""
program = os.path.join(test_modules.__path__[0], os.path.basename(__file__))

for signal_name in signal_names:

with test_case.subTest(f"Testing with signal {signal_name}"):

# Call into this same module as a command. This will initiate another Python process that
# performs a lengthy data object "get" operation (see the main body of the script, below.)
process = subprocess.Popen(
[sys.executable, program],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
text=True,
)

# Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions
# of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread
# unless measures are taken (#722).
localfile = process.stdout.readline().strip()
test_case.assertTrue(
wait_till_true(
lambda: os.path.exists(localfile)
and os.stat(localfile).st_size > OBJECT_SIZE // 2
),
"Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.",
)

sig = getattr(signal, signal_name)

translate_return_code = lambda s: 128 - s if s < 0 else s

# Interrupt the subprocess with the given signal.
process.send_signal(sig)

# Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit
# due to misproper or incomplete handling of the signal.
try:
test_case.assertEqual(
translate_return_code(process.wait(timeout=15)),
128 + sig,
"Unexpected subprocess return code.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put the expected value and the actual value in the msg?

)
except subprocess.TimeoutExpired as timeout_exc:
test_case.fail(
f"Subprocess timed out before terminating. "
"Non-daemon thread(s) probably prevented subprocess's main thread from exiting."
)
# Assert that in the case of SIGINT, the process registered a KeyboardInterrupt.
if sig == signal.SIGINT:
test_case.assertTrue(
re.search("KeyboardInterrupt", process.stderr.read()),
"Did not find expected string 'KeyboardInterrupt' in log output.",
)


if __name__ == "__main__":
# These lines are run only if the module is launched as a process.
session = irods.helpers.make_session()
hc = irods.helpers.home_collection(session)
TESTFILE_FILL = b"_" * (1024 * 1024)
object_path = f"{hc}/{OBJECT_NAME}"

# Create the object to be downloaded.
with session.data_objects.open(object_path, "w") as f:
for y in range(OBJECT_SIZE // len(TESTFILE_FILL)):
f.write(TESTFILE_FILL)
local_path = None
# Establish where (ie absolute path) to place the downloaded file, i.e. the get() target.
try:
with tempfile.NamedTemporaryFile(
prefix="local_file_issue_722.dat", delete=True
) as t:
local_path = t.name

# Tell the parent process the name of the local file being "get"ted (got) from iRODS
print(local_path)
sys.stdout.flush()

def handler(sig,*_):
abort_parallel_transfers()
exit(128+sig)

signal.signal(signal.SIGTERM, handler)

try:
# download the object
session.data_objects.get(object_path, local_path)
except KeyboardInterrupt:
abort_parallel_transfers()
raise

finally:
# Clean up, whether or not the download succeeded.
if local_path is not None and os.path.exists(local_path):
os.unlink(local_path)
if session.data_objects.exists(object_path):
session.data_objects.unlink(object_path, force=True)