|
33 | 33 |
|
34 | 34 | import pandas as pd
|
35 | 35 | import tzlocal
|
36 |
| -from hopsworks_common.core.constants import HAS_NUMPY |
| 36 | +from hopsworks_common.core.constants import HAS_NUMPY, HAS_PANDAS |
37 | 37 | from hsfs.constructor import query
|
38 | 38 |
|
39 | 39 | # in case importing in %%local
|
@@ -262,63 +262,11 @@ def _return_dataframe_type(self, dataframe, dataframe_type):
|
262 | 262 |
|
263 | 263 | def convert_to_default_dataframe(self, dataframe):
|
264 | 264 | if isinstance(dataframe, list):
|
265 |
| - #################### TODO TODO TODO TODO TODO #################### |
266 |
| - if HAS_NUMPY: |
267 |
| - dataframe = np.array(dataframe) |
268 |
| - else: |
269 |
| - try: |
270 |
| - dataframe[0][0] |
271 |
| - except TypeError: |
272 |
| - raise TypeError( |
273 |
| - "Cannot convert a list that has less than two dimensions to a dataframe." |
274 |
| - ) from None |
275 |
| - ok = False |
276 |
| - try: |
277 |
| - dataframe[0][0][0] |
278 |
| - except TypeError: |
279 |
| - ok = True |
280 |
| - if not ok: |
281 |
| - raise TypeError( |
282 |
| - "Cannot convert a list that has more than two dimensions to a dataframe." |
283 |
| - ) from None |
284 |
| - num_cols = len(dataframe[0]) |
285 |
| - dataframe_dict = {} |
286 |
| - for n_col in list(range(num_cols)): |
287 |
| - col_name = "col_" + str(n_col) |
288 |
| - dataframe_dict[col_name] = dataframe[:, n_col] |
289 |
| - dataframe = pd.DataFrame(dataframe_dict) |
290 |
| - |
291 |
| - if HAS_NUMPY and isinstance(dataframe, np.ndarray): |
292 |
| - if dataframe.ndim != 2: |
293 |
| - raise TypeError( |
294 |
| - "Cannot convert numpy array that do not have two dimensions to a dataframe. " |
295 |
| - "The number of dimensions are: {}".format(dataframe.ndim) |
296 |
| - ) |
297 |
| - num_cols = dataframe.shape[1] |
298 |
| - dataframe_dict = {} |
299 |
| - for n_col in list(range(num_cols)): |
300 |
| - col_name = "col_" + str(n_col) |
301 |
| - dataframe_dict[col_name] = dataframe[:, n_col] |
302 |
| - dataframe = pd.DataFrame(dataframe_dict) |
303 |
| - |
304 |
| - if isinstance(dataframe, pd.DataFrame): |
305 |
| - # convert timestamps to current timezone |
306 |
| - local_tz = tzlocal.get_localzone() |
307 |
| - # make shallow copy so the original df does not get changed |
308 |
| - dataframe_copy = dataframe.copy(deep=False) |
309 |
| - for c in dataframe_copy.columns: |
310 |
| - if isinstance( |
311 |
| - dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype |
312 |
| - ): |
313 |
| - # convert to utc timestamp |
314 |
| - dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None) |
315 |
| - if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"): |
316 |
| - # set the timezone to the client's timezone because that is |
317 |
| - # what spark expects. |
318 |
| - dataframe_copy[c] = dataframe_copy[c].dt.tz_localize( |
319 |
| - str(local_tz), ambiguous="infer", nonexistent="shift_forward" |
320 |
| - ) |
321 |
| - dataframe = self._spark_session.createDataFrame(dataframe_copy) |
| 265 | + dataframe = self.convert_list_to_spark_dataframe(dataframe) |
| 266 | + elif HAS_NUMPY and isinstance(dataframe, np.ndarray): |
| 267 | + dataframe = self.convert_numpy_to_spark_dataframe(dataframe) |
| 268 | + elif HAS_PANDAS and isinstance(dataframe, pd.DataFrame): |
| 269 | + dataframe = self.convert_pandas_to_spark_dataframe(dataframe) |
322 | 270 | elif isinstance(dataframe, RDD):
|
323 | 271 | dataframe = dataframe.toDF()
|
324 | 272 |
|
@@ -369,6 +317,81 @@ def convert_to_default_dataframe(self, dataframe):
|
369 | 317 | )
|
370 | 318 | )
|
371 | 319 |
|
| 320 | + def convert_list_to_spark_dataframe(self, dataframe): |
| 321 | + if HAS_NUMPY: |
| 322 | + return self.convert_numpy_to_spark_dataframe(np.array(dataframe)) |
| 323 | + try: |
| 324 | + dataframe[0][0] |
| 325 | + except TypeError: |
| 326 | + raise TypeError( |
| 327 | + "Cannot convert a list that has less than two dimensions to a dataframe." |
| 328 | + ) from None |
| 329 | + ok = False |
| 330 | + try: |
| 331 | + dataframe[0][0][0] |
| 332 | + except TypeError: |
| 333 | + ok = True |
| 334 | + if not ok: |
| 335 | + raise TypeError( |
| 336 | + "Cannot convert a list that has more than two dimensions to a dataframe." |
| 337 | + ) from None |
| 338 | + num_cols = len(dataframe[0]) |
| 339 | + if HAS_PANDAS: |
| 340 | + dataframe_dict = {} |
| 341 | + for n_col in range(num_cols): |
| 342 | + c = "col_" + str(n_col) |
| 343 | + dataframe_dict[c] = [dataframe[i][n_col] for i in range(len(dataframe))] |
| 344 | + return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict)) |
| 345 | + # We have neither numpy nor pandas, so there is no need to transform timestamps |
| 346 | + return self._spark_session.createDataFrame( |
| 347 | + dataframe, ["col_" + str(n) for n in range(num_cols)] |
| 348 | + ) |
| 349 | + |
| 350 | + def convert_numpy_to_spark_dataframe(self, dataframe): |
| 351 | + if dataframe.ndim != 2: |
| 352 | + raise TypeError( |
| 353 | + "Cannot convert numpy array that do not have two dimensions to a dataframe. " |
| 354 | + "The number of dimensions are: {}".format(dataframe.ndim) |
| 355 | + ) |
| 356 | + num_cols = dataframe.shape[1] |
| 357 | + if HAS_PANDAS: |
| 358 | + dataframe_dict = {} |
| 359 | + for n_col in range(num_cols): |
| 360 | + c = "col_" + str(n_col) |
| 361 | + dataframe_dict[c] = dataframe[:, n_col] |
| 362 | + return self.convert_pandas_to_spark_dataframe(pd.DataFrame(dataframe_dict)) |
| 363 | + # convert timestamps to current timezone |
| 364 | + local_tz = tzlocal.get_localzone() |
| 365 | + for n_col in range(num_cols): |
| 366 | + if dataframe[:, n_col].dtype == np.dtype("datetime64[ns]"): |
| 367 | + # set the timezone to the client's timezone because that is |
| 368 | + # what spark expects. |
| 369 | + dataframe[:, n_col] = dataframe[:, n_col].map( |
| 370 | + lambda d: local_tz.fromutc(d.item().astimezone(local_tz)) |
| 371 | + ) |
| 372 | + return self._spark_session.createDataFrame( |
| 373 | + dataframe.tolist(), ["col_" + str(n) for n in range(num_cols)] |
| 374 | + ) |
| 375 | + |
| 376 | + def convert_pandas_to_spark_dataframe(self, dataframe): |
| 377 | + # convert timestamps to current timezone |
| 378 | + local_tz = tzlocal.get_localzone() |
| 379 | + # make shallow copy so the original df does not get changed |
| 380 | + dataframe_copy = dataframe.copy(deep=False) |
| 381 | + for c in dataframe_copy.columns: |
| 382 | + if isinstance( |
| 383 | + dataframe_copy[c].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype |
| 384 | + ): |
| 385 | + # convert to utc timestamp |
| 386 | + dataframe_copy[c] = dataframe_copy[c].dt.tz_convert(None) |
| 387 | + if HAS_NUMPY and dataframe_copy[c].dtype == np.dtype("datetime64[ns]"): |
| 388 | + # set the timezone to the client's timezone because that is |
| 389 | + # what spark expects. |
| 390 | + dataframe_copy[c] = dataframe_copy[c].dt.tz_localize( |
| 391 | + str(local_tz), ambiguous="infer", nonexistent="shift_forward" |
| 392 | + ) |
| 393 | + return self._spark_session.createDataFrame(dataframe_copy) |
| 394 | + |
372 | 395 | def save_dataframe(
|
373 | 396 | self,
|
374 | 397 | feature_group,
|
|
0 commit comments