@@ -34,19 +34,24 @@ use datafusion::logical_expr::expr::Alias;
3434use datafusion:: logical_expr:: {
3535 Aggregate , Explain , Filter , LogicalPlan , PlanType , Projection , ToStringifiedPlan ,
3636} ;
37- use datafusion:: physical_plan:: ExecutionPlan ;
37+ use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
38+ use datafusion:: physical_plan:: {
39+ ExecutionPlan , ExecutionPlanProperties , collect_partitioned, execute_stream_partitioned,
40+ } ;
3841use datafusion:: prelude:: * ;
3942use datafusion:: sql:: parser:: DFParser ;
4043use datafusion:: sql:: resolve:: resolve_table_references;
4144use datafusion:: sql:: sqlparser:: dialect:: PostgreSqlDialect ;
4245use futures:: Stream ;
46+ use futures:: stream:: select_all;
4347use itertools:: Itertools ;
4448use once_cell:: sync:: Lazy ;
4549use serde:: { Deserialize , Serialize } ;
4650use serde_json:: { Value , json} ;
4751use std:: ops:: Bound ;
4852use std:: pin:: Pin ;
4953use std:: sync:: Arc ;
54+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
5055use std:: task:: { Context , Poll } ;
5156use sysinfo:: System ;
5257use tokio:: runtime:: Runtime ;
@@ -85,7 +90,27 @@ pub async fn execute(
8590 is_streaming : bool ,
8691) -> Result <
8792 (
88- Either < Vec < RecordBatch > , Pin < Box < MetricMonitorStream > > > ,
93+ Either <
94+ Vec < RecordBatch > ,
95+ Pin <
96+ Box <
97+ RecordBatchStreamAdapter <
98+ select_all:: SelectAll <
99+ Pin <
100+ Box <
101+ dyn RecordBatchStream <
102+ Item = Result <
103+ RecordBatch ,
104+ datafusion:: error:: DataFusionError ,
105+ > ,
106+ > + Send ,
107+ > ,
108+ > ,
109+ > ,
110+ > ,
111+ > ,
112+ > ,
113+ > ,
89114 Vec < String > ,
90115 ) ,
91116 ExecuteError ,
@@ -186,7 +211,27 @@ impl Query {
186211 is_streaming : bool ,
187212 ) -> Result <
188213 (
189- Either < Vec < RecordBatch > , Pin < Box < MetricMonitorStream > > > ,
214+ Either <
215+ Vec < RecordBatch > ,
216+ Pin <
217+ Box <
218+ RecordBatchStreamAdapter <
219+ select_all:: SelectAll <
220+ Pin <
221+ Box <
222+ dyn RecordBatchStream <
223+ Item = Result <
224+ RecordBatch ,
225+ datafusion:: error:: DataFusionError ,
226+ > ,
227+ > + Send ,
228+ > ,
229+ > ,
230+ > ,
231+ > ,
232+ > ,
233+ > ,
234+ > ,
190235 Vec < String > ,
191236 ) ,
192237 ExecuteError ,
@@ -215,8 +260,11 @@ impl Query {
215260 let results = if !is_streaming {
216261 let task_ctx = QUERY_SESSION . task_ctx ( ) ;
217262
218- let stream = plan. execute ( 0 , task_ctx) ?;
219- let batches = datafusion:: physical_plan:: common:: collect ( stream) . await ?;
263+ let batches = collect_partitioned ( plan. clone ( ) , task_ctx. clone ( ) )
264+ . await ?
265+ . into_iter ( )
266+ . flatten ( )
267+ . collect ( ) ;
220268
221269 let actual_io_bytes = get_total_bytes_scanned ( & plan) ;
222270
@@ -228,11 +276,25 @@ impl Query {
228276 } else {
229277 let task_ctx = QUERY_SESSION . task_ctx ( ) ;
230278
231- let stream = plan. execute ( 0 , task_ctx) ?;
279+ let output_partitions = plan. output_partitioning ( ) . partition_count ( ) ;
280+
281+ let monitor_state = Arc :: new ( MonitorState {
282+ plan : plan. clone ( ) ,
283+ active_streams : AtomicUsize :: new ( output_partitions) ,
284+ } ) ;
285+
286+ let streams = execute_stream_partitioned ( plan. clone ( ) , task_ctx. clone ( ) ) ?
287+ . into_iter ( )
288+ . map ( |s| {
289+ let wrapped = PartitionedMetricMonitor :: new ( s, monitor_state. clone ( ) ) ;
290+ Box :: pin ( wrapped) as SendableRecordBatchStream
291+ } )
292+ . collect_vec ( ) ;
232293
233- let monitored_stream = MetricMonitorStream :: new ( stream, plan . clone ( ) ) ;
294+ let merged_stream = futures :: stream:: select_all ( streams ) ;
234295
235- Either :: Right ( Box :: pin ( monitored_stream) )
296+ let final_stream = RecordBatchStreamAdapter :: new ( plan. schema ( ) , merged_stream) ;
297+ Either :: Right ( Box :: pin ( final_stream) )
236298 } ;
237299
238300 Ok ( ( results, fields) )
@@ -789,39 +851,52 @@ pub mod error {
789851 }
790852}
791853
854+ /// Shared state across all partitions
855+ struct MonitorState {
856+ plan : Arc < dyn ExecutionPlan > ,
857+ active_streams : AtomicUsize ,
858+ }
859+
792860/// A wrapper that monitors the ExecutionPlan and logs metrics when the stream finishes.
793- pub struct MetricMonitorStream {
861+ pub struct PartitionedMetricMonitor {
794862 // The actual stream doing the work
795863 inner : SendableRecordBatchStream ,
796- // We hold the plan so we can read metrics after execution
797- plan : Arc < dyn ExecutionPlan > ,
864+ /// State of the streams
865+ state : Arc < MonitorState > ,
866+ // Ensure we only emit metrics once even if polled after completion/error
867+ is_finished : bool ,
798868}
799869
800- impl MetricMonitorStream {
801- pub fn new ( inner : SendableRecordBatchStream , plan : Arc < dyn ExecutionPlan > ) -> Self {
802- Self { inner, plan }
870+ impl PartitionedMetricMonitor {
871+ fn new ( inner : SendableRecordBatchStream , state : Arc < MonitorState > ) -> Self {
872+ Self {
873+ inner,
874+ state,
875+ is_finished : false ,
876+ }
803877 }
804878}
805879
806- impl Stream for MetricMonitorStream {
880+ impl Stream for PartitionedMetricMonitor {
807881 type Item = datafusion:: error:: Result < RecordBatch > ;
808882
809883 fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
884+ if self . is_finished {
885+ return Poll :: Ready ( None ) ;
886+ }
887+
810888 let poll = self . inner . as_mut ( ) . poll_next ( cx) ;
811889
812890 // Check if the stream just finished
813891 match & poll {
814892 Poll :: Ready ( None ) => {
815- // Stream is done. Now we can safely read the metrics.
816- let bytes = get_total_bytes_scanned ( & self . plan ) ;
817- let current_date = chrono:: Utc :: now ( ) . date_naive ( ) . to_string ( ) ;
818- increment_bytes_scanned_in_query_by_date ( bytes, & current_date) ;
893+ self . is_finished = true ;
894+ self . check_if_last_stream ( ) ;
819895 }
820896 Poll :: Ready ( Some ( Err ( e) ) ) => {
821- let bytes = get_total_bytes_scanned ( & self . plan ) ;
822- let current_date = chrono:: Utc :: now ( ) . date_naive ( ) . to_string ( ) ;
823- increment_bytes_scanned_in_query_by_date ( bytes, & current_date) ;
824897 tracing:: error!( "Stream Failed with error: {}" , e) ;
898+ self . is_finished = true ;
899+ self . check_if_last_stream ( ) ;
825900 }
826901 _ => { }
827902 }
@@ -834,12 +909,24 @@ impl Stream for MetricMonitorStream {
834909 }
835910}
836911
837- impl RecordBatchStream for MetricMonitorStream {
912+ impl RecordBatchStream for PartitionedMetricMonitor {
838913 fn schema ( & self ) -> SchemaRef {
839914 self . inner . schema ( )
840915 }
841916}
842917
918+ impl PartitionedMetricMonitor {
919+ fn check_if_last_stream ( & self ) {
920+ let prev_count = self . state . active_streams . fetch_sub ( 1 , Ordering :: SeqCst ) ;
921+
922+ if prev_count == 1 {
923+ let bytes = get_total_bytes_scanned ( & self . state . plan ) ;
924+ let current_date = chrono:: Utc :: now ( ) . date_naive ( ) . to_string ( ) ;
925+ increment_bytes_scanned_in_query_by_date ( bytes, & current_date) ;
926+ }
927+ }
928+ }
929+
843930#[ cfg( test) ]
844931mod tests {
845932 use serde_json:: json;
0 commit comments