Skip to content

Commit 1a7f8f3

Browse files
committed
add pyspark support
1 parent 90e9830 commit 1a7f8f3

File tree

8 files changed

+222
-27
lines changed

8 files changed

+222
-27
lines changed

bmsdna/table_rendering/excel.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from xlsxwriter.worksheet import Worksheet
2828
from xlsxwriter.workbook import Workbook
2929
import polars as pl
30+
from pyspark.sql import DataFrame as SparkDataFrame
3031

3132

3233
class SheetOptions(TypedDict):
@@ -36,7 +37,7 @@ class SheetOptions(TypedDict):
3637

3738
def render_into_sheet(
3839
configs: Sequence[ColumnConfig],
39-
data: "Iterable[dict] | pl.DataFrame",
40+
data: "Iterable[dict] | pl.DataFrame| SparkDataFrame",
4041
ws: "Worksheet",
4142
wb: "Workbook",
4243
sheet_options: SheetOptions = {},
@@ -46,16 +47,28 @@ def render_into_sheet(
4647
autofit=True,
4748
table_name: str | None = None,
4849
) -> "Worksheet":
49-
try:
50-
import polars as pl
50+
data_iter: Iterable[dict] | None = None
51+
if data is None:
52+
data_iter = []
53+
if data_iter is None:
54+
try:
55+
from pyspark.sql import DataFrame as SparkDataFrame
5156

52-
if isinstance(data, pl.DataFrame):
53-
data_iter: Iterable[dict] = data.iter_rows(named=True)
54-
else:
55-
data_iter: Iterable[dict] = data
56-
except ImportError:
57-
data_iter: Iterable[dict] = data # type: ignore
57+
if isinstance(data, SparkDataFrame):
58+
data_iter = (d.asDict(True) for d in data.collect())
59+
except ImportError:
60+
pass
61+
if data_iter is None:
62+
try:
63+
import polars as pl
5864

65+
if isinstance(data, pl.DataFrame):
66+
data_iter = data.iter_rows(named=True)
67+
else:
68+
data_iter = cast(list[dict], data)
69+
except ImportError:
70+
data_iter = data # type: ignore
71+
assert data_iter is not None, f"Unknown data type for data: {type(data)}"
5972
import xlsxwriter
6073

6174
ws.write_row(

bmsdna/table_rendering/html.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,45 @@
1-
from typing import Sequence, TYPE_CHECKING, Iterable, Callable
1+
from typing import Sequence, TYPE_CHECKING, Iterable, Callable, cast
22
import json
33
from bmsdna.table_rendering.config import ColumnConfig, ValueContext, format_value
44

55
if TYPE_CHECKING:
66
import polars as pl
7+
from pyspark.sql import DataFrame as SparkDataFrame
78

89

910
def render_html(
1011
configs: Sequence[ColumnConfig],
11-
data: "Iterable[dict] | pl.DataFrame",
12+
data: "Iterable[dict] | pl.DataFrame | SparkDataFrame",
1213
*,
1314
translator: Callable[[str, str], str] | None = None,
1415
add_classes: Sequence[str] | None = None,
1516
styles: str | dict[str, str] = "",
1617
tr_styles: str | dict[str, str] = "",
1718
td_styles: str | dict[str, str] = "",
1819
):
19-
try:
20-
import polars as pl
20+
data_iter: Iterable[dict] | None = None
21+
if data is None:
22+
data_iter = []
23+
if data_iter is None:
24+
try:
25+
from pyspark.sql import DataFrame as SparkDataFrame
2126

22-
if isinstance(data, pl.DataFrame):
23-
data_iter: Iterable[dict] = data.iter_rows(named=True)
24-
else:
25-
data_iter: Iterable[dict] = data
26-
except ImportError:
27-
data_iter: Iterable[dict] = data # type: ignore
27+
if isinstance(data, SparkDataFrame):
28+
data_iter = (d.asDict(True) for d in data.collect())
29+
except ImportError:
30+
pass
31+
if data_iter is None:
32+
try:
33+
import polars as pl
2834

35+
if isinstance(data, pl.DataFrame):
36+
data_iter = data.iter_rows(named=True)
37+
else:
38+
data_iter = cast(list[dict], data)
39+
except ImportError:
40+
data_iter = data # type: ignore
41+
42+
assert data_iter is not None, f"Unknown data type for data: {type(data)}"
2943
from dominate.tags import table, thead, tr, td, th, a, tbody
3044

3145
tbl = table()

bmsdna/table_rendering/spark.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import TYPE_CHECKING
2+
from bmsdna.table_rendering.config import ColumnConfig
3+
4+
if TYPE_CHECKING:
5+
from pyspark.sql import DataFrame
6+
7+
8+
def configs_from_pyspark(df: "DataFrame"):
9+
from pyspark.sql.types import (
10+
DateType,
11+
TimestampType,
12+
IntegerType,
13+
FloatType,
14+
LongType,
15+
DoubleType,
16+
DecimalType,
17+
)
18+
19+
configs = []
20+
for field in df.schema.fields:
21+
name = field.name
22+
dtype = field.dataType
23+
format_nr_decimals = None
24+
25+
if name.startswith("_") or name.startswith("mail_"):
26+
continue
27+
if isinstance(dtype, DateType):
28+
format_type = "date"
29+
elif isinstance(dtype, TimestampType):
30+
format_type = "datetime"
31+
elif isinstance(dtype, IntegerType) or isinstance(dtype, LongType):
32+
format_type = "int"
33+
elif isinstance(dtype, DecimalType):
34+
format_type = "int" if dtype.scale == 0 else "float"
35+
format_nr_decimals = dtype.scale
36+
elif isinstance(dtype, FloatType) or isinstance(dtype, DoubleType):
37+
format_type = "float"
38+
else:
39+
format_type = None
40+
configs.append(
41+
ColumnConfig(
42+
header=name,
43+
field=name,
44+
format=format_type,
45+
format_nr_decimals=format_nr_decimals,
46+
)
47+
)
48+
return configs

bmsdna/table_rendering/table_rendering.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from xlsxwriter.worksheet import Worksheet
2929
from xlsxwriter.workbook import Workbook
3030
import polars as pl
31+
from pyspark.sql import DataFrame as SparkDataFrame
3132

3233

3334
class TableRenderer:
@@ -39,6 +40,16 @@ def __init__(
3940
self.configs = configs
4041
self.translator = translator
4142

43+
@classmethod
44+
def from_spark(
45+
cls,
46+
data: "SparkDataFrame",
47+
translator: Callable[[str, str], str] | None = None,
48+
):
49+
from .spark import configs_from_pyspark
50+
51+
return cls(configs_from_pyspark(data), translator=translator)
52+
4253
@classmethod
4354
def from_df(
4455
cls, data: "pl.DataFrame", translator: Callable[[str, str], str] | None = None
@@ -108,7 +119,7 @@ def with_translator(self, translator: Callable[[str, str], str]):
108119

109120
def render_html(
110121
self,
111-
data: "Iterable[dict] | pl.DataFrame",
122+
data: "Iterable[dict] | pl.DataFrame | SparkDataFrame",
112123
*,
113124
add_classes: Sequence[str] | None = None,
114125
styles: str | dict[str, str] = "",
@@ -131,7 +142,7 @@ def render_into_sheet(
131142
self,
132143
ws: "Worksheet",
133144
wb: "Workbook",
134-
data: "Iterable[dict] | pl.DataFrame",
145+
data: "Iterable[dict] | pl.DataFrame | SparkDataFrame",
135146
sheet_options: SheetOptions = {},
136147
*,
137148
offset_rows: int = 0,
@@ -151,7 +162,7 @@ def render_into_sheet(
151162

152163
@overload
153164
def create_excel(
154-
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame, SheetOptions]]",
165+
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame | SparkDataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame, SheetOptions]]",
155166
excel: Path | None,
156167
*,
157168
workbook_options: dict | None = None,
@@ -160,21 +171,21 @@ def create_excel(
160171

161172
@overload
162173
def create_excel(
163-
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame, SheetOptions]]",
174+
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame, SheetOptions]]",
164175
*,
165176
workbook_options: dict | None = None,
166177
) -> Path: ...
167178

168179

169180
@overload
170181
def create_excel(
171-
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame, SheetOptions]]",
182+
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame, SheetOptions]]",
172183
excel: "Workbook",
173184
) -> None: ...
174185

175186

176187
def create_excel(
177-
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame, SheetOptions]]",
188+
sheets: "Mapping[str, tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame]|tuple[TableRenderer, list[dict]| pl.DataFrame| SparkDataFrame, SheetOptions]]",
178189
excel: "Path | Workbook | None" = None,
179190
*,
180191
workbook_options: dict | None = None,

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "bmsdna-table-rendering"
3-
version = "0.4.0"
3+
version = "0.5.0"
44
description = ""
55
authors = [{ name = "Adrian Ehrsam", email = "adrian.ehrsam@bmsuisse.ch" }]
66
dependencies = [
@@ -14,6 +14,9 @@ readme = "README.md"
1414

1515
[project.scripts]
1616

17+
[project.optional-dependencies]
18+
spark = ["pyspark>=3.5.5"]
19+
1720
[build-system]
1821
requires = ["hatchling"]
1922
build-backend = "hatchling.build"

tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import os
3+
from pathlib import Path
4+
5+
6+
@pytest.fixture(scope="session")
7+
def spark_session():
8+
if os.getenv("NO_SPARK", "0") == "1":
9+
return None
10+
if os.getenv("ODBCLAKE_TEST_CONFIGURATION", "spark").lower() != "spark":
11+
return None
12+
from pyspark.sql import SparkSession
13+
14+
jar = str(Path("tests/jar").absolute())
15+
builder = (
16+
SparkSession.builder.appName("test_spark") # type: ignore
17+
.config("spark.driver.extraClassPath", jar)
18+
.config("spark.executor.extraClassPath", jar)
19+
.config("spark.memory.fraction", 0.5)
20+
)
21+
22+
spark = builder.getOrCreate()
23+
24+
return spark

tests/test_formats_spark.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from datetime import date, datetime
2+
import os
3+
import polars as pl
4+
5+
from bmsdna.table_rendering.table_rendering import create_excel
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from pyspark.sql import SparkSession
11+
12+
13+
def test_excel(spark_session: "SparkSession"):
14+
fake_data = spark_session.createDataFrame(
15+
[
16+
{
17+
"a": 1,
18+
"b": 2.0,
19+
"chf": 234,
20+
"chf2": 234.67,
21+
"date": date.fromisoformat("2022-01-01"),
22+
"datetime": datetime.fromisoformat("2022-01-01 12:00:00"),
23+
},
24+
{
25+
"a": 2,
26+
"b": 3.0,
27+
"chf": 2345,
28+
"chf2": 2343.67,
29+
"date": date.fromisoformat("2025-01-01"),
30+
"datetime": datetime.fromisoformat("2022-01-01 18:00:00"),
31+
},
32+
]
33+
)
34+
from bmsdna.table_rendering import TableRenderer
35+
36+
rend = TableRenderer.from_spark(fake_data).with_overwritten_configs(
37+
{
38+
"a": {"format": "int"},
39+
"b": {"format": "float", "header_title": "B is a great col"},
40+
"chf": {"format": "currency:chf"},
41+
"chf2": {"format": "currency:chf"},
42+
}
43+
)
44+
os.makedirs("tests/_data", exist_ok=True)
45+
create_excel({"sheet1": (rend, fake_data)}, Path("tests/_data/test_excel.xlsx"))
46+
import openpyxl
47+
48+
workbook = openpyxl.load_workbook("tests/_data/test_excel.xlsx")
49+
50+
# Get the first sheet
51+
sheet = workbook.active
52+
assert sheet is not None
53+
# Get the number format in the "chf" column
54+
column_letter = "D" # Assuming "chf" is in column C
55+
column = sheet[column_letter]
56+
assert column[0].value == "chf2"
57+
assert column[1].value == 234.67
58+
assert "CHF" in column[1].number_format

uv.lock

Lines changed: 25 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)