Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/139797.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 139797
summary: Fix aggregation on null value
area: ES|QL
type: bug
issues:
- 137544
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ public enum Type {
DOUBLE_RANGE(s -> s == null ? null : Arrays.stream(s.split("-")).map(Double::parseDouble).toArray(), double[].class),
DATE_RANGE(s -> s == null ? null : Arrays.stream(s.split("-")).map(BytesRef::new).toArray(), BytesRef[].class),
VERSION(v -> new org.elasticsearch.xpack.versionfield.Version(v).toBytesRef(), BytesRef.class),
NULL(s -> null, Void.class),
NULL(s -> s, Void.class),
DATETIME(
x -> x == null ? null : DateFormatters.from(UTC_DATE_TIME_FORMATTER.parse(x)).toInstant().toEpochMilli(),
(l, r) -> l instanceof Long maybeIP ? maybeIP.compareTo((Long) r) : l.toString().compareTo(r.toString()),
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be lovely if we could add a couple more tests, as we're adding something pretty new here:

  • tests with different agg functions, esp. the special ones from ReplaceStatsFilteredAggWithEval#mapNullToValue
  • tests with more than 1 agg function where 1 or more get null literals, and some where another agg function does not get a null value.
  • tests with BY
  • tests with INLINE STATS
  • tests with per-agg WHERE clauses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests with different agg functions, esp. the special ones from ReplaceStatsFilteredAggWithEval#mapNullToValue

There are already tests for count/count_distinct:

  • stats.countNull
  • stats.countDistinctNull

Those were both working because those two functions are not nullable and they were escaping null-folding because of that.

I'll definitely add tests for all other points you bring up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since this closes #137544, let's add the repro queries from the issue to the spec tests?

Original file line number Diff line number Diff line change
Expand Up @@ -3599,3 +3599,36 @@ from airports
a:double | b:double | c:long
6.0 | 6.0 | 8
;

fixStatsValuesOnReferenceToNullUncasted
required_capability: fix_agg_on_null_by_replacing_with_eval

ROW x = null
| STATS VALUES(x)
;

VALUES(x):null
null
;

fixStatsValuesOnReferenceToNullCasted
required_capability: fix_agg_on_null_by_replacing_with_eval

ROW x = null::long
| STATS VALUES(x)
;

VALUES(x):long
null
;

fixStatsValuesOnLiteralNull
required_capability: fix_agg_on_null_by_replacing_with_eval

ROW x = 1
| STATS y = VALUES(null)
;

y:null
null
;
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.rest.action.admin.cluster.RestNodesCapabilitiesAction;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredOrNullAggWithEval;
import org.elasticsearch.xpack.esql.plugin.EsqlFeatures;

import java.util.ArrayList;
Expand Down Expand Up @@ -1790,7 +1791,7 @@ public enum Cap {
FIX_INLINE_STATS_INCORRECT_PRUNNING(INLINE_STATS.enabled),

/**
* {@link org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredAggWithEval} replaced a stats
* {@link ReplaceStatsFilteredOrNullAggWithEval} replaced a stats
* with false filter with null with {@link org.elasticsearch.xpack.esql.expression.function.aggregate.Present} or
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.Absent}
*/
Expand All @@ -1806,6 +1807,14 @@ public enum Cap {
*/
ENABLE_REDUCE_NODE_LATE_MATERIALIZATION(Build.current().isSnapshot()),

/**
* {@link ReplaceStatsFilteredOrNullAggWithEval} now replaces an
* {@link org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction} with null value with an
* {@link org.elasticsearch.xpack.esql.plan.logical.Eval}.
* https://github.com/elastic/elasticsearch/issues/137544
*/
FIX_AGG_ON_NULL_BY_REPLACING_WITH_EVAL,

// Last capability should still have a comma for fewer merge conflicts when adding new ones :)
// This comment prevents the semicolon from being on the previous capability when Spotless formats the file.
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceOrderByExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceRegexMatch;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceRowAsLocalRelation;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredAggWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredOrNullAggWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStringCasingWithInsensitiveEquals;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceTrivialTypeConversions;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.SetAsOptimized;
Expand Down Expand Up @@ -206,7 +206,7 @@ protected static Batch<LogicalPlan> operators() {
// TODO: bifunction can now (since we now have just one data types set) be pushed into the rule
new SimplifyComparisonsArithmetics(DataType::areCompatible),
new ReplaceStringCasingWithInsensitiveEquals(),
new ReplaceStatsFilteredAggWithEval(),
new ReplaceStatsFilteredOrNullAggWithEval(),
new ExtractAggregateCommonFilter(),
// prune/elimination
new PruneFilters(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public Expression rule(Expression e, LogicalOptimizerContext ctx) {
// Non-evaluatable functions stay as a STATS grouping (It isn't moved to an early EVAL like other groupings),
// so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value.
&& e instanceof GroupingFunction.NonEvaluatableGroupingFunction == false
// We cannot fold aggregate functions until we resolve https://github.com/elastic/elasticsearch/issues/100634.
// AggregateMapper cannot handle aggregate functions with literal values.
&& e instanceof AggregateFunction == false
&& Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) {
return Literal.of(e, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Absent;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
Expand All @@ -34,19 +35,30 @@
import java.util.List;

/**
* Replaces an aggregation function having a false/null filter with an EVAL node.
* Replaces an aggregation function with an EVAL node under 2 conditions.
*
* First, having a false/null filter
* <pre>
* ... | STATS/INLINE STATS x = someAgg(y) WHERE FALSE {BY z} | ...
* =>
* ... | STATS/INLINE STATS x = someAgg(y) {BY z} > | EVAL x = NULL | KEEP x{, z} | ...
* ... | EVAL x = NULL | KEEP x{, z} | ...
* </pre>
*
* Second, having an agg on a null value
* <pre>
* ... | STATS/INLINE STATS x = someAgg(null) {BY z} | ...
* =>
* ... | EVAL x = NULL | KEEP x{, z} | ...
* </pre>
*
* This rule is applied to both STATS' {@link Aggregate} and {@link InlineJoin} right-hand side {@link Aggregate} plans.
* The logic is common for both, but the handling of the {@link InlineJoin} is slightly different when it comes to pruning
* its right-hand side {@link Aggregate}.
* Skipped in local optimizer: once a fragment contains an Agg, this can no longer be pruned, which the rule can do
*/
public class ReplaceStatsFilteredAggWithEval extends OptimizerRules.OptimizerRule<LogicalPlan> implements OptimizerRules.CoordinatorOnly {
public class ReplaceStatsFilteredOrNullAggWithEval extends OptimizerRules.OptimizerRule<LogicalPlan>
implements
OptimizerRules.CoordinatorOnly {
@Override
protected LogicalPlan rule(LogicalPlan plan) {
Aggregate aggregate;
Expand All @@ -69,11 +81,7 @@ protected LogicalPlan rule(LogicalPlan plan) {
List<NamedExpression> newProjections = new ArrayList<>(oldAggSize);

for (var ne : aggregate.aggregates()) {
if (ne instanceof Alias alias
&& alias.child() instanceof AggregateFunction aggFunction
&& aggFunction.hasFilter()
&& aggFunction.filter() instanceof Literal literal
&& Boolean.FALSE.equals(literal.value())) {
if (ne instanceof Alias alias && alias.child() instanceof AggregateFunction aggFunction && shouldReplace(aggFunction)) {

Object value = mapNullToValue(aggFunction);
Alias newAlias = alias.replaceChild(Literal.of(aggFunction, value));
Expand Down Expand Up @@ -119,6 +127,14 @@ protected LogicalPlan rule(LogicalPlan plan) {
return plan;
}

private static boolean shouldReplace(AggregateFunction aggFunction) {
return hasFalseFilter(aggFunction) || DataType.isNull(aggFunction.field().dataType());
}

private static boolean hasFalseFilter(AggregateFunction aggFunction) {
return aggFunction.hasFilter() && aggFunction.filter() instanceof Literal literal && Boolean.FALSE.equals(literal.value());
}

private static Object mapNullToValue(AggregateFunction aggFunction) {
return switch (aggFunction) {
case Count ignored -> 0L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.dissect.DissectParser;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
Expand Down Expand Up @@ -231,7 +232,7 @@
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.startsWith;

//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover?

public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests {
private static final LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new LiteralsOnTheRight();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
Expand Down Expand Up @@ -158,8 +158,6 @@ public void testGenericNullableExpression() {
assertNullLiteral(foldNull(new Cos(EMPTY, NULL)));
// string functions
assertNullLiteral(foldNull(new LTrim(EMPTY, NULL)));
// spatial
assertNullLiteral(foldNull(new SpatialCentroid(EMPTY, NULL)));
// ip
assertNullLiteral(foldNull(new CIDRMatch(EMPTY, NULL, List.of(NULL))));
// conversion
Expand All @@ -180,50 +178,34 @@ public void testNullFoldingDoesNotApplyOnLogicalExpressions() {

@SuppressWarnings("unchecked")
public void testNullFoldingDoesNotApplyOnAggregate() throws Exception {
List<Class<? extends AggregateFunction>> items = List.of(Max.class, Min.class);
List<Class<? extends AggregateFunction>> items = List.of(
Avg.class,
Count.class,
Max.class,
Median.class,
MedianAbsoluteDeviation.class,
Min.class,
Sum.class,
Values.class
);
for (Class<? extends AggregateFunction> clazz : items) {
Constructor<? extends AggregateFunction> ctor = clazz.getConstructor(Source.class, Expression.class);
AggregateFunction conditionalFunction = ctor.newInstance(EMPTY, getFieldAttribute("a"));
assertEquals(conditionalFunction, foldNull(conditionalFunction));

conditionalFunction = ctor.newInstance(EMPTY, NULL);
assertEquals(NULL, foldNull(conditionalFunction));
assertEquals(conditionalFunction, foldNull(conditionalFunction));
}

Avg avg = new Avg(EMPTY, getFieldAttribute("a"));
assertEquals(avg, foldNull(avg));
avg = new Avg(EMPTY, NULL);
assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(avg));

Count count = new Count(EMPTY, getFieldAttribute("a"));
assertEquals(count, foldNull(count));
count = new Count(EMPTY, NULL);
assertEquals(count, foldNull(count));

CountDistinct countd = new CountDistinct(EMPTY, getFieldAttribute("a"), getFieldAttribute("a"));
assertEquals(countd, foldNull(countd));
countd = new CountDistinct(EMPTY, NULL, NULL);
assertEquals(new Literal(EMPTY, null, LONG), foldNull(countd));

Median median = new Median(EMPTY, getFieldAttribute("a"));
assertEquals(median, foldNull(median));
median = new Median(EMPTY, NULL);
assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(median));

MedianAbsoluteDeviation medianad = new MedianAbsoluteDeviation(EMPTY, getFieldAttribute("a"));
assertEquals(medianad, foldNull(medianad));
medianad = new MedianAbsoluteDeviation(EMPTY, NULL);
assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(medianad));
assertEquals(countd, foldNull(countd));

Percentile percentile = new Percentile(EMPTY, getFieldAttribute("a"), getFieldAttribute("a"));
assertEquals(percentile, foldNull(percentile));
percentile = new Percentile(EMPTY, NULL, NULL);
assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(percentile));

Sum sum = new Sum(EMPTY, getFieldAttribute("a"));
assertEquals(sum, foldNull(sum));
sum = new Sum(EMPTY, NULL);
assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(sum));
assertEquals(percentile, foldNull(percentile));
}

public void testNullFoldableDoesNotApplyToIsNullAndNotNull() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.NULL;
import static org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizerTests.releaseBuildForInlineStats;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;

public class ReplaceStatsFilteredAggWithEvalTests extends AbstractLogicalPlanOptimizerTests {
public class ReplaceStatsFilteredOrNullAggWithEvalTests extends AbstractLogicalPlanOptimizerTests {

/**
* <pre>{@code
Expand Down Expand Up @@ -800,4 +801,85 @@ public void testReplaceTwoConsecutiveInlineStats_WithFalseFilters() {
assertThat(aliasCc.child().fold(FoldContext.small()), is(0L));
as(eval.child(), EsRelation.class);
}

/**
* <pre>{@code
* Limit[1000[INTEGER],false,false]
* \_LocalRelation[[max(x){r}#6],Page{blocks=[ConstantNullBlock[positions=1]]}]
* }</pre>
*/
public void testReplaceStatsMaxOnNullReferenceWithEvalSingleAgg() {
var plan = plan("""
row x = null
| stats max(x)
""");

var project = as(plan, Limit.class);
var source = as(project.child(), LocalRelation.class);
assertThat(Expressions.names(source.output()), contains("max(x)"));
Page page = source.supplier().get();
assertThat(page.getBlockCount(), is(1));
assertThat(page.getBlock(0).getPositionCount(), is(1));
assertTrue(page.getBlock(0).areAllValuesNull());
}

/**
* <pre>{@code
* Project[[y{r}#6]]
* \_Eval[[null[NULL] AS y#6]]
* \_Limit[1000[INTEGER],false,false]
* \_LocalRelation[[{e}#7],Page{blocks=[ConstantNullBlock[positions=1]]}]
* }</pre>
*/
public void testReplaceStatsMaxOnNullLiteralWithEvalSingleAgg() {
var plan = plan("""
row x = 3
| stats y = max(null)
""");

var project = as(plan, Project.class);
assertThat(Expressions.names(project.projections()), contains("y"));
var eval = as(project.child(), Eval.class);
assertThat(eval.fields().size(), is(1));

var alias = as(eval.fields().getFirst(), Alias.class);
assertTrue(alias.child().foldable());
assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(NULL));

var limit = as(eval.child(), Limit.class);
var source = as(limit.child(), LocalRelation.class);
}

/**
* <pre>{@code
* Project[[max(x){r}#9, sum(y){r}#11, x{r}#4]]
* \_Eval[[null[NULL] AS max(x)#9]]
* \_Limit[1000[INTEGER],false,false]
* \_Aggregate[[x{r}#4],[SUM(y{r}#6,true[BOOLEAN],PT0S[TIME_DURATION],compensated[KEYWORD]) AS sum(y)#11, x{r}#4]]
* \_LocalRelation[[x{r}#4, y{r}#6],Page{blocks=[ConstantNullBlock[positions=1], IntVectorBlock[vector=..]]}]
* }</pre>
*/
public void testReplaceStatsMaxOnNullWithEvalAndAgg() {
var plan = plan("""
row x = null, y = 1
| stats max(x),
sum(y)
by x
""");

var project = as(plan, Project.class);
assertThat(Expressions.names(project.projections()), contains("max(x)", "sum(y)", "x"));
var eval = as(project.child(), Eval.class);
assertThat(eval.fields().size(), is(1));

var alias = as(eval.fields().getFirst(), Alias.class);
assertTrue(alias.child().foldable());
assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(NULL));

var limit = as(eval.child(), Limit.class);
var aggregate = as(limit.child(), Aggregate.class);
var source = as(aggregate.child(), LocalRelation.class);
}
}