22
22
from typing import Callable , Mapping , Sequence , TypeVar , Union
23
23
24
24
from absl import logging
25
+ from etils import epath
25
26
import pandas as pd
26
27
from tensorflow_datasets .datasets .smart_buildings import constants
27
28
from tensorflow_datasets .datasets .smart_buildings import reader_lib
@@ -45,8 +46,8 @@ class ProtoReader(reader_lib.BaseReader):
45
46
input_dir: directory path where the files are located
46
47
"""
47
48
48
- def __init__ (self , input_dir ):
49
- self ._input_dir = input_dir
49
+ def __init__ (self , input_dir : epath . PathLike ):
50
+ self ._input_dir = epath . Path ( input_dir )
50
51
logging .info ('Reader lib input directory %s' , self ._input_dir )
51
52
52
53
def read_observation_responses (
@@ -97,7 +98,7 @@ def read_reward_responses( # pytype: disable=signature-mismatch # overriding-r
97
98
98
99
def read_zone_infos (self ) -> Sequence [smart_control_building_pb2 .ZoneInfo ]:
99
100
"""Reads the zone infos for the Building from .pbtxt."""
100
- filename = os . path . join ( self ._input_dir , constants .ZONE_INFO_PREFIX )
101
+ filename = self ._input_dir / constants .ZONE_INFO_PREFIX
101
102
return self ._read_streamed_protos (
102
103
filename , smart_control_building_pb2 .ZoneInfo .FromString
103
104
)
@@ -107,7 +108,7 @@ def read_device_infos(
107
108
) -> Sequence [smart_control_building_pb2 .DeviceInfo ]:
108
109
"""Reads the device infos for the Building."""
109
110
110
- filename = os . path . join ( self ._input_dir , constants .DEVICE_INFO_PREFIX )
111
+ filename = self ._input_dir / constants .DEVICE_INFO_PREFIX
111
112
return self ._read_streamed_protos (
112
113
filename , smart_control_building_pb2 .DeviceInfo .FromString
113
114
)
@@ -141,28 +142,26 @@ def _read_messages(
141
142
messages .extend (file_messages )
142
143
return messages
143
144
144
- def _read_shards (self , input_dir : str , file_prefix : str ) -> Sequence [str ]:
145
+ def _read_shards (
146
+ self , input_dir : epath .Path , file_prefix : str
147
+ ) -> Sequence [epath .Path ]:
145
148
"""Returns full paths in input_dir of files starting with file_prefix."""
146
-
147
- shards = [
148
- os .path .join (input_dir , f )
149
- for f in os .listdir (input_dir )
150
- if f .startswith (file_prefix )
151
- ]
152
- return shards
149
+ return list (epath .Path (input_dir ).glob (f'{ file_prefix } *' ))
153
150
154
151
def _select_shards (
155
152
self ,
156
153
start_time : pd .Timestamp ,
157
154
end_time : pd .Timestamp ,
158
- shards : Sequence [str ],
159
- ) -> Sequence [str ]:
155
+ shards : Sequence [epath . Path ],
156
+ ) -> Sequence [epath . Path ]:
160
157
"""Returns the shards that fall inside the start and end times."""
161
158
162
- def _read_timestamp (filepath : str ) -> pd .Timestamp :
159
+ def _read_timestamp (filepath : epath . Path ) -> pd .Timestamp :
163
160
"""Reads the timestamp from the filepath."""
164
161
assert filepath
165
- ts = pd .Timestamp (re .findall (r'\d{4}\.\d{2}\.\d{2}\.\d{2}' , filepath )[- 1 ])
162
+ ts = pd .Timestamp (
163
+ re .findall (r'\d{4}\.\d{2}\.\d{2}\.\d{2}' , os .fspath (filepath ))[- 1 ]
164
+ )
166
165
return ts
167
166
168
167
def _between (
@@ -179,13 +178,13 @@ def _between(
179
178
180
179
def _read_streamed_protos (
181
180
self ,
182
- full_path : str ,
181
+ full_path : epath . Path ,
183
182
from_string_func : Callable [[Union [bytearray , bytes , memoryview ]], T ],
184
183
) -> Sequence [T ]:
185
184
"""Reads a proto which has byte size preceding the message."""
186
185
187
186
messages = []
188
- with open (full_path , 'rb' ) as f :
187
+ with full_path . open ('rb' ) as f :
189
188
while True :
190
189
# Read size as a varint
191
190
size_bytes = f .read (4 )
@@ -260,7 +259,7 @@ def get_episode_data(working_dir: str) -> pd.DataFrame:
260
259
Returns:
261
260
A dataframe with episode label, timestamps, number of updates.
262
261
"""
263
- episode_dirs = os . listdir (working_dir )
262
+ episode_dirs = list ( epath . Path (working_dir ). iterdir () )
264
263
date_extractor = operator .itemgetter (slice (- 13 , None ))
265
264
266
265
execution_times = pd .to_datetime (
0 commit comments