|
31 | 31 | from pyspark.rdd import RDD
|
32 | 32 | from pyspark.sql import DataFrame
|
33 | 33 |
|
34 |
| -import numpy as np |
35 | 34 | import pandas as pd
|
36 | 35 | import tzlocal
|
| 36 | +from hopsworks_common.core.constants import HAS_NUMPY, HAS_PANDAS |
37 | 37 | from hsfs.constructor import query
|
38 | 38 |
|
39 | 39 | # in case importing in %%local
|
40 | 40 | from hsfs.core.vector_db_client import VectorDbClient
|
41 | 41 |
|
42 | 42 |
|
| 43 | +if HAS_NUMPY: |
| 44 | + import numpy as np |
| 45 | + |
| 46 | + |
43 | 47 | try:
|
44 | 48 | import pyspark
|
45 | 49 | from pyspark import SparkFiles
|
@@ -258,39 +262,11 @@ def _return_dataframe_type(self, dataframe, dataframe_type):
|
258 | 262 |
|
259 | 263 | def convert_to_default_dataframe(self, dataframe):
|
260 | 264 | if isinstance(dataframe, list):
|
261 |
| - dataframe = np.array(dataframe) |
262 |
| - |
263 |
| - if isinstance(dataframe, np.ndarray): |
264 |
| - if dataframe.ndim != 2: |
265 |
| - raise TypeError( |
266 |
| - "Cannot convert numpy array that do not have two dimensions to a dataframe. " |
267 |
| - "The number of dimensions are: {}".format(dataframe.ndim) |
268 |
| - ) |
269 |
| - num_cols = dataframe.shape[1] |
270 |
| - dataframe_dict = {} |
271 |
| - for n_col in list(range(num_cols)): |
272 |
| - col_name = "col_" + str(n_col) |
273 |
| - dataframe_dict[col_name] = dataframe[:, n_col] |
274 |
| - dataframe = pd.DataFrame(dataframe_dict) |
275 |
| - |
276 |
| - if isinstance(dataframe, pd.DataFrame): |
277 |
| - # convert timestamps to current timezone |
278 |
| - local_tz = tzlocal.get_localzone() |
279 |
| - # make shallow copy so the original df does not get changed |
280 |
| - dataframe_copy = dataframe.copy(deep=False) |
281 |
| - for c in dataframe_copy.columns: |
282 |
| - if isinstance( |
283 |
| - dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype |
284 |
| - ): |
285 |
| - # convert to utc timestamp |
286 |
| - dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None) |
287 |
| - if dataframe_copy[c].dtype == np.dtype("datetime64[ns]"): |
288 |
| - # set the timezone to the client's timezone because that is |
289 |
| - # what spark expects. |
290 |
| - dataframe_copy[c] = dataframe_copy[c].dt.tz_localize( |
291 |
| - str(local_tz), ambiguous="infer", nonexistent="shift_forward" |
292 |
| - ) |
293 |
| - dataframe = self._spark_session.createDataFrame(dataframe_copy) |
| 265 | + dataframe = self.convert_list_to_spark_dataframe(dataframe) |
| 266 | + elif HAS_NUMPY and isinstance(dataframe, np.ndarray): |
| 267 | + dataframe = self.convert_numpy_to_spark_dataframe(dataframe) |
| 268 | + elif HAS_PANDAS and isinstance(dataframe, pd.DataFrame): |
| 269 | + dataframe = self.convert_pandas_to_spark_dataframe(dataframe) |
294 | 270 | elif isinstance(dataframe, RDD):
|
295 | 271 | dataframe = dataframe.toDF()
|
296 | 272 |
|
@@ -341,6 +317,92 @@ def convert_to_default_dataframe(self, dataframe):
|
341 | 317 | )
|
342 | 318 | )
|
343 | 319 |
|
| 320 | + @staticmethod |
| 321 | + def utc_disguised_as_local(dt): |
| 322 | + local_tz = tzlocal.get_localzone() |
| 323 | + utc = timezone.utc |
| 324 | + if not dt.tzinfo: |
| 325 | + dt = dt.replace(tzinfo=utc) |
| 326 | + return dt.astimezone(utc).replace(tzinfo=local_tz) |
| 327 | + |
| 328 | + def convert_list_to_spark_dataframe(self, dataframe): |
| 329 | + if HAS_NUMPY: |
| 330 | + return self.convert_numpy_to_spark_dataframe(np.array(dataframe)) |
| 331 | + try: |
| 332 | + dataframe[0][0] |
| 333 | + except TypeError: |
| 334 | + raise TypeError( |
| 335 | + "Cannot convert a list that has less than two dimensions to a dataframe." |
| 336 | + ) from None |
| 337 | + ok = False |
| 338 | + try: |
| 339 | + dataframe[0][0][0] |
| 340 | + except TypeError: |
| 341 | + ok = True |
| 342 | + if not ok: |
| 343 | + raise TypeError( |
| 344 | + "Cannot convert a list that has more than two dimensions to a dataframe." |
| 345 | + ) from None |
| 346 | + num_cols = len(dataframe[0]) |
| 347 | + if HAS_PANDAS: |
| 348 | + dataframe_dict = {} |
| 349 | + for n_col in range(num_cols): |
| 350 | + c = "col_" + str(n_col) |
| 351 | + dataframe_dict[c] = [dataframe[i][n_col] for i in range(len(dataframe))] |
| 352 | + return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict)) |
| 353 | + for i in range(len(dataframe)): |
| 354 | + dataframe[i] = [ |
| 355 | + self.utc_disguised_as_local(d) if isinstance(d, datetime) else d |
| 356 | + for d in dataframe[i] |
| 357 | + ] |
| 358 | + return self._spark_session.createDataFrame( |
| 359 | + dataframe, ["col_" + str(n) for n in range(num_cols)] |
| 360 | + ) |
| 361 | + |
| 362 | + def convert_numpy_to_spark_dataframe(self, dataframe): |
| 363 | + if dataframe.ndim != 2: |
| 364 | + raise TypeError( |
| 365 | + "Cannot convert numpy array that do not have two dimensions to a dataframe. " |
| 366 | + "The number of dimensions are: {}".format(dataframe.ndim) |
| 367 | + ) |
| 368 | + num_cols = dataframe.shape[1] |
| 369 | + if HAS_PANDAS: |
| 370 | + dataframe_dict = {} |
| 371 | + for n_col in range(num_cols): |
| 372 | + c = "col_" + str(n_col) |
| 373 | + dataframe_dict[c] = dataframe[:, n_col] |
| 374 | + return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict)) |
| 375 | + # convert timestamps to current timezone |
| 376 | + for n_col in range(num_cols): |
| 377 | + if dataframe[:, n_col].dtype == np.dtype("datetime64[ns]"): |
| 378 | + # set the timezone to the client's timezone because that is |
| 379 | + # what spark expects. |
| 380 | + dataframe[:, n_col] = np.array( |
| 381 | + [self.utc_disguised_as_local(d.item()) for d in dataframe[:, n_col]] |
| 382 | + ) |
| 383 | + return self._spark_session.createDataFrame( |
| 384 | + dataframe.tolist(), ["col_" + str(n) for n in range(num_cols)] |
| 385 | + ) |
| 386 | + |
| 387 | + def convert_pandas_to_spark_dataframe(self, dataframe): |
| 388 | + # convert timestamps to current timezone |
| 389 | + local_tz = tzlocal.get_localzone() |
| 390 | + # make shallow copy so the original df does not get changed |
| 391 | + dataframe_copy = dataframe.copy(deep=False) |
| 392 | + for c in dataframe_copy.columns: |
| 393 | + if isinstance( |
| 394 | + dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype |
| 395 | + ): |
| 396 | + # convert to utc timestamp |
| 397 | + dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None) |
| 398 | + if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"): |
| 399 | + # set the timezone to the client's timezone because that is |
| 400 | + # what spark expects. |
| 401 | + dataframe_copy[c] = dataframe_copy[c].dt.tz_localize( |
| 402 | + str(local_tz), ambiguous="infer", nonexistent="shift_forward" |
| 403 | + ) |
| 404 | + return self._spark_session.createDataFrame(dataframe_copy) |
| 405 | + |
344 | 406 | def save_dataframe(
|
345 | 407 | self,
|
346 | 408 | feature_group,
|
|
0 commit comments