Skip to content

feat: Allow expressions as group_by keys #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 124 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from 123 commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
c3389d5
WIP: pandas and (almost) polars
FBruzzesi Mar 31, 2025
cbc2275
arrow
FBruzzesi Mar 31, 2025
3c03f42
lazyframes
FBruzzesi Mar 31, 2025
3ba79c0
WIP typing
FBruzzesi Mar 31, 2025
4ce0c15
almost there
FBruzzesi Mar 31, 2025
1b2f38a
better parsing: include multi-output exprs
FBruzzesi Mar 31, 2025
105623e
test multi-output
FBruzzesi Mar 31, 2025
73f85ea
allow lit and agg's
FBruzzesi Apr 1, 2025
b940b2c
some more typing and coverage
FBruzzesi Apr 1, 2025
c6cf243
deduplicate
FBruzzesi Apr 2, 2025
f189b26
fix tests
FBruzzesi Apr 2, 2025
698091d
fix(typing): Narrow `EagerGroupBy._keys` to `list[str]`
dangotbanned Apr 2, 2025
922a02e
fix(typing): Mostly resolve `Polars*`
dangotbanned Apr 2, 2025
6407f52
fix(typing): Narrow from `Series | Expr` to `Expr`
dangotbanned Apr 2, 2025
1868a4d
fix(typing): Un break `extract_native` overloads
dangotbanned Apr 2, 2025
733d9f4
refactor: Generalize `_parse_keys` -> `_evaluate_aliases`
dangotbanned Apr 2, 2025
94e3211
fix(typing): Un-confuse `mypy`
dangotbanned Apr 2, 2025
d38e466
revert: Undo more renaming
dangotbanned Apr 2, 2025
a38fa41
refactor(typing): Remove implicit `Self`
dangotbanned Apr 2, 2025
1f17e52
refactor: Use `CompliantExpr._evaluate_aliases` more
dangotbanned Apr 2, 2025
08ba27f
test: Skip `over` for `duckdb<1.3`
dangotbanned Apr 2, 2025
25307f4
chore: Remove dead code
dangotbanned Apr 2, 2025
72f5dbb
refactor: Move `with_columns` out of `__init__`
dangotbanned Apr 2, 2025
eac99b8
special paths for pandas-like and dask
FBruzzesi Apr 3, 2025
4d5a396
move parsing to group_by
FBruzzesi Apr 3, 2025
0e6a983
Merge remote-tracking branch 'upstream/main' into feat/allow-expr-in-…
dangotbanned Apr 5, 2025
660c2ab
revert(typing): Remove `IncompletePolarsExpr`
dangotbanned Apr 5, 2025
befcc48
chore(typing): Ignore existing issues
dangotbanned Apr 5, 2025
aaf027d
fix(typing): Make everyone compliant again
dangotbanned Apr 5, 2025
8408b01
fix(typing): Use `pyright` ignores
dangotbanned Apr 5, 2025
3f93217
refactor: Move inner function to `_expression_passing`
dangotbanned Apr 5, 2025
890d741
parametrize test
FBruzzesi Apr 6, 2025
8f6632c
no cover
FBruzzesi Apr 6, 2025
f075315
fixup tests
FBruzzesi Apr 6, 2025
3bb4968
rm shapeerror for lazyframe, fix duckdb context condition in test
FBruzzesi Apr 6, 2025
9ea5d11
flatten first
FBruzzesi Apr 6, 2025
3473c56
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 8, 2025
346f492
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 8, 2025
ec449ec
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 9, 2025
8431cee
Merge branch 'feat/allow-expr-in-group-by' of https://github.com/narw…
FBruzzesi Apr 9, 2025
f050486
Merge remote-tracking branch 'upstream/main' into feat/allow-expr-in-…
dangotbanned Apr 9, 2025
02c7ac5
solve conflicts
FBruzzesi Apr 10, 2025
74da552
solve conflicts
FBruzzesi Apr 10, 2025
bb7bda7
Merge branch 'main' into feat/allow-expr-in-group-by
MarcoGorelli Apr 10, 2025
54a7178
WIP
FBruzzesi Apr 11, 2025
386e38c
just missing ordering
FBruzzesi Apr 11, 2025
3faba54
compliant frames
FBruzzesi Apr 11, 2025
8ab528e
ok make stable v1 work
FBruzzesi Apr 11, 2025
49a77b7
extract_native(arg) -> arg.native
FBruzzesi Apr 12, 2025
7e86ac1
fix polars drop_null_keys case
FBruzzesi Apr 12, 2025
f26f6ed
rename parsing method, fix sqlframe
FBruzzesi Apr 12, 2025
326222f
fix eager __iter__
FBruzzesi Apr 12, 2025
78970ef
fix __all__ namespaces, arrow over
FBruzzesi Apr 12, 2025
c43512c
add index name in expected output
FBruzzesi Apr 12, 2025
fff4f41
xfail polars
FBruzzesi Apr 12, 2025
36fe9e3
missing drop_null_keys condition to xfail
FBruzzesi Apr 12, 2025
b5f5888
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 12, 2025
99987aa
deal with unnamed and key in ops in agg
FBruzzesi Apr 12, 2025
7862d28
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 12, 2025
a452116
old pandas, make mypy happy?
FBruzzesi Apr 12, 2025
dad4035
pandas?
FBruzzesi Apr 12, 2025
a077261
resolve conflicts
FBruzzesi Apr 12, 2025
9ac7a24
fix duckdb dashed names
FBruzzesi Apr 12, 2025
1b02590
docstrings and raise test
FBruzzesi Apr 12, 2025
b6d5190
merge main
FBruzzesi Apr 12, 2025
0393cc7
forgot about main errors for different exprkind's
FBruzzesi Apr 12, 2025
5bf2179
avoid using internal functions
FBruzzesi Apr 13, 2025
0bfaabc
ok unnamed?
FBruzzesi Apr 13, 2025
0c1d427
xfail polars drop_null_keys and multi output exprs
FBruzzesi Apr 13, 2025
925a295
pin down polars xfail
FBruzzesi Apr 13, 2025
f6dba0b
polars conditions
FBruzzesi Apr 13, 2025
4dec624
fast_path for keys being all strings
FBruzzesi Apr 14, 2025
7c42fb2
do not use init in protocol
FBruzzesi Apr 14, 2025
5c43b49
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 14, 2025
99313bf
merge main
FBruzzesi Apr 15, 2025
fa77614
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 16, 2025
cfaea70
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 18, 2025
3754321
type annotate 'keys'
FBruzzesi Apr 18, 2025
69edc84
do not cast
FBruzzesi Apr 18, 2025
2794362
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 18, 2025
e7b4790
Marco's feedback
FBruzzesi Apr 18, 2025
614b69d
merge main
FBruzzesi Apr 18, 2025
c5c0bad
merge main
FBruzzesi Apr 19, 2025
bb5bae8
WIP: Dan's feedback
FBruzzesi Apr 19, 2025
caaa092
fix(typing): Use an invariant `TypeVar` for frame
dangotbanned Apr 19, 2025
797cf74
Dan's feedback, pt2: it does look promising indeed
FBruzzesi Apr 19, 2025
3e968b5
refactor: Reduce protocol footprint
dangotbanned Apr 19, 2025
b28da50
fix(typing): Use a constrained `TypeVar`
dangotbanned Apr 19, 2025
8e66073
perf: Avoid converting to `Expr`
dangotbanned Apr 19, 2025
66a24f2
fix: Don't explode on angery `_bool_`
dangotbanned Apr 19, 2025
5d6dd61
try without `tupleify`?
dangotbanned Apr 19, 2025
e2184ea
refactor: Simplify `PolarsGroupBy`
dangotbanned Apr 19, 2025
c5379d7
remove duplicate line oops
dangotbanned Apr 19, 2025
0df5c3f
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 19, 2025
f079244
refactor: Align both `Polars` impls
dangotbanned Apr 19, 2025
61be623
fix?
FBruzzesi Apr 19, 2025
ae38ed7
feedback
FBruzzesi Apr 19, 2025
5f593c8
rollback pandas change
FBruzzesi Apr 19, 2025
5aef757
ok, pragmatism wins?
FBruzzesi Apr 19, 2025
eae7f40
pragmatism wins: test fix
FBruzzesi Apr 19, 2025
2a574c0
similar trick for pyarrow
FBruzzesi Apr 19, 2025
169d749
feedback
FBruzzesi Apr 20, 2025
62f5749
GroupBy positional only, parametrize exception test
FBruzzesi Apr 20, 2025
a16ec2a
raise for drop_null_keys with expr/series in keys
FBruzzesi Apr 20, 2025
bb22b1b
feat(typing): Add overloads for `drop_null_keys` string only
dangotbanned Apr 20, 2025
cb4a28e
also ignore `[call-overload]`
dangotbanned Apr 20, 2025
db9175f
split test to raise on drop_null_keys True and exprs
FBruzzesi Apr 20, 2025
1533cde
rm ExpansionKind.is_multi_output
FBruzzesi Apr 20, 2025
2b62198
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 21, 2025
a943409
Merge branch 'main' into feat/allow-expr-in-group-by
FBruzzesi Apr 21, 2025
b8ac267
perf: Avoid creating `2*len(exprs)` lists in `agg`
dangotbanned Apr 22, 2025
b08aabe
revert(typing): Un-add `EagerDataFrameT_co`
dangotbanned Apr 22, 2025
e3b5a7e
perf: evaluate exclude outside of loop
dangotbanned Apr 22, 2025
5e77645
revert: Remove duplicate method
dangotbanned Apr 22, 2025
3f22a44
perf: More outside of loop
dangotbanned Apr 22, 2025
d8a47ad
Merge remote-tracking branch 'upstream/main' into feat/allow-expr-in-…
dangotbanned Apr 23, 2025
7adffd0
chore: Apply suggestions
dangotbanned Apr 23, 2025
5440c4f
Dan's feedback
FBruzzesi Apr 24, 2025
a605700
Merge branch 'main' into feat/allow-expr-in-group-by
dangotbanned Apr 24, 2025
e839ba6
Merge branch 'main' into feat/allow-expr-in-group-by
dangotbanned Apr 26, 2025
8a9b313
fix when grouping by selector
MarcoGorelli Apr 27, 2025
9565b9e
simplify pyarrow, remove double-drop_nulls
MarcoGorelli Apr 27, 2025
1623a71
restore existing test
MarcoGorelli Apr 27, 2025
4995bc8
remove unnecessary xfail
MarcoGorelli Apr 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/downstream_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ jobs:
run: |
cd validoopsie
# empty pytest.ini to avoid pytest using narwhals configs
touch pytest.ini
touch pytest.ini
touch tests/__init__.py
touch tests/utils/__init__.py
uv run pytest tests
Expand Down
4 changes: 3 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:

return self._with_native(native_frame, validate_column_names=False)

def group_by(self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy:
def group_by(
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
) -> ArrowGroupBy:
from narwhals._arrow.group_by import ArrowGroupBy

return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
)
raise NotImplementedError(msg)

tmp = df.group_by(*partition_by, drop_null_keys=False).agg(self)
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
tmp = df.simple_select(*partition_by).join(
tmp,
how="left",
Expand Down
27 changes: 17 additions & 10 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,28 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):

def __init__(
self,
compliant_frame: ArrowDataFrame,
keys: Sequence[str],
df: ArrowDataFrame,
keys: Sequence[ArrowExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
if drop_null_keys:
self._compliant_frame = compliant_frame.drop_nulls(keys)
else:
self._compliant_frame = compliant_frame
self._keys: list[str] = list(keys)
self._df = df
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
self._drop_null_keys = drop_null_keys

def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
aggs: list[tuple[str, str, Any]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()
exclude = (*self._keys, *self._output_key_names)

for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, self._keys
expr, self.compliant, exclude
)

if expr._depth == 0:
Expand Down Expand Up @@ -120,7 +120,10 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
result_simple = result_simple.select(
[*self._keys, *[col for col in columns if col not in self._keys]]
)
return self.compliant._with_native(result_simple)

return self.compliant._with_native(result_simple).rename(
dict(zip(self._keys, self._output_key_names))
)

def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(
Expand All @@ -142,9 +145,13 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
null_replacement=null_token,
)
table = table.add_column(i=0, field_=col_token, column=key_values)

for v in pc.unique(key_values):
t = self.compliant._with_native(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
row = t.simple_select(*self._keys).row(0)
yield tuple(extract_py_scalar(el) for el in row), t
yield (
tuple(extract_py_scalar(el) for el in row),
t.simple_select(*self._df.columns),
)
23 changes: 18 additions & 5 deletions narwhals/_compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import NativeFrameT
from narwhals._compliant.typing import NativeSeriesT
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._translate import ArrowConvertible
from narwhals._translate import DictConvertible
from narwhals._translate import FromNative
Expand Down Expand Up @@ -159,7 +158,10 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self, *keys: str, drop_null_keys: bool
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
Expand Down Expand Up @@ -250,6 +252,10 @@ def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...

def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))


class CompliantLazyFrame(
_StoresNative[NativeFrameT],
Expand Down Expand Up @@ -302,8 +308,11 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
)
def gather_every(self, n: int, offset: int) -> Self: ...
def group_by(
self, *keys: str, drop_null_keys: bool
) -> CompliantGroupBy[Self, Any]: ...
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
Expand Down Expand Up @@ -349,6 +358,10 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]

def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))


class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT],
Expand Down Expand Up @@ -379,7 +392,7 @@ def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:

Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
_, aliases = evaluate_output_names_and_aliases(expr, self, [])
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
Expand Down
35 changes: 20 additions & 15 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import LazyExprT
from narwhals._compliant.typing import NativeExprT
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals.dependencies import get_numpy
from narwhals.dependencies import is_numpy_array
from narwhals.dtypes import DType
Expand Down Expand Up @@ -195,19 +194,6 @@ def clip(
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...

@property
def str(self) -> Any: ...
@property
def name(self) -> Any: ...
@property
def dt(self) -> Any: ...
@property
def cat(self) -> Any: ...
@property
def list(self) -> Any: ...
@property
def struct(self) -> Any: ...

def ewm_mean(
self,
*,
Expand Down Expand Up @@ -287,6 +273,25 @@ def _is_multi_output_unnamed(self) -> bool:
assert self._metadata is not None # noqa: S101
return self._metadata.expansion_kind.is_multi_unnamed()

def _evaluate_aliases(
self: CompliantExpr[CompliantFrameT, Any], frame: CompliantFrameT, /
) -> Sequence[str]:
names = self._evaluate_output_names(frame)
return alias(names) if (alias := self._alias_output_names) else names

@property
def str(self) -> Any: ...
@property
def name(self) -> Any: ...
@property
def dt(self) -> Any: ...
@property
def cat(self) -> Any: ...
@property
def list(self) -> Any: ...
@property
def struct(self) -> Any: ...


class DepthTrackingExpr(
CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co],
Expand Down Expand Up @@ -467,7 +472,7 @@ def _reuse_series_inner(
series._from_scalar(method(series)) if returns_scalar else method(series)
for series in self(df)
]
_, aliases = evaluate_output_names_and_aliases(self, df, [])
aliases = self._evaluate_aliases(df)
if [s.name for s in out] != list(aliases): # pragma: no cover
msg = (
f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n"
Expand Down
85 changes: 74 additions & 11 deletions narwhals/_compliant/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
from typing import Sequence
from typing import TypeVar

from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._compliant.typing import CompliantDataFrameT
from narwhals._compliant.typing import CompliantDataFrameT_co
from narwhals._compliant.typing import CompliantExprT_contra
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import CompliantFrameT_co
from narwhals._compliant.typing import CompliantLazyFrameT_co
from narwhals._compliant.typing import CompliantLazyFrameAny
from narwhals._compliant.typing import CompliantLazyFrameT
from narwhals._compliant.typing import DepthTrackingExprAny
from narwhals._compliant.typing import DepthTrackingExprT_contra
from narwhals._compliant.typing import EagerExprT_contra
from narwhals._compliant.typing import LazyExprT_contra
from narwhals._compliant.typing import NativeExprT_co
from narwhals._expression_parsing import is_multi_output
from narwhals.utils import is_sequence_of

if TYPE_CHECKING:
from typing_extensions import TypeAlias

_SameFrameT = TypeVar("_SameFrameT", CompliantDataFrameAny, CompliantLazyFrameAny)


if not TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -58,7 +66,6 @@

class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]):
_compliant_frame: Any
_keys: Sequence[str]

@property
def compliant(self) -> CompliantFrameT_co:
Expand All @@ -67,7 +74,7 @@ def compliant(self) -> CompliantFrameT_co:
def __init__(
self,
compliant_frame: CompliantFrameT_co,
keys: Sequence[str],
keys: Sequence[CompliantExprT_contra] | Sequence[str],
/,
*,
drop_null_keys: bool,
Expand All @@ -83,9 +90,60 @@ class DataFrameGroupBy(
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...


class ParseKeysGroupBy(
CompliantGroupBy[CompliantFrameT, CompliantExprT_contra],
Protocol38[CompliantFrameT, CompliantExprT_contra],
):
def _parse_keys(
self,
compliant_frame: CompliantFrameT,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
) -> tuple[CompliantFrameT, list[str], list[str]]:
if is_sequence_of(keys, str):
keys_str = list(keys)
return compliant_frame, keys_str, keys_str.copy()
else:
return self._parse_expr_keys(compliant_frame, keys=keys)

@staticmethod
def _parse_expr_keys(
compliant_frame: _SameFrameT, keys: Sequence[CompliantExprT_contra]
) -> tuple[_SameFrameT, list[str], list[str]]:
"""Parses key expressions to set up `.agg` operation with correct information.

Since keys are expressions, it's possible to alias any such key to match
other dataframe column names.

In order to match polars behavior and not overwrite columns when evaluating keys:

- We evaluate what the output key names should be, in order to remap temporary column
names to the expected ones, and to exclude those from unnamed expressions in
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
- Create temporary names for evaluated key expressions that are guaranteed to have
no overlap with any existing column name.
- Add these temporary columns to the compliant dataframe.
"""
suffix_token = "_" * (max(len(str(c)) for c in compliant_frame.columns) + 1)
output_names = compliant_frame._evaluate_aliases(*keys)

safe_keys = [
# multi-output expression cannot have duplicate names, hence it's safe to suffix
key.name.suffix(suffix_token)
if key._metadata is not None and is_multi_output(key._metadata.expansion_kind)
# otherwise it's single named and we can use Expr.alias
else key.alias(f"{new_name}{suffix_token}")
for key, new_name in zip(keys, output_names)
]
return (
compliant_frame.with_columns(*safe_keys),
compliant_frame._evaluate_aliases(*safe_keys),
output_names,
)


class DepthTrackingGroupBy(
CompliantGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra],
Protocol38[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""

Expand Down Expand Up @@ -138,16 +196,20 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:


class EagerGroupBy(
DepthTrackingGroupBy[CompliantDataFrameT_co, EagerExprT_contra, str],
DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra],
Protocol38[CompliantDataFrameT_co, EagerExprT_contra],
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, str],
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
Protocol38[CompliantDataFrameT, EagerExprT_contra],
): ...


class LazyGroupBy(
CompliantGroupBy[CompliantLazyFrameT_co, LazyExprT_contra],
Protocol38[CompliantLazyFrameT_co, LazyExprT_contra, NativeExprT_co],
ParseKeysGroupBy[CompliantLazyFrameT, LazyExprT_contra],
CompliantGroupBy[CompliantLazyFrameT, LazyExprT_contra],
Protocol38[CompliantLazyFrameT, LazyExprT_contra, NativeExprT_co],
):
_keys: list[str]
_output_key_names: list[str]

def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]:
output_names = expr._evaluate_output_names(self.compliant)
aliases = (
Expand All @@ -157,8 +219,9 @@ def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]:
)
native_exprs = expr(self.compliant)
if expr._is_multi_output_unnamed():
exclude = {*self._keys, *self._output_key_names}
for native_expr, name, alias in zip(native_exprs, output_names, aliases):
if name not in self._keys:
if name not in exclude:
yield native_expr.alias(alias)
else:
for native_expr, alias in zip(native_exprs, aliases):
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _with_native(
"""Return a new `CompliantSeries`, wrapping the native `series`.

In cases when operations are known to not affect whether a result should
be broadcast, we can pass `preverse_broadcast=True`.
be broadcast, we can pass `preserve_broadcast=True`.
Set this with care - it should only be set for unary expressions which don't
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
set it, you probably don't need it.
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,12 @@ def join_asof(
),
)

def group_by(self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy:
def group_by(
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
) -> DaskLazyGroupBy:
from narwhals._dask.group_by import DaskLazyGroupBy

return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys)
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)

def tail(self, n: int) -> Self: # pragma: no cover
native_frame = self.native
Expand Down
Loading
Loading