Skip to content

Commit d93aa2e

Browse files
committed
bugfix: execution of all partitions
1 parent 0495e8a commit d93aa2e

File tree

1 file changed

+110
-23
lines changed

1 file changed

+110
-23
lines changed

src/query/mod.rs

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,24 @@ use datafusion::logical_expr::expr::Alias;
3434
use datafusion::logical_expr::{
3535
Aggregate, Explain, Filter, LogicalPlan, PlanType, Projection, ToStringifiedPlan,
3636
};
37-
use datafusion::physical_plan::ExecutionPlan;
37+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
38+
use datafusion::physical_plan::{
39+
ExecutionPlan, ExecutionPlanProperties, collect_partitioned, execute_stream_partitioned,
40+
};
3841
use datafusion::prelude::*;
3942
use datafusion::sql::parser::DFParser;
4043
use datafusion::sql::resolve::resolve_table_references;
4144
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
4245
use futures::Stream;
46+
use futures::stream::select_all;
4347
use itertools::Itertools;
4448
use once_cell::sync::Lazy;
4549
use serde::{Deserialize, Serialize};
4650
use serde_json::{Value, json};
4751
use std::ops::Bound;
4852
use std::pin::Pin;
4953
use std::sync::Arc;
54+
use std::sync::atomic::{AtomicUsize, Ordering};
5055
use std::task::{Context, Poll};
5156
use sysinfo::System;
5257
use tokio::runtime::Runtime;
@@ -85,7 +90,27 @@ pub async fn execute(
8590
is_streaming: bool,
8691
) -> Result<
8792
(
88-
Either<Vec<RecordBatch>, Pin<Box<MetricMonitorStream>>>,
93+
Either<
94+
Vec<RecordBatch>,
95+
Pin<
96+
Box<
97+
RecordBatchStreamAdapter<
98+
select_all::SelectAll<
99+
Pin<
100+
Box<
101+
dyn RecordBatchStream<
102+
Item = Result<
103+
RecordBatch,
104+
datafusion::error::DataFusionError,
105+
>,
106+
> + Send,
107+
>,
108+
>,
109+
>,
110+
>,
111+
>,
112+
>,
113+
>,
89114
Vec<String>,
90115
),
91116
ExecuteError,
@@ -186,7 +211,27 @@ impl Query {
186211
is_streaming: bool,
187212
) -> Result<
188213
(
189-
Either<Vec<RecordBatch>, Pin<Box<MetricMonitorStream>>>,
214+
Either<
215+
Vec<RecordBatch>,
216+
Pin<
217+
Box<
218+
RecordBatchStreamAdapter<
219+
select_all::SelectAll<
220+
Pin<
221+
Box<
222+
dyn RecordBatchStream<
223+
Item = Result<
224+
RecordBatch,
225+
datafusion::error::DataFusionError,
226+
>,
227+
> + Send,
228+
>,
229+
>,
230+
>,
231+
>,
232+
>,
233+
>,
234+
>,
190235
Vec<String>,
191236
),
192237
ExecuteError,
@@ -215,8 +260,11 @@ impl Query {
215260
let results = if !is_streaming {
216261
let task_ctx = QUERY_SESSION.task_ctx();
217262

218-
let stream = plan.execute(0, task_ctx)?;
219-
let batches = datafusion::physical_plan::common::collect(stream).await?;
263+
let batches = collect_partitioned(plan.clone(), task_ctx.clone())
264+
.await?
265+
.into_iter()
266+
.flatten()
267+
.collect();
220268

221269
let actual_io_bytes = get_total_bytes_scanned(&plan);
222270

@@ -228,11 +276,25 @@ impl Query {
228276
} else {
229277
let task_ctx = QUERY_SESSION.task_ctx();
230278

231-
let stream = plan.execute(0, task_ctx)?;
279+
let output_partitions = plan.output_partitioning().partition_count();
280+
281+
let monitor_state = Arc::new(MonitorState {
282+
plan: plan.clone(),
283+
active_streams: AtomicUsize::new(output_partitions),
284+
});
285+
286+
let streams = execute_stream_partitioned(plan.clone(), task_ctx.clone())?
287+
.into_iter()
288+
.map(|s| {
289+
let wrapped = PartitionedMetricMonitor::new(s, monitor_state.clone());
290+
Box::pin(wrapped) as SendableRecordBatchStream
291+
})
292+
.collect_vec();
232293

233-
let monitored_stream = MetricMonitorStream::new(stream, plan.clone());
294+
let merged_stream = futures::stream::select_all(streams);
234295

235-
Either::Right(Box::pin(monitored_stream))
296+
let final_stream = RecordBatchStreamAdapter::new(plan.schema(), merged_stream);
297+
Either::Right(Box::pin(final_stream))
236298
};
237299

238300
Ok((results, fields))
@@ -789,39 +851,52 @@ pub mod error {
789851
}
790852
}
791853

854+
/// Shared state across all partitions
855+
struct MonitorState {
856+
plan: Arc<dyn ExecutionPlan>,
857+
active_streams: AtomicUsize,
858+
}
859+
792860
/// A wrapper that monitors the ExecutionPlan and logs metrics when the stream finishes.
793-
pub struct MetricMonitorStream {
861+
pub struct PartitionedMetricMonitor {
794862
// The actual stream doing the work
795863
inner: SendableRecordBatchStream,
796-
// We hold the plan so we can read metrics after execution
797-
plan: Arc<dyn ExecutionPlan>,
864+
/// State of the streams
865+
state: Arc<MonitorState>,
866+
// Ensure we only emit metrics once even if polled after completion/error
867+
is_finished: bool,
798868
}
799869

800-
impl MetricMonitorStream {
801-
pub fn new(inner: SendableRecordBatchStream, plan: Arc<dyn ExecutionPlan>) -> Self {
802-
Self { inner, plan }
870+
impl PartitionedMetricMonitor {
871+
fn new(inner: SendableRecordBatchStream, state: Arc<MonitorState>) -> Self {
872+
Self {
873+
inner,
874+
state,
875+
is_finished: false,
876+
}
803877
}
804878
}
805879

806-
impl Stream for MetricMonitorStream {
880+
impl Stream for PartitionedMetricMonitor {
807881
type Item = datafusion::error::Result<RecordBatch>;
808882

809883
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
884+
if self.is_finished {
885+
return Poll::Ready(None);
886+
}
887+
810888
let poll = self.inner.as_mut().poll_next(cx);
811889

812890
// Check if the stream just finished
813891
match &poll {
814892
Poll::Ready(None) => {
815-
// Stream is done. Now we can safely read the metrics.
816-
let bytes = get_total_bytes_scanned(&self.plan);
817-
let current_date = chrono::Utc::now().date_naive().to_string();
818-
increment_bytes_scanned_in_query_by_date(bytes, &current_date);
893+
self.is_finished = true;
894+
self.check_if_last_stream();
819895
}
820896
Poll::Ready(Some(Err(e))) => {
821-
let bytes = get_total_bytes_scanned(&self.plan);
822-
let current_date = chrono::Utc::now().date_naive().to_string();
823-
increment_bytes_scanned_in_query_by_date(bytes, &current_date);
824897
tracing::error!("Stream Failed with error: {}", e);
898+
self.is_finished = true;
899+
self.check_if_last_stream();
825900
}
826901
_ => {}
827902
}
@@ -834,12 +909,24 @@ impl Stream for MetricMonitorStream {
834909
}
835910
}
836911

837-
impl RecordBatchStream for MetricMonitorStream {
912+
impl RecordBatchStream for PartitionedMetricMonitor {
838913
fn schema(&self) -> SchemaRef {
839914
self.inner.schema()
840915
}
841916
}
842917

918+
impl PartitionedMetricMonitor {
919+
fn check_if_last_stream(&self) {
920+
let prev_count = self.state.active_streams.fetch_sub(1, Ordering::SeqCst);
921+
922+
if prev_count == 1 {
923+
let bytes = get_total_bytes_scanned(&self.state.plan);
924+
let current_date = chrono::Utc::now().date_naive().to_string();
925+
increment_bytes_scanned_in_query_by_date(bytes, &current_date);
926+
}
927+
}
928+
}
929+
843930
#[cfg(test)]
844931
mod tests {
845932
use serde_json::json;

0 commit comments

Comments
 (0)