Skip to content

api: formalise expression expansion in group-by #2225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 task
MarcoGorelli opened this issue Mar 16, 2025 · 10 comments
Open
1 task

api: formalise expression expansion in group-by #2225

MarcoGorelli opened this issue Mar 16, 2025 · 10 comments

Comments

@MarcoGorelli
Copy link
Member

MarcoGorelli commented Mar 16, 2025

We need to formalise the rules around

df.group_by(keys).agg(expr.sum())

when expr isn't a simple single-column expression.

The rules aren't totally clear in Polars either, see

where I noted that agg(pl.col('a', 'b').sum()) differs from agg(pl.nth(0, 1).sum()).

It looks to me like the rule maybe is

When expressions in agg are expanded out, group-by keys are excluded, unless they are selected explicitly by name.

For example:

  • group_by('a').agg(nw.all().sum()): column 'a' should be excluded from the all expansion
  • group_by('a').agg(nw.selectors.by_type(nw.Float32).sum()): column 'a' should be excluded from the all expansion
  • group_by('a').agg(nw.col('a', 'b').sum()): column 'a' should be included in the all expansion
  • group_by('a').agg(nw.nth(0, 1).sum()): column 'a' should be excluded from the all expansion

Regardless of what Polars does, we may want to think of what we think a good rule would look like, as there's no guarantee that Polars would remain stable here anyway

@MarcoGorelli MarcoGorelli changed the title api: formalise expression in expansion in group-by api: formalise expression expansion in group-by Mar 16, 2025
@dangotbanned
Copy link
Member

@MarcoGorelli hope it's okay that I edited the description

This way we get the Listed in #XXXX thing:

Screenshot

Image

@camriddell
Copy link
Member

I think the last example would be a surprising behavior of the column 'a' was excluded as the user specifically requests 2 columns and would not receive 2 columns in the result. I think a set of applicable rules could be:

  • Selections with 1:1 input-output relationship should maintain that behavior (e.g. nw.col and nw.nth).
  • Selections with 1:M input-output relationship can have M adjusted post-hoc by removing the group keys (e.g. all selectors).

A separate thought (and divergence from the Polars API) would be to include a special selector to include/exclude the group_keys. The default behavior could be to exclude the group_by keys from the result set when using other selectors:

# df has columns a, b, c
df.group_by('a').agg(nw.all().sum()) # has b, c in result
df.group_by('a').agg((nw.all() | nw.group_keys()).sum()) # has a, b, c in result

This would help keep us away from a static argument in group_by to include/exclude the keys. If one uses this selector outside of a group_by context it would be a no-op. This does mean that the selector itself would need access to the expansion context which can be handled at the Python level.

@MarcoGorelli
Copy link
Member Author

I think the last example would be a surprising behavior of the column 'a' was excluded as the user specifically requests 2 columns and would not receive 2 columns in the result

I also find it surprising and kinda feeling like considering this a bug in Polars

@MarcoGorelli
Copy link
Member Author

  • Selections with 1:1 input-output relationship should maintain that behavior (e.g. nw.col and nw.nth).
  • Selections with 1:M input-output relationship can have M adjusted post-hoc by removing the group keys (e.g. all selectors).

Another one to think about is M:1, e.g.

df.group_by('a').agg(nw.sum_horizontal(nw.all()).sum().name.suffix('_agg'))

Here Polars considers all columns (even the grouped-by keys) in the inner nw.all(), as do we in Narwhals, and I think that's fine

@camriddell
Copy link
Member

camriddell commented Mar 18, 2025

I think having nw.all() expand differently in these cases leads to ambiguity as to when one would expect the grouping keys to be solely part of the result set, or part of the computation itself. For a comparison to other tools

  • Polars inserts the grouping column as part of the result set (not the computation). This is cannot be controlled by the user.
  • SQL doesn't report the grouping column unless explicitly included in the query
import polars as pl
import duckdb
print(
    f"{pl.__version__     = }", # 1.25.2
    f"{duckdb.__version__ = }", # 1.2.1
    sep='\n',
)

df = pl.DataFrame({
    'a': [*'xxyyz'],
    'b': [1,2,3,4,5],
    'c': [7,8,9,10,11],
})


## Group-by & 1:1 expressions
# polars deviates from traditional SQL group-by behavior because the grouped
#   key is inserted into the results by default
print( # 'a' is in output
    df
    .group_by('a')
    .agg(pl.col('b', 'c').sum()),
)

print( # 'a' is not in output, unless explicitly requested
    duckdb.sql('''
        select sum(b), sum(c)
        from df
        group by "a"
    ''')
)

## Group-by & selectors (1:M expressions)
# the grouping column can often not be treated in the same fashion as the value columns
#  typically due to a semantic difference in the columns or a strict data type difference
#  Polars avoids this issue by excluding the grouping column from the computation
#  then inserting it later on

print( # 'a' is in the output, but its sum was not computed
    df # this indicates that `pl.all()` was expanded to NOT include 'a'
    .group_by('a')
    .agg(pl.all().sum()),
)

# duckdb does not perform this exclusion and thus raises an error due to sum(VARCHAR)
# print(
#     duckdb.sql('''
#         select sum(columns(*)),
#         from df
#         group by "a"
#     ''')
# )

## Group-by & selectors (M:1 expressions)
# Polars fully respects the selector in this case, though this feels almost
#  inconsistent with the behavior above as `pl.all()` is being expanded differently
#  in these cases
print( # 'a' is in the output, AND is used in the computation
    # I was actually surprised type promotion happened here. This may be a bug?
    df
    .group_by('a')
    .agg(
        result=pl.sum_horizontal(pl.all())
    ),
)

# duckdb exhibits the same behavior, but raises instead of automatic casting
#   cannot create a list of types VARCHAR and BIGINT
# print(
#     duckdb.sql('''
#         select list_value(*columns(*)),
#         from df
#         group by "a"
#     ''')
# )

So it seems like the following rules exist for these tools for operations within
a group-by context. Below I use K for the group by columns and V for the remaining
columns.

Polars Example Polars Result Polars was K part of the computation? DuckDB Example DuckDB Result DuckDB was K part of the computation?
Selector
1:1 .group_by(K).agg(pl.col(V).sum()) K|V No select sum(V) from … group by K V No
1:M .group_by(K).agg(pl.all().sum()) K|V No select sum(columns(*)) from … group by K K|V Yes
M:1 .group_by(K).agg(res=pl.sum_horizontal(pl.all())) K|V Yes select list_value(*columns(*)) from df group by K K|V Yes

I think the underlying cause of confusion here is that Polars includes the grouping column in the result set by default thus ensuring the result set is always K|V. This leads to ambiguity in the behavior of 1:M selectors that could include the grouping keys (as K will be projected twice).

@MarcoGorelli
Copy link
Member Author

Thinking about this more I don't think that agg(nw.sum_horizontal(nw.all()) is problematic, because that's a single-output expression. It's well defined

The issue is what happens with multi-output expressions, and how they get expanded

In binary comparisons, we already only allow multi-output expressions to be in the left-hand-side. So, this should be tractable

@MarcoGorelli
Copy link
Member Author

A few more thoughts:

In ExprMetadata, we could track ExpansionKind, which could be one of:

  • Single. Always only produces a single column output. e.g. nw.col('a'), nw.sum_horizontal(...)
  • MultiAnonymous. Produces multiple outputs, whose names depend on the input dataframe. For example, nw.all(), nw.nth(0, 1), nw.selectors.datetime()
  • MultiNamed. Produces multiple outputs whose names are known in advance. For example, nw.col('a', 'b')

These can change according to the following rules:

  • binary operations (e.g. nw.all() + nw.col('a')). Multi-output expressions are banned from the right-hand-side, and the ExpansionKind of the left-hand-side is always preseved
  • horizontal reductions (e.g. nw.concat_str(nw.all(), nw.col('b'), 'c')). Regardless of the inputs, the result ExpansionKind is always Single
  • methods (e.g. nw.all().mode()). These always preserve the input ExpansionKind

By the time we get to group_by, the rule can be:

  • For MultiAnonymous expressions, the group-by keys are excluded from the result
  • For anything else, all aggregated expressions are included in the result

This has a few properties that I like:

  • it seems aligned with what Polars does (even though it may not have been formalised on their side)
  • for the vast majority of cases, I think this does what the user intends. Nobody had reported this to Polars before I made the issue, which suggests that what Polars currently does it good enough
  • on the implementation side, the rules seem simple, predictable, and teachable

@dangotbanned
Copy link
Member

#2225 (comment)

@MarcoGorelli in relation to these rules - what subset of expressions do you think could be permitted in df.group_by(...)?

I know it isn't possible yet, but I thought it might be a nice exercise in dogfooding the categories

@MarcoGorelli
Copy link
Member Author

Expressions to group_by would have to be ExprKind.TRANSFORM and would have to have had no windows in their history

For the latter part, we should probably also track n_closed_windows. Then:

  • WINDOW: increases n_open_windows by 1
  • OVER:
    • if applied to a WINDOW and order_by was specified, increases n_closed_windows by 1
    • else, both n_open_window and n_closed_windows increase by 1

So then, in extract_compliant, instead of checking

if arg._metadata.n_open_windows > 0:

we check

if arg._metadata.n_open_windows > arg._metadata.n_closed_windows:

Then

  • in group_by, we check that for the arguments to group_by, n_open_windows is 0. This is to disallow df.group_by(nw.col('a').diff().over(order_by='b')).agg(nw.all().sum()) - you can't group by diff even if it was followed by over(order_by=...)
  • to address api: support window functions in filter #2226, in filter, determine whether to call filter or qualify based on whether n_open_windows>0

Does this make sense?

Happy for you to do this n_open_windows / n_closed_windows refactor if it interests you (but only if it interests you, don't want to come across as demanding free work or anything)?

@dangotbanned
Copy link
Member

#2225 (comment)

Really appreciate the detail, thanks @MarcoGorelli!

Does this make sense?

Mostly, the only bit I'm missing is the why for this part:

and would have to have had no windows in their history

but the how seems sound 🙂


Open/Close

Happy for you to do this n_open_windows / n_closed_windows refactor if it interests you (but only if it interests you, don't want to come across as demanding free work or anything)?

Oh I'm interested - don't be silly there's no demand here 😅

Your comment reminds me of (#2078 (comment)) and I think it could be helpful here.

The API change could look like this:

ExprMetadata.with_window_open()

metadata: ExprMetadata

# Current
metadata.with_extra_open_window()

# Proposed
metadata.with_window_open()

ExprMetadata.with_window_close()

# Current
n_open_windows = metadata.n_open_windows()
if _order_by is not None and metadata.kind.is_window():
    n_open_windows -= 1
next_meta = ExprMetadata(
    kind,
    n_open_windows=n_open_windows,
    expansion_kind=metadata.expansion_kind,
)

# Proposed
metadata.with_window_close()

ExprMetadata.is_window_open()

# Current
if metadata.n_open_windows > 0:
if metadata.n_open_windows > metadata.n_closed_windows:

# Proposed
if metadata.is_window_open():

Those methods could still work internally with n_(open|closed)_windows - but you wouldn't need that detail in all of:

Where n_open_windows?

if arg._metadata.n_open_windows > 0:

narwhals/narwhals/expr.py

Lines 1590 to 1596 in b563ba7

n_open_windows = self._metadata.n_open_windows
if flat_order_by is not None and self._metadata.kind.is_window():
n_open_windows -= 1
current_meta = self._metadata
next_meta = ExprMetadata(
kind,
n_open_windows=n_open_windows,

if arg._metadata.n_open_windows:
result_n_open_windows += 1

Bringing it all the way back to:

and would have to have had no windows in their history

I'm not sure if that part is captured yet.
I imagine the check is if either n_(open|closed)_windows is non-zero?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants