Skip to content

Commit a5c28e7

Browse files
authored
[FSTORE-1468] Make numpy optional (#338)
* Remove dependency on numpy except from convert_to_default_dataframe * Ruff * Update pyproject.toml extras * Fix * Fix * Attempt making numpy optional in convert_to_default_dataframe * Address Manu's review
1 parent 147af3e commit a5c28e7

18 files changed

+195
-69
lines changed

locust_benchmark/create_feature_group.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from common.hopsworks_client import HopsworksClient
22

33
if __name__ == "__main__":
4-
54
hopsworks_client = HopsworksClient()
65
fg = hopsworks_client.get_or_create_fg()
76
hopsworks_client.insert_data(fg)

python/hopsworks_common/core/constants.py

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
# Avro
2121
HAS_FAST_AVRO: bool = importlib.util.find_spec("fastavro") is not None
2222
HAS_AVRO: bool = importlib.util.find_spec("avro") is not None
23+
avro_not_installed_message = (
24+
"Avro package not found. "
25+
"If you want to use avro with Hopsworks you can install the corresponding extra via "
26+
'`pip install "hopsworks[avro]"`. '
27+
"You can also install avro directly in your environment with `pip install fastavro` or `pip install avro`. "
28+
"You will need to restart your kernel if applicable."
29+
)
2330

2431
# Confluent Kafka
2532
HAS_CONFLUENT_KAFKA: bool = importlib.util.find_spec("confluent_kafka") is not None
@@ -55,7 +62,17 @@
5562
)
5663

5764
HAS_PANDAS: bool = importlib.util.find_spec("pandas") is not None
65+
66+
# NumPy
5867
HAS_NUMPY: bool = importlib.util.find_spec("numpy") is not None
68+
numpy_not_installed_message = (
69+
"Numpy package not found. "
70+
"If you want to use numpy with Hopsworks you can install the corresponding extra via "
71+
'`pip install "hopsworks[numpy]"`. '
72+
"You can also install numpy directly in your environment with `pip install numpy`. "
73+
"You will need to restart your kernel if applicable."
74+
)
75+
5976
HAS_POLARS: bool = importlib.util.find_spec("polars") is not None
6077
polars_not_installed_message = (
6178
"Polars package not found. "

python/hsfs/builtin_transformations.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515
#
1616

17-
import numpy as np
17+
import math
18+
1819
import pandas as pd
1920
from hsfs.hopsworks_udf import udf
2021
from hsfs.transformation_statistics import TransformationStatistics
@@ -49,7 +50,7 @@ def label_encoder(feature: pd.Series, statistics=feature_statistics) -> pd.Serie
4950
# Unknown categories not present in training dataset are encoded as -1.
5051
return pd.Series(
5152
[
52-
value_to_index.get(data, -1) if not pd.isna(data) else np.nan
53+
value_to_index.get(data, -1) if not pd.isna(data) else math.nan
5354
for data in feature
5455
]
5556
)

python/hsfs/constructor/query.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
2222

2323
import humps
24-
import numpy as np
2524
import pandas as pd
2625
from hopsworks_common.client.exceptions import FeatureStoreException
26+
from hopsworks_common.core.constants import HAS_NUMPY
2727
from hsfs import engine, storage_connector, util
2828
from hsfs import feature_group as fg_mod
2929
from hsfs.constructor import join
@@ -34,6 +34,10 @@
3434
from hsfs.feature import Feature
3535

3636

37+
if HAS_NUMPY:
38+
import numpy as np
39+
40+
3741
@typechecked
3842
class Query:
3943
ERROR_MESSAGE_FEATURE_AMBIGUOUS = (

python/hsfs/core/feature_view_engine.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import warnings
2020
from typing import Any, Dict, List, Optional, TypeVar, Union
2121

22-
import numpy as np
2322
import pandas as pd
2423
from hopsworks_common import client
2524
from hopsworks_common.client.exceptions import FeatureStoreException
25+
from hopsworks_common.core.constants import HAS_NUMPY
2626
from hsfs import (
2727
engine,
2828
feature_group,
@@ -45,6 +45,10 @@
4545
from hsfs.training_dataset_split import TrainingDatasetSplit
4646

4747

48+
if HAS_NUMPY:
49+
import numpy as np
50+
51+
4852
class FeatureViewEngine:
4953
ENTITY_TYPE = "featureview"
5054
_TRAINING_DATA_API_PATH = "trainingdatasets"

python/hsfs/core/kafka_engine.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
from io import BytesIO
2121
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Union
2222

23-
import numpy as np
2423
import pandas as pd
2524
from hopsworks_common import client
25+
from hopsworks_common.core.constants import HAS_NUMPY
2626
from hsfs.core import storage_connector_api
2727
from hsfs.core.constants import HAS_AVRO, HAS_CONFLUENT_KAFKA, HAS_FAST_AVRO
2828
from tqdm import tqdm
2929

3030

31+
if HAS_NUMPY:
32+
import numpy as np
33+
3134
if HAS_CONFLUENT_KAFKA:
3235
from confluent_kafka import Consumer, KafkaError, Producer, TopicPartition
3336

@@ -202,7 +205,7 @@ def encode_row(complex_feature_writers, writer, row):
202205
if isinstance(row, dict):
203206
for k in row.keys():
204207
# for avro to be able to serialize them, they need to be python data types
205-
if isinstance(row[k], np.ndarray):
208+
if HAS_NUMPY and isinstance(row[k], np.ndarray):
206209
row[k] = row[k].tolist()
207210
if isinstance(row[k], pd.Timestamp):
208211
row[k] = row[k].to_pydatetime()

python/hsfs/core/vector_server.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
from io import BytesIO
2424
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union
2525

26-
import avro.io
27-
import avro.schema
28-
import numpy as np
2926
import pandas as pd
3027
from hopsworks_common import client
3128
from hopsworks_common.core.constants import (
29+
HAS_AVRO,
3230
HAS_FAST_AVRO,
31+
HAS_NUMPY,
3332
HAS_POLARS,
33+
avro_not_installed_message,
34+
numpy_not_installed_message,
3435
polars_not_installed_message,
3536
)
3637
from hsfs import (
@@ -52,9 +53,14 @@
5253
)
5354

5455

56+
if HAS_NUMPY:
57+
import numpy as np
58+
5559
if HAS_FAST_AVRO:
5660
from fastavro import schemaless_reader
57-
else:
61+
if HAS_AVRO:
62+
import avro.io
63+
import avro.schema
5864
from avro.io import BinaryDecoder
5965

6066
if HAS_POLARS:
@@ -807,6 +813,8 @@ def handle_feature_vector_return_type(
807813
return feature_vectorz
808814
elif return_type.lower() == "numpy" and not inference_helper:
809815
_logger.debug("Returning feature vector as numpy array")
816+
if not HAS_NUMPY:
817+
raise ModuleNotFoundError(numpy_not_installed_message)
810818
return np.array(feature_vectorz)
811819
# Only inference helper can return dict
812820
elif return_type.lower() == "dict" and inference_helper:
@@ -1064,6 +1072,9 @@ def build_complex_feature_decoders(self) -> Dict[str, Callable]:
10641072
- deserialization of complex features from the online feature store
10651073
- conversion of string or int timestamps to datetime objects
10661074
"""
1075+
if not HAS_AVRO:
1076+
raise ModuleNotFoundError(avro_not_installed_message)
1077+
10671078
complex_feature_schemas = {
10681079
f.name: avro.io.DatumReader(
10691080
avro.schema.parse(

python/hsfs/engine/python.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050

5151
import boto3
5252
import hsfs
53-
import numpy as np
5453
import pandas as pd
5554
import pyarrow as pa
5655
from botocore.response import StreamingBody
@@ -83,6 +82,7 @@
8382
from hsfs.core.constants import (
8483
HAS_AIOMYSQL,
8584
HAS_GREAT_EXPECTATIONS,
85+
HAS_NUMPY,
8686
HAS_PANDAS,
8787
HAS_PYARROW,
8888
HAS_SQLALCHEMY,
@@ -98,6 +98,9 @@
9898
if HAS_GREAT_EXPECTATIONS:
9999
import great_expectations
100100

101+
if HAS_NUMPY:
102+
import numpy as np
103+
101104
if HAS_AIOMYSQL and HAS_SQLALCHEMY:
102105
from hsfs.core import util_sql
103106

@@ -1464,11 +1467,13 @@ def _start_offline_materialization(offline_write_options: Dict[str, Any]) -> boo
14641467
def _convert_feature_log_to_df(feature_log, cols) -> pd.DataFrame:
14651468
if feature_log is None and cols:
14661469
return pd.DataFrame(columns=cols)
1467-
if not (
1468-
isinstance(feature_log, (list, np.ndarray, pd.DataFrame, pl.DataFrame))
1470+
if not (isinstance(feature_log, (list, pd.DataFrame, pl.DataFrame))) or (
1471+
HAS_NUMPY and isinstance(feature_log, np.ndarray)
14691472
):
14701473
raise ValueError(f"Type '{type(feature_log)}' not accepted")
1471-
if isinstance(feature_log, list) or isinstance(feature_log, np.ndarray):
1474+
if isinstance(feature_log, list) or (
1475+
HAS_NUMPY and isinstance(feature_log, np.ndarray)
1476+
):
14721477
Engine._validate_logging_list(feature_log, cols)
14731478
return pd.DataFrame(feature_log, columns=cols)
14741479
else:
@@ -1479,7 +1484,9 @@ def _convert_feature_log_to_df(feature_log, cols) -> pd.DataFrame:
14791484

14801485
@staticmethod
14811486
def _validate_logging_list(feature_log, cols):
1482-
if isinstance(feature_log[0], list) or isinstance(feature_log[0], np.ndarray):
1487+
if isinstance(feature_log[0], list) or (
1488+
HAS_NUMPY and isinstance(feature_log[0], np.ndarray)
1489+
):
14831490
provided_len = len(feature_log[0])
14841491
else:
14851492
provided_len = 1

python/hsfs/engine/spark.py

+96-34
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,19 @@
3131
from pyspark.rdd import RDD
3232
from pyspark.sql import DataFrame
3333

34-
import numpy as np
3534
import pandas as pd
3635
import tzlocal
36+
from hopsworks_common.core.constants import HAS_NUMPY, HAS_PANDAS
3737
from hsfs.constructor import query
3838

3939
# in case importing in %%local
4040
from hsfs.core.vector_db_client import VectorDbClient
4141

4242

43+
if HAS_NUMPY:
44+
import numpy as np
45+
46+
4347
try:
4448
import pyspark
4549
from pyspark import SparkFiles
@@ -258,39 +262,11 @@ def _return_dataframe_type(self, dataframe, dataframe_type):
258262

259263
def convert_to_default_dataframe(self, dataframe):
260264
if isinstance(dataframe, list):
261-
dataframe = np.array(dataframe)
262-
263-
if isinstance(dataframe, np.ndarray):
264-
if dataframe.ndim != 2:
265-
raise TypeError(
266-
"Cannot convert numpy array that do not have two dimensions to a dataframe. "
267-
"The number of dimensions are: {}".format(dataframe.ndim)
268-
)
269-
num_cols = dataframe.shape[1]
270-
dataframe_dict = {}
271-
for n_col in list(range(num_cols)):
272-
col_name = "col_" + str(n_col)
273-
dataframe_dict[col_name] = dataframe[:, n_col]
274-
dataframe = pd.DataFrame(dataframe_dict)
275-
276-
if isinstance(dataframe, pd.DataFrame):
277-
# convert timestamps to current timezone
278-
local_tz = tzlocal.get_localzone()
279-
# make shallow copy so the original df does not get changed
280-
dataframe_copy = dataframe.copy(deep=False)
281-
for c in dataframe_copy.columns:
282-
if isinstance(
283-
dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype
284-
):
285-
# convert to utc timestamp
286-
dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None)
287-
if dataframe_copy[c].dtype == np.dtype("datetime64[ns]"):
288-
# set the timezone to the client's timezone because that is
289-
# what spark expects.
290-
dataframe_copy[c] = dataframe_copy[c].dt.tz_localize(
291-
str(local_tz), ambiguous="infer", nonexistent="shift_forward"
292-
)
293-
dataframe = self._spark_session.createDataFrame(dataframe_copy)
265+
dataframe = self.convert_list_to_spark_dataframe(dataframe)
266+
elif HAS_NUMPY and isinstance(dataframe, np.ndarray):
267+
dataframe = self.convert_numpy_to_spark_dataframe(dataframe)
268+
elif HAS_PANDAS and isinstance(dataframe, pd.DataFrame):
269+
dataframe = self.convert_pandas_to_spark_dataframe(dataframe)
294270
elif isinstance(dataframe, RDD):
295271
dataframe = dataframe.toDF()
296272

@@ -341,6 +317,92 @@ def convert_to_default_dataframe(self, dataframe):
341317
)
342318
)
343319

320+
@staticmethod
321+
def utc_disguised_as_local(dt):
322+
local_tz = tzlocal.get_localzone()
323+
utc = timezone.utc
324+
if not dt.tzinfo:
325+
dt = dt.replace(tzinfo=utc)
326+
return dt.astimezone(utc).replace(tzinfo=local_tz)
327+
328+
def convert_list_to_spark_dataframe(self, dataframe):
329+
if HAS_NUMPY:
330+
return self.convert_numpy_to_spark_dataframe(np.array(dataframe))
331+
try:
332+
dataframe[0][0]
333+
except TypeError:
334+
raise TypeError(
335+
"Cannot convert a list that has less than two dimensions to a dataframe."
336+
) from None
337+
ok = False
338+
try:
339+
dataframe[0][0][0]
340+
except TypeError:
341+
ok = True
342+
if not ok:
343+
raise TypeError(
344+
"Cannot convert a list that has more than two dimensions to a dataframe."
345+
) from None
346+
num_cols = len(dataframe[0])
347+
if HAS_PANDAS:
348+
dataframe_dict = {}
349+
for n_col in range(num_cols):
350+
c = "col_" + str(n_col)
351+
dataframe_dict[c] = [dataframe[i][n_col] for i in range(len(dataframe))]
352+
return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict))
353+
for i in range(len(dataframe)):
354+
dataframe[i] = [
355+
self.utc_disguised_as_local(d) if isinstance(d, datetime) else d
356+
for d in dataframe[i]
357+
]
358+
return self._spark_session.createDataFrame(
359+
dataframe, ["col_" + str(n) for n in range(num_cols)]
360+
)
361+
362+
def convert_numpy_to_spark_dataframe(self, dataframe):
363+
if dataframe.ndim != 2:
364+
raise TypeError(
365+
"Cannot convert numpy array that do not have two dimensions to a dataframe. "
366+
"The number of dimensions are: {}".format(dataframe.ndim)
367+
)
368+
num_cols = dataframe.shape[1]
369+
if HAS_PANDAS:
370+
dataframe_dict = {}
371+
for n_col in range(num_cols):
372+
c = "col_" + str(n_col)
373+
dataframe_dict[c] = dataframe[:, n_col]
374+
return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict))
375+
# convert timestamps to current timezone
376+
for n_col in range(num_cols):
377+
if dataframe[:, n_col].dtype == np.dtype("datetime64[ns]"):
378+
# set the timezone to the client's timezone because that is
379+
# what spark expects.
380+
dataframe[:, n_col] = np.array(
381+
[self.utc_disguised_as_local(d.item()) for d in dataframe[:, n_col]]
382+
)
383+
return self._spark_session.createDataFrame(
384+
dataframe.tolist(), ["col_" + str(n) for n in range(num_cols)]
385+
)
386+
387+
def convert_pandas_to_spark_dataframe(self, dataframe):
388+
# convert timestamps to current timezone
389+
local_tz = tzlocal.get_localzone()
390+
# make shallow copy so the original df does not get changed
391+
dataframe_copy = dataframe.copy(deep=False)
392+
for c in dataframe_copy.columns:
393+
if isinstance(
394+
dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype
395+
):
396+
# convert to utc timestamp
397+
dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None)
398+
if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"):
399+
# set the timezone to the client's timezone because that is
400+
# what spark expects.
401+
dataframe_copy[c] = dataframe_copy[c].dt.tz_localize(
402+
str(local_tz), ambiguous="infer", nonexistent="shift_forward"
403+
)
404+
return self._spark_session.createDataFrame(dataframe_copy)
405+
344406
def save_dataframe(
345407
self,
346408
feature_group,

0 commit comments

Comments
 (0)