Skip to content

Commit 0d56a3b

Browse files
Changes to support string transform in add_field. (#1936)
Closes #1882 # Rationale for this change Feature request: Ability to pass transform name as string in `add_fields` # Are these changes tested? Yes # Are there any user-facing changes? Yes. Users will be able to pass transform names as string while calling add_fields method of update_spec.
1 parent 068ee5d commit 0d56a3b

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

pyiceberg/table/update/spec.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import (
20-
TYPE_CHECKING,
21-
Any,
22-
Dict,
23-
List,
24-
Optional,
25-
Set,
26-
Tuple,
27-
)
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
2820

2921
from pyiceberg.expressions import (
3022
Reference,
@@ -47,7 +39,7 @@
4739
UpdatesAndRequirements,
4840
UpdateTableMetadata,
4941
)
50-
from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform
42+
from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform, parse_transform
5143

5244
if TYPE_CHECKING:
5345
from pyiceberg.table import Transaction
@@ -85,11 +77,13 @@ def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> Non
8577
def add_field(
8678
self,
8779
source_column_name: str,
88-
transform: Transform[Any, Any],
80+
transform: Union[str, Transform[Any, Any]],
8981
partition_field_name: Optional[str] = None,
9082
) -> UpdateSpec:
9183
ref = Reference(source_column_name)
9284
bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive)
85+
if isinstance(transform, str):
86+
transform = parse_transform(transform)
9387
# verify transform can actually bind it
9488
output_type = bound_ref.field.field_type
9589
if not transform.can_transform(output_type):

pyiceberg/transforms.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,6 @@ def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
111111
return literal(func(lit.value))
112112

113113

114-
def parse_transform(v: Any) -> Any:
115-
if isinstance(v, str):
116-
if v == IDENTITY:
117-
return IdentityTransform()
118-
elif v == VOID:
119-
return VoidTransform()
120-
elif v.startswith(BUCKET):
121-
return BucketTransform(num_buckets=BUCKET_PARSER.match(v))
122-
elif v.startswith(TRUNCATE):
123-
return TruncateTransform(width=TRUNCATE_PARSER.match(v))
124-
elif v == YEAR:
125-
return YearTransform()
126-
elif v == MONTH:
127-
return MonthTransform()
128-
elif v == DAY:
129-
return DayTransform()
130-
elif v == HOUR:
131-
return HourTransform()
132-
else:
133-
return UnknownTransform(transform=v)
134-
return v
135-
136-
137114
class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
138115
"""Transform base class for concrete transforms.
139116
@@ -220,6 +197,29 @@ def _transform(array: "ArrayLike") -> "ArrayLike":
220197
return _transform
221198

222199

200+
def parse_transform(v: Any) -> Transform[Any, Any]:
201+
if isinstance(v, str):
202+
if v == IDENTITY:
203+
return IdentityTransform()
204+
elif v == VOID:
205+
return VoidTransform()
206+
elif v.startswith(BUCKET):
207+
return BucketTransform(num_buckets=BUCKET_PARSER.match(v))
208+
elif v.startswith(TRUNCATE):
209+
return TruncateTransform(width=TRUNCATE_PARSER.match(v))
210+
elif v == YEAR:
211+
return YearTransform()
212+
elif v == MONTH:
213+
return MonthTransform()
214+
elif v == DAY:
215+
return DayTransform()
216+
elif v == HOUR:
217+
return HourTransform()
218+
else:
219+
return UnknownTransform(transform=v)
220+
return v
221+
222+
223223
class BucketTransform(Transform[S, int]):
224224
"""Base Transform class to transform a value into a bucket partition value.
225225

tests/integration/test_partition_evolution.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ def test_add_hour(catalog: Catalog) -> None:
140140
_validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "hour_transform"))
141141

142142

143+
@pytest.mark.integration
144+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
145+
def test_add_hour_string_transform(catalog: Catalog) -> None:
146+
table = _table(catalog)
147+
table.update_spec().add_field("event_ts", "hour", "str_hour_transform").commit()
148+
_validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "str_hour_transform"))
149+
150+
143151
@pytest.mark.integration
144152
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
145153
def test_add_hour_generates_default_name(catalog: Catalog) -> None:

0 commit comments

Comments
 (0)