|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
18 | 18 | import json
|
| 19 | +import logging |
19 | 20 | import warnings
|
20 | 21 | from datetime import date, datetime
|
21 | 22 | from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
|
34 | 35 | from hsfs.feature import Feature
|
35 | 36 |
|
36 | 37 |
|
| 38 | +_logger = logging.getLogger(__name__) |
| 39 | + |
| 40 | + |
37 | 41 | if HAS_NUMPY:
|
38 | 42 | import numpy as np
|
39 | 43 |
|
@@ -138,6 +142,100 @@ def _prep_read(
|
138 | 142 |
|
139 | 143 | return sql_query, online_conn
|
140 | 144 |
|
| 145 | + def check_ambiguous_features(self) -> None: |
| 146 | + self._ambiguous_features_in_query = self.get_ambiguous_features() |
| 147 | + if self._ambiguous_features_in_query: |
| 148 | + ambiguous_features_warning_str = ( |
| 149 | + "Ambiguous features found in the query. The feature(s) " |
| 150 | + ) |
| 151 | + for ( |
| 152 | + feature_group_name, |
| 153 | + features, |
| 154 | + ) in self._ambiguous_features_in_query.items(): |
| 155 | + feature_names = "`, `".join(features) |
| 156 | + ambiguous_features_warning_str += ( |
| 157 | + f"`{feature_names}` in feature group `{feature_group_name}`, " |
| 158 | + ) |
| 159 | + ambiguous_features_warning_str += "is ambiguous. Automatically prefixing the features selected in these feature groups with the feature group name." |
| 160 | + _logger.warning(ambiguous_features_warning_str) |
| 161 | + |
| 162 | + def get_ambiguous_features_in_joins( |
| 163 | + self, |
| 164 | + joins: List[join.Join], |
| 165 | + selected_features: set[str], |
| 166 | + ambiguous_feature_feature_group_mapping: Dict[str, set[str]], |
| 167 | + ) -> tuple[Dict[str, set[str]], set[str]]: |
| 168 | + """ |
| 169 | + Function that extracts all the ambiguous features in the joins of the query. The function will return a dictionary with feature group name of the ambiguous features as key and list of ambiguous features as value. |
| 170 | +
|
| 171 | + # Arguments |
| 172 | + `joins` : List of joins in the query. |
| 173 | + `selected_features` : List of selected features in the query. |
| 174 | + `feature_group_ambiegeous_feature_mapping` : Dictionary with feature group name of the ambiguous features as key and list of ambiguous features as value. |
| 175 | +
|
| 176 | + # Returns |
| 177 | + `Dict[str, List[str]]`: Dictionary with feature group name of the ambiguous features as key and list of ambiguous features as value. |
| 178 | + `set[str]`: Set of selected features in the query. |
| 179 | + """ |
| 180 | + for query_join in joins: |
| 181 | + query = query_join._query |
| 182 | + join_prefix = query_join.prefix |
| 183 | + |
| 184 | + join_features = { |
| 185 | + feature.name if not join_prefix else join_prefix + feature.name |
| 186 | + for feature in query._left_features |
| 187 | + } |
| 188 | + |
| 189 | + ambiguous_names = { |
| 190 | + feature_name |
| 191 | + for feature_name in join_features |
| 192 | + if feature_name in selected_features |
| 193 | + } |
| 194 | + |
| 195 | + if ambiguous_names: |
| 196 | + ambiguous_feature_feature_group_mapping[ |
| 197 | + query._left_feature_group.name |
| 198 | + ] = ambiguous_feature_feature_group_mapping.get( |
| 199 | + query._left_feature_group.name, set() |
| 200 | + ).union(ambiguous_names) |
| 201 | + |
| 202 | + selected_features.update(join_features) |
| 203 | + |
| 204 | + if query.joins: |
| 205 | + ( |
| 206 | + subquery_ambiguous_feature_feature_group_mapping, |
| 207 | + subquery_selected_features, |
| 208 | + ) = self.get_ambiguous_features_in_joins( |
| 209 | + query.joins, |
| 210 | + selected_features, |
| 211 | + ambiguous_feature_feature_group_mapping, |
| 212 | + ) |
| 213 | + ambiguous_feature_feature_group_mapping.update( |
| 214 | + subquery_ambiguous_feature_feature_group_mapping |
| 215 | + ) |
| 216 | + selected_features.update(subquery_selected_features) |
| 217 | + |
| 218 | + return ambiguous_feature_feature_group_mapping, selected_features |
| 219 | + |
| 220 | + def get_ambiguous_features(self: Query) -> Dict[str, set[str]]: |
| 221 | + """ |
| 222 | + Function to check ambiguous features in the query. The function will return a dictionary with feature group name of the ambiguous features as key and list of ambiguous features as value. |
| 223 | +
|
| 224 | + # Returns |
| 225 | + `Dict[str, List[str]]`: Dictionary with feature group name of the ambiguous features as key and list of ambiguous features as value. |
| 226 | + """ |
| 227 | + ambiguous_feature_feature_group_mapping: Dict[str, set[str]] = {} |
| 228 | + |
| 229 | + selected_features = {feature.name for feature in self._left_features} |
| 230 | + |
| 231 | + ambiguous_feature_feature_group_mapping, selected_features = ( |
| 232 | + self.get_ambiguous_features_in_joins( |
| 233 | + self._joins, selected_features, ambiguous_feature_feature_group_mapping |
| 234 | + ) |
| 235 | + ) |
| 236 | + |
| 237 | + return ambiguous_feature_feature_group_mapping |
| 238 | + |
141 | 239 | def read(
|
142 | 240 | self,
|
143 | 241 | online: bool = False,
|
@@ -302,6 +400,8 @@ def join(
|
302 | 400 | )
|
303 | 401 | )
|
304 | 402 |
|
| 403 | + self.check_ambiguous_features() |
| 404 | + |
305 | 405 | return self
|
306 | 406 |
|
307 | 407 | def as_of(
|
|
0 commit comments