@@ -19,16 +19,22 @@ class ParallelProcessing: ...
1919from .utils .config import Settings
2020from .utils .algorithm import chunk_split
2121
22- from ._types import ThreadStatus , Data_In , Data_Out , Overflow_In , TargetFunction , HookFunction
22+ from ._types import (
23+ ThreadStatus , Data_In , Data_Out , Overflow_In ,
24+ TargetFunction , _Target_P , _Target_T ,
25+ DatasetFunction , _Dataset_T ,
26+ HookFunction
27+ )
28+ from typing_extensions import Generic , ParamSpec
2329from typing import (
24- Any , List ,
25- Callable , Optional ,
30+ List ,
31+ Callable , Optional , Union ,
2632 Mapping , Sequence , Tuple
2733)
2834
2935
3036Threads : set ['Thread' ] = set ()
31- class Thread (threading .Thread ):
37+ class Thread (threading .Thread , Generic [ _Target_P , _Target_T ] ):
3238 """
3339 Wraps python's `threading.Thread` class
3440 ---------------------------------------
@@ -38,7 +44,7 @@ class Thread(threading.Thread):
3844
3945 status : ThreadStatus
4046 hooks : List [HookFunction ]
41- returned_value : Data_Out
47+ _returned_value : Data_Out
4248
4349 errors : List [Exception ]
4450 ignore_errors : Sequence [type [Exception ]]
@@ -51,7 +57,7 @@ class Thread(threading.Thread):
5157
5258 def __init__ (
5359 self ,
54- target : TargetFunction ,
60+ target : TargetFunction [ _Target_P , _Target_T ] ,
5561 args : Sequence [Data_In ] = (),
5662 kwargs : Mapping [str , Data_In ] = {},
5763 ignore_errors : Sequence [type [Exception ]] = (),
@@ -80,7 +86,7 @@ def __init__(
8086 :param **: These are arguments parsed to `thread.Thread`
8187 """
8288 _target = self ._wrap_target (target )
83- self .returned_value = None
89+ self ._returned_value = None
8490 self .status = 'Idle'
8591 self .hooks = []
8692
@@ -100,17 +106,17 @@ def __init__(
100106 )
101107
102108
103- def _wrap_target (self , target : TargetFunction ) -> TargetFunction :
109+ def _wrap_target (self , target : TargetFunction [ _Target_P , _Target_T ] ) -> TargetFunction [ _Target_P , Union [ _Target_T , None ]] :
104110 """Wraps the target function"""
105111 @wraps (target )
106- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
112+ def wrapper (* args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> Union [ _Target_T , None ] :
107113 self .status = 'Running'
108114
109115 global Threads
110116 Threads .add (self )
111117
112118 try :
113- self .returned_value = target (* args , ** kwargs )
119+ self ._returned_value = target (* args , ** kwargs )
114120 except Exception as e :
115121 if not any (isinstance (e , ignore ) for ignore in self .ignore_errors ):
116122 self .status = 'Errored'
@@ -129,7 +135,7 @@ def _invoke_hooks(self) -> None:
129135 errors : List [Tuple [Exception , str ]] = []
130136 for hook in self .hooks :
131137 try :
132- hook (self .returned_value )
138+ hook (self ._returned_value )
133139 except Exception as e :
134140 if not any (isinstance (e , ignore ) for ignore in self .ignore_errors ):
135141 errors .append ((
@@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:
173179
174180
175181 @property
176- def result (self ) -> Data_Out :
182+ def result (self ) -> _Target_T :
177183 """
178184 The return value of the thread
179185
@@ -190,7 +196,7 @@ def result(self) -> Data_Out:
190196
191197 self ._handle_exceptions ()
192198 if self .status in ['Invoking hooks' , 'Completed' ]:
193- return self .returned_value
199+ return self ._returned_value
194200 else :
195201 raise exceptions .ThreadStillRunningError ()
196202
@@ -208,7 +214,7 @@ def is_alive(self) -> bool:
208214 return super ().is_alive ()
209215
210216
211- def add_hook (self , hook : HookFunction ) -> None :
217+ def add_hook (self , hook : HookFunction [ _Target_T ] ) -> None :
212218 """
213219 Adds a hook to the thread
214220 -------------------------
@@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
250256 return not self .is_alive ()
251257
252258
253- def get_return_value (self ) -> Data_Out :
259+ def get_return_value (self ) -> _Target_T :
254260 """
255261 Halts the current thread execution until the thread completes
256262
@@ -315,6 +321,7 @@ def start(self) -> None:
315321
316322
317323
324+ _P = ParamSpec ('_P' )
318325class _ThreadWorker :
319326 progress : float
320327 thread : Thread
@@ -323,7 +330,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
323330 self .thread = thread
324331 self .progress = progress
325332
326- class ParallelProcessing :
333+ class ParallelProcessing ( Generic [ _Target_P , _Target_T , _Dataset_T ]) :
327334 """
328335 Multi-Threaded Parallel Processing
329336 ---------------------------------------
@@ -335,7 +342,7 @@ class ParallelProcessing:
335342 _completed : int
336343
337344 status : ThreadStatus
338- function : Callable [..., List [ Data_Out ]]
345+ function : TargetFunction
339346 dataset : Sequence [Data_In ]
340347 max_threads : int
341348
@@ -344,8 +351,8 @@ class ParallelProcessing:
344351
345352 def __init__ (
346353 self ,
347- function : TargetFunction ,
348- dataset : Sequence [Data_In ],
354+ function : DatasetFunction [ _Dataset_T , _Target_T ] ,
355+ dataset : Sequence [_Dataset_T ],
349356 max_threads : int = 8 ,
350357
351358 * overflow_args : Overflow_In ,
@@ -386,9 +393,9 @@ def __init__(
386393 def _wrap_function (
387394 self ,
388395 function : TargetFunction
389- ) -> Callable [..., List [ Data_Out ]] :
396+ ) -> TargetFunction :
390397 @wraps (function )
391- def wrapper (index : int , data_chunk : Sequence [Data_In ], * args : Any , ** kwargs : Any ) -> List [Data_Out ]:
398+ def wrapper (index : int , data_chunk : Sequence [_Dataset_T ], * args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> List [_Target_T ]:
392399 computed : List [Data_Out ] = []
393400 for i , data_entry in enumerate (data_chunk ):
394401 v = function (data_entry , * args , ** kwargs )
@@ -404,7 +411,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any
404411
405412
406413 @property
407- def results (self ) -> Data_Out :
414+ def results (self ) -> List [ _Dataset_T ] :
408415 """
409416 The return value of the threads if completed
410417
@@ -436,7 +443,7 @@ def is_alive(self) -> bool:
436443 return any (entry .thread .is_alive () for entry in self ._threads )
437444
438445
439- def get_return_values (self ) -> List [Data_Out ]:
446+ def get_return_values (self ) -> List [_Dataset_T ]:
440447 """
441448 Halts the current thread execution until the thread completes
442449
@@ -506,6 +513,8 @@ def start(self) -> None:
506513 name_format = self .overflow_kwargs .get ('name' ) and self .overflow_kwargs ['name' ] + '%s'
507514 self .overflow_kwargs = { i : v for i ,v in self .overflow_kwargs .items () if i != 'name' and i != 'args' }
508515
516+ print (parsed_args , self .overflow_args )
517+
509518 for i , data_chunk in enumerate (chunk_split (self .dataset , max_threads )):
510519 chunk_thread = Thread (
511520 target = self .function ,
0 commit comments