@@ -43,13 +43,15 @@ def __getattr__(self, name):
43
43
def bbox (self , variable ):
44
44
data_array = self [variable ]
45
45
46
- coords = self ._get_xy_coords (data_array )
47
- key = ("bbox" , tuple (coords ))
46
+ keys , coords = self ._get_xy_coords (data_array )
47
+ key = ("bbox" , tuple (keys ), tuple ( coords ))
48
48
if key in self ._cache :
49
49
return self ._cache [key ]
50
50
51
- lons = self ._get_xy (data_array , "x" , flatten = False )
52
- lats = self ._get_xy (data_array , "y" , flatten = False )
51
+ # lons = self._get_xy(data_array, "x", flatten=False)
52
+ # lats = self._get_xy(data_array, "y", flatten=False)
53
+
54
+ lats , lons = self ._get_latlon (data_array , flatten = False )
53
55
54
56
north = np .amax (lats )
55
57
west = np .amin (lons )
@@ -59,143 +61,86 @@ def bbox(self, variable):
59
61
self ._cache [key ] = (north , west , south , east )
60
62
return self ._cache [key ]
61
63
62
- # dims = data_array.dims
63
-
64
- # lat = dims[-2]
65
- # lon = dims[-1]
66
-
67
- # key = ("bbox", lat, lon)
68
- # if key in self._cache:
69
- # return self._cache[key]
70
-
71
- # lats, lons = self.grid_points(variable)
72
- # north = np.amax(lats)
73
- # west = np.amin(lons)
74
- # south = np.amin(lats)
75
- # east = np.amax(lons)
76
-
77
- # self._cache[key] = (north, west, south, east)
78
- # return self._cache[key]
79
-
80
- # if (lat, lon) not in self._bbox:
81
- # dims = data_array.dims
82
-
83
- # latitude = data_array[lat]
84
- # longitude = data_array[lon]
85
-
86
- # self._bbox[(lat, lon)] = (
87
- # np.amax(latitude.data),
88
- # np.amin(longitude.data),
89
- # np.amin(latitude.data),
90
- # np.amax(longitude.data),
91
- # )
92
-
93
- # return self._bbox[(lat, lon)]
94
-
95
- # def grid_points(self, variable):
96
- # data_array = self[variable]
97
- # dims = data_array.dims
98
-
99
- # lat = dims[-2]
100
- # lon = dims[-1]
101
-
102
- # key = ("grid_points", lat, lon)
103
- # if key in self._cache:
104
- # return self._cache[key]
105
-
106
- # if "latitude" in self._ds and "longitude" in self._ds:
107
- # latitude = self._ds["latitude"]
108
- # longitude = self._ds["longitude"]
109
-
110
- # if latitude.dims == (lat, lon) and longitude.dims == (lat, lon):
111
- # latitude = latitude.data
112
- # longitude = longitude.data
113
- # return latitude.flatten(), longitude.flatten()
114
-
115
- # latitude = data_array[lat]
116
- # longitude = data_array[lon]
117
-
118
- # lat, lon = np.meshgrid(latitude.data, longitude.data)
119
-
120
- # self._cache[key] = lat.flatten(), lon.flatten()
121
- # return self._cache[key]
122
-
123
- # def grid_points_xy(self, variable):
124
- # data_array = self[variable]
125
- # dims = data_array.dims
126
-
127
- # lat = dims[-2]
128
- # lon = dims[-1]
129
-
130
- # latitude = data_array[lat].data
131
- # longitude = data_array[lon].data
132
-
133
- # print(latitude, longitude)
134
-
135
- # lat, lon = np.meshgrid(latitude, longitude)
136
-
137
- # return lat.flatten(), lon.flatten()
138
- # # return self._cache[key]
139
-
140
- def _get_xy (self , data_array , axis , flatten = False , dtype = None ):
141
- if axis not in ("x" , "y" ):
142
- raise ValueError (f"Invalid axis={ axis } " )
143
-
144
- coords = self ._get_xy_coords (data_array )
145
- key = ("grid_points" , tuple (coords ))
146
- if key in self ._cache :
147
- points = self ._cache [key ]
148
- else :
149
- points = dict ()
150
- keys = [x [0 ] for x in coords ]
151
- coords = tuple ([x [1 ] for x in coords ])
152
-
153
- if "latitude" in self ._ds and "longitude" in self ._ds :
154
- latitude = self ._ds ["latitude" ]
155
- longitude = self ._ds ["longitude" ]
156
-
157
- if latitude .dims == coords and longitude .dims == coords :
158
- latitude = latitude .data
159
- longitude = longitude .data
160
- points ["x" ] = longitude
161
- points ["y" ] = latitude
162
- if not points :
163
- v0 , v1 = data_array .coords [coords [0 ]], data_array .coords [coords [1 ]]
164
- points [keys [1 ]], points [keys [0 ]] = np .meshgrid (v1 , v0 )
165
- self ._cache [key ] = points
166
-
167
- if flatten :
168
- points [axis ] = points [axis ].reshape (- 1 )
169
- if dtype is not None :
170
- return points [axis ].astype (dtype )
171
- else :
172
- return points [axis ]
173
-
174
64
def _get_xy_coords (self , data_array ):
175
- c = []
176
-
177
65
if (
178
66
len (data_array .dims ) >= 2
179
67
and data_array .dims [- 1 ] in GEOGRAPHIC_COORDS ["x" ]
180
68
and data_array .dims [- 2 ] in GEOGRAPHIC_COORDS ["y" ]
181
69
):
182
- return [ ("y" , data_array .dims [- 2 ]), ( "x" , data_array .dims [- 1 ])]
70
+ return ("y" , "x" ), ( data_array .dims [- 2 ], data_array .dims [- 1 ])
183
71
72
+ keys = []
73
+ coords = []
184
74
axes = ("x" , "y" )
185
75
for dim in data_array .dims :
186
76
for ax in axes :
187
77
candidates = GEOGRAPHIC_COORDS .get (ax , [])
188
78
if dim in candidates :
189
- c .append ((ax , dim ))
79
+ keys .append (ax )
80
+ coords .append (dim )
190
81
else :
191
82
ax = data_array .coords [dim ].attrs .get ("axis" , "" ).lower ()
192
83
if ax in axes :
193
- c .append ([ax , dim ])
194
- if len (c ) == 2 :
195
- return c
84
+ keys .append (ax )
85
+ coords .append (dim )
86
+ if len (keys ) == 2 :
87
+ return tuple (keys ), tuple (coords )
196
88
197
89
for ax in axes :
198
- if ax not in [ x [ 0 ] for x in c ] :
90
+ if ax not in keys :
199
91
raise ValueError (f"No coordinate found with axis '{ ax } '" )
200
92
201
- return c
93
+ return keys , coords
94
+
95
+ def _get_xy (self , data_array , flatten = False , dtype = None ):
96
+ keys , coords = self ._get_xy_coords (data_array )
97
+ key = ("grid_points" , tuple (keys ), tuple (coords ))
98
+
99
+ if key in self ._cache :
100
+ points = self ._cache [key ]
101
+ else :
102
+ points = dict ()
103
+ v0 , v1 = data_array .coords [coords [0 ]], data_array .coords [coords [1 ]]
104
+ points [keys [1 ]], points [keys [0 ]] = np .meshgrid (v1 , v0 )
105
+ self ._cache [key ] = points
106
+
107
+ if flatten :
108
+ points ["x" ] = points ["x" ].reshape (- 1 )
109
+ points ["y" ] = points ["y" ].reshape (- 1 )
110
+
111
+ if dtype is not None :
112
+ return points ["x" ].astype (dtype ), points ["y" ].astype (dtype )
113
+ else :
114
+ return points ["x" ], points ["y" ]
115
+
116
+ def _get_latlon (self , data_array , flatten = False , dtype = None ):
117
+ keys , coords = self ._get_xy_coords (data_array )
118
+
119
+ points = dict ()
120
+ if "latitude" in self ._ds and "longitude" in self ._ds :
121
+ latitude = self ._ds ["latitude" ]
122
+ longitude = self ._ds ["longitude" ]
123
+ if latitude .dims == coords and longitude .dims == coords :
124
+ latitude = latitude .data
125
+ longitude = longitude .data
126
+ points ["y" ] = latitude
127
+ points ["x" ] = longitude
128
+
129
+ if not points :
130
+ key = ("grid_points" , tuple (keys ), tuple (coords ))
131
+
132
+ if key in self ._cache :
133
+ points = self ._cache [key ]
134
+ else :
135
+ v0 , v1 = data_array .coords [coords [0 ]], data_array .coords [coords [1 ]]
136
+ points [keys [1 ]], points [keys [0 ]] = np .meshgrid (v1 , v0 )
137
+ self ._cache [key ] = points
138
+
139
+ if flatten :
140
+ points ["x" ] = points ["x" ].reshape (- 1 )
141
+ points ["y" ] = points ["y" ].reshape (- 1 )
142
+
143
+ if dtype is not None :
144
+ return points ["y" ].astype (dtype ), points ["x" ].astype (dtype )
145
+ else :
146
+ return points ["y" ], points ["x" ]
0 commit comments