-
Notifications
You must be signed in to change notification settings - Fork 140
/
Copy pathexpr.py
214 lines (183 loc) · 8.18 KB
/
expr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import is_scalar_like
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import not_implemented
if TYPE_CHECKING:
from typing_extensions import Self
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._compliant.typing import AliasNames
from narwhals._compliant.typing import EvalNames
from narwhals._compliant.typing import EvalSeries
from narwhals._expression_parsing import ExprMetadata
from narwhals.typing import RankMethod
from narwhals.utils import Version
from narwhals.utils import _FullContext
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
_implementation: Implementation = Implementation.PYARROW
def __init__(
self,
call: EvalSeries[ArrowDataFrame, ArrowSeries],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[ArrowDataFrame],
alias_output_names: AliasNames | None,
backend_version: tuple[int, ...],
version: Version,
call_kwargs: dict[str, Any] | None = None,
implementation: Implementation | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._depth = depth
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._call_kwargs = call_kwargs or {}
self._metadata: ExprMetadata | None = None
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[ArrowDataFrame],
/,
*,
context: _FullContext,
function_name: str = "",
) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
try:
return [
ArrowSeries(
df.native[column_name],
name=column_name,
backend_version=df._backend_version,
version=df._version,
)
for column_name in evaluate_column_names(df)
]
except KeyError as e:
missing_columns = [
x for x in evaluate_column_names(df) if x not in df.columns
]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns, available_columns=df.columns
) from e
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
)
@classmethod
def from_column_indices(
cls: type[Self], *column_indices: int, context: _FullContext
) -> Self:
from narwhals._arrow.series import ArrowSeries
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [
ArrowSeries(
df.native[column_index],
name=df.native.column_names[column_index],
backend_version=df._backend_version,
version=df._version,
)
for column_index in column_indices
]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(
backend_version=self._backend_version, version=self._version
)
def __narwhals_expr__(self) -> None: ...
def _reuse_series_extra_kwargs(
self, *, returns_scalar: bool = False
) -> dict[str, Any]:
return {"_return_py_scalar": False} if returns_scalar else {}
def cum_sum(self, *, reverse: bool) -> Self:
return self._reuse_series("cum_sum", reverse=reverse)
def shift(self, n: int) -> Self:
return self._reuse_series("shift", n=n)
def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
assert self._metadata is not None # noqa: S101
if partition_by and not is_scalar_like(self._metadata.kind):
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
raise NotImplementedError(msg)
if not partition_by:
# e.g. `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
assert order_by is not None # help type checkers # noqa: S101
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
token = generate_temporary_column_name(8, df.columns)
df = df.with_row_index(token).sort(
*order_by, descending=False, nulls_last=False
)
result = self(df.drop([token], strict=True))
# TODO(marco): is there a way to do this efficiently without
# doing 2 sorts? Here we're sorting the dataframe and then
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
sorting_indices = pc.sort_indices(df.get_column(token).native) # type: ignore[call-overload]
return [s._with_native(s.native.take(sorting_indices)) for s in result]
else:
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(partition_by):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
# we just don't support it yet.
msg = (
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
"This is not yet supported."
)
raise NotImplementedError(msg)
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
tmp = df.simple_select(*partition_by).join(
tmp,
how="left",
left_on=partition_by,
right_on=partition_by,
suffix="_right",
)
return [tmp.get_column(alias) for alias in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)
def cum_count(self, *, reverse: bool) -> Self:
return self._reuse_series("cum_count", reverse=reverse)
def cum_min(self, *, reverse: bool) -> Self:
return self._reuse_series("cum_min", reverse=reverse)
def cum_max(self, *, reverse: bool) -> Self:
return self._reuse_series("cum_max", reverse=reverse)
def cum_prod(self, *, reverse: bool) -> Self:
return self._reuse_series("cum_prod", reverse=reverse)
def rank(self, method: RankMethod, *, descending: bool) -> Self:
return self._reuse_series("rank", method=method, descending=descending)
ewm_mean = not_implemented()