Skip to content

Commit bfbe7aa

Browse files
committed
Attempt making numpy optional in convert_to_default_dataframe
1 parent f06ee3f commit bfbe7aa

File tree

1 file changed

+81
-58
lines changed

1 file changed

+81
-58
lines changed

python/hsfs/engine/spark.py

+81-58
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import pandas as pd
3535
import tzlocal
36-
from hopsworks_common.core.constants import HAS_NUMPY
36+
from hopsworks_common.core.constants import HAS_NUMPY, HAS_PANDAS
3737
from hsfs.constructor import query
3838

3939
# in case importing in %%local
@@ -262,63 +262,11 @@ def _return_dataframe_type(self, dataframe, dataframe_type):
262262

263263
def convert_to_default_dataframe(self, dataframe):
264264
if isinstance(dataframe, list):
265-
#################### TODO TODO TODO TODO TODO ####################
266-
if HAS_NUMPY:
267-
dataframe = np.array(dataframe)
268-
else:
269-
try:
270-
dataframe[0][0]
271-
except TypeError:
272-
raise TypeError(
273-
"Cannot convert a list that has less than two dimensions to a dataframe."
274-
) from None
275-
ok = False
276-
try:
277-
dataframe[0][0][0]
278-
except TypeError:
279-
ok = True
280-
if not ok:
281-
raise TypeError(
282-
"Cannot convert a list that has more than two dimensions to a dataframe."
283-
) from None
284-
num_cols = len(dataframe[0])
285-
dataframe_dict = {}
286-
for n_col in list(range(num_cols)):
287-
col_name = "col_" + str(n_col)
288-
dataframe_dict[col_name] = dataframe[:, n_col]
289-
dataframe = pd.DataFrame(dataframe_dict)
290-
291-
if HAS_NUMPY and isinstance(dataframe, np.ndarray):
292-
if dataframe.ndim != 2:
293-
raise TypeError(
294-
"Cannot convert numpy array that do not have two dimensions to a dataframe. "
295-
"The number of dimensions are: {}".format(dataframe.ndim)
296-
)
297-
num_cols = dataframe.shape[1]
298-
dataframe_dict = {}
299-
for n_col in list(range(num_cols)):
300-
col_name = "col_" + str(n_col)
301-
dataframe_dict[col_name] = dataframe[:, n_col]
302-
dataframe = pd.DataFrame(dataframe_dict)
303-
304-
if isinstance(dataframe, pd.DataFrame):
305-
# convert timestamps to current timezone
306-
local_tz = tzlocal.get_localzone()
307-
# make shallow copy so the original df does not get changed
308-
dataframe_copy = dataframe.copy(deep=False)
309-
for c in dataframe_copy.columns:
310-
if isinstance(
311-
dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype
312-
):
313-
# convert to utc timestamp
314-
dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None)
315-
if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"):
316-
# set the timezone to the client's timezone because that is
317-
# what spark expects.
318-
dataframe_copy[c] = dataframe_copy[c].dt.tz_localize(
319-
str(local_tz), ambiguous="infer", nonexistent="shift_forward"
320-
)
321-
dataframe = self._spark_session.createDataFrame(dataframe_copy)
265+
dataframe = self.convert_list_to_spark_dataframe(dataframe)
266+
elif HAS_NUMPY and isinstance(dataframe, np.ndarray):
267+
dataframe = self.convert_numpy_to_spark_dataframe(dataframe)
268+
elif HAS_PANDAS and isinstance(dataframe, pd.DataFrame):
269+
dataframe = self.convert_pandas_to_spark_dataframe(dataframe)
322270
elif isinstance(dataframe, RDD):
323271
dataframe = dataframe.toDF()
324272

@@ -369,6 +317,81 @@ def convert_to_default_dataframe(self, dataframe):
369317
)
370318
)
371319

320+
def convert_list_to_spark_dataframe(self, dataframe):
321+
if HAS_NUMPY:
322+
return self.convert_numpy_to_spark_dataframe(np.array(dataframe))
323+
try:
324+
dataframe[0][0]
325+
except TypeError:
326+
raise TypeError(
327+
"Cannot convert a list that has less than two dimensions to a dataframe."
328+
) from None
329+
ok = False
330+
try:
331+
dataframe[0][0][0]
332+
except TypeError:
333+
ok = True
334+
if not ok:
335+
raise TypeError(
336+
"Cannot convert a list that has more than two dimensions to a dataframe."
337+
) from None
338+
num_cols = len(dataframe[0])
339+
if HAS_PANDAS:
340+
dataframe_dict = {}
341+
for n_col in range(num_cols):
342+
c = "col_" + str(n_col)
343+
dataframe_dict[c] = [dataframe[i][n_col] for i in range(len(dataframe))]
344+
return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict))
345+
# We have neither numpy nor pandas, so there is no need to transform timestamps
346+
return self._spark_session.createDataFrame(
347+
dataframe, ["col_" + str(n) for n in range(num_cols)]
348+
)
349+
350+
def convert_numpy_to_spark_dataframe(self, dataframe):
351+
if dataframe.ndim != 2:
352+
raise TypeError(
353+
"Cannot convert numpy array that do not have two dimensions to a dataframe. "
354+
"The number of dimensions are: {}".format(dataframe.ndim)
355+
)
356+
num_cols = dataframe.shape[1]
357+
if HAS_PANDAS:
358+
dataframe_dict = {}
359+
for n_col in range(num_cols):
360+
c = "col_" + str(n_col)
361+
dataframe_dict[c] = dataframe[:, n_col]
362+
return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict))
363+
# convert timestamps to current timezone
364+
local_tz = tzlocal.get_localzone()
365+
for n_col in range(num_cols):
366+
if dataframe[:, n_col].dtype == np.dtype("datetime64[ns]"):
367+
# set the timezone to the client's timezone because that is
368+
# what spark expects.
369+
dataframe[:, n_col] = dataframe[:, n_col].map(
370+
lambda d: local_tz.fromutc(d.item().astimezone(local_tz))
371+
)
372+
return self._spark_session.createDataFrame(
373+
dataframe.tolist(), ["col_" + str(n) for n in range(num_cols)]
374+
)
375+
376+
def convert_pandas_to_spark_dataframe(self, dataframe):
377+
# convert timestamps to current timezone
378+
local_tz = tzlocal.get_localzone()
379+
# make shallow copy so the original df does not get changed
380+
dataframe_copy = dataframe.copy(deep=False)
381+
for c in dataframe_copy.columns:
382+
if isinstance(
383+
dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype
384+
):
385+
# convert to utc timestamp
386+
dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None)
387+
if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"):
388+
# set the timezone to the client's timezone because that is
389+
# what spark expects.
390+
dataframe_copy[c] = dataframe_copy[c].dt.tz_localize(
391+
str(local_tz), ambiguous="infer", nonexistent="shift_forward"
392+
)
393+
return self._spark_session.createDataFrame(dataframe_copy)
394+
372395
def save_dataframe(
373396
self,
374397
feature_group,

0 commit comments

Comments
 (0)