Skip to content

Commit 40cc288

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Use epath instead of os to read files
PiperOrigin-RevId: 696497054
1 parent c3c62b8 commit 40cc288

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

tensorflow_datasets/datasets/smart_buildings/controller_reader.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Callable, Mapping, Sequence, TypeVar, Union
2323

2424
from absl import logging
25+
from etils import epath
2526
import pandas as pd
2627
from tensorflow_datasets.datasets.smart_buildings import constants
2728
from tensorflow_datasets.datasets.smart_buildings import reader_lib
@@ -45,8 +46,8 @@ class ProtoReader(reader_lib.BaseReader):
4546
input_dir: directory path where the files are located
4647
"""
4748

48-
def __init__(self, input_dir):
49-
self._input_dir = input_dir
49+
def __init__(self, input_dir: epath.PathLike):
50+
self._input_dir = epath.Path(input_dir)
5051
logging.info('Reader lib input directory %s', self._input_dir)
5152

5253
def read_observation_responses(
@@ -97,7 +98,7 @@ def read_reward_responses( # pytype: disable=signature-mismatch # overriding-r
9798

9899
def read_zone_infos(self) -> Sequence[smart_control_building_pb2.ZoneInfo]:
99100
"""Reads the zone infos for the Building from .pbtxt."""
100-
filename = os.path.join(self._input_dir, constants.ZONE_INFO_PREFIX)
101+
filename = self._input_dir / constants.ZONE_INFO_PREFIX
101102
return self._read_streamed_protos(
102103
filename, smart_control_building_pb2.ZoneInfo.FromString
103104
)
@@ -107,7 +108,7 @@ def read_device_infos(
107108
) -> Sequence[smart_control_building_pb2.DeviceInfo]:
108109
"""Reads the device infos for the Building."""
109110

110-
filename = os.path.join(self._input_dir, constants.DEVICE_INFO_PREFIX)
111+
filename = self._input_dir / constants.DEVICE_INFO_PREFIX
111112
return self._read_streamed_protos(
112113
filename, smart_control_building_pb2.DeviceInfo.FromString
113114
)
@@ -141,28 +142,26 @@ def _read_messages(
141142
messages.extend(file_messages)
142143
return messages
143144

144-
def _read_shards(self, input_dir: str, file_prefix: str) -> Sequence[str]:
145+
def _read_shards(
146+
self, input_dir: epath.Path, file_prefix: str
147+
) -> Sequence[epath.Path]:
145148
"""Returns full paths in input_dir of files starting with file_prefix."""
146-
147-
shards = [
148-
os.path.join(input_dir, f)
149-
for f in os.listdir(input_dir)
150-
if f.startswith(file_prefix)
151-
]
152-
return shards
149+
return list(epath.Path(input_dir).glob(f'{file_prefix}*'))
153150

154151
def _select_shards(
155152
self,
156153
start_time: pd.Timestamp,
157154
end_time: pd.Timestamp,
158-
shards: Sequence[str],
159-
) -> Sequence[str]:
155+
shards: Sequence[epath.Path],
156+
) -> Sequence[epath.Path]:
160157
"""Returns the shards that fall inside the start and end times."""
161158

162-
def _read_timestamp(filepath: str) -> pd.Timestamp:
159+
def _read_timestamp(filepath: epath.Path) -> pd.Timestamp:
163160
"""Reads the timestamp from the filepath."""
164161
assert filepath
165-
ts = pd.Timestamp(re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', filepath)[-1])
162+
ts = pd.Timestamp(
163+
re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', os.fspath(filepath))[-1]
164+
)
166165
return ts
167166

168167
def _between(
@@ -179,13 +178,13 @@ def _between(
179178

180179
def _read_streamed_protos(
181180
self,
182-
full_path: str,
181+
full_path: epath.Path,
183182
from_string_func: Callable[[Union[bytearray, bytes, memoryview]], T],
184183
) -> Sequence[T]:
185184
"""Reads a proto which has byte size preceding the message."""
186185

187186
messages = []
188-
with open(full_path, 'rb') as f:
187+
with full_path.open('rb') as f:
189188
while True:
190189
# Read size as a varint
191190
size_bytes = f.read(4)
@@ -260,7 +259,7 @@ def get_episode_data(working_dir: str) -> pd.DataFrame:
260259
Returns:
261260
A dataframe with episode label, timestamps, number of updates.
262261
"""
263-
episode_dirs = os.listdir(working_dir)
262+
episode_dirs = list(epath.Path(working_dir).iterdir())
264263
date_extractor = operator.itemgetter(slice(-13, None))
265264

266265
execution_times = pd.to_datetime(

0 commit comments

Comments
 (0)