-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutils.py
269 lines (198 loc) · 7.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# -*- coding: utf-8 -*-
import ast
from collections import OrderedDict
import logging
from pathlib import PurePath, Path
import re
import time
import inspect
import warnings
# The numpy module may disappear during interpreter shutdown
# so explicitly import ndarray
from numpy import ndarray
from daskms.testing import in_pytest
log = logging.getLogger(__name__)
class ChunkTransformer(ast.NodeTransformer):
def visit_Module(self, node):
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
raise ValueError("Module must contain a single expression")
expr = node.body[0]
if not isinstance(expr.value, ast.Dict):
raise ValueError("Expression must contain a dictionary")
return self.visit(expr).value
def visit_Dict(self, node):
keys = [self.visit(k) for k in node.keys]
values = [self.visit(v) for v in node.values]
return {k: v for k, v in zip(keys, values)}
def visit_Name(self, node):
return node.id
def visit_Tuple(self, node):
return tuple(self.visit(v) for v in node.elts)
def visit_Constant(self, node):
return node.n
def parse_chunks_dict(chunks_str):
return ChunkTransformer().visit(ast.parse(chunks_str))
def natural_order(key):
return tuple(
int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(key))
)
def arg_hasher(args):
"""Recursively hash data structures -- handles list and dicts"""
if isinstance(args, (tuple, list, set)):
return hash(tuple(arg_hasher(v) for v in args))
elif isinstance(args, dict):
return hash(tuple((k, arg_hasher(v)) for k, v in sorted(args.items())))
elif isinstance(args, ndarray):
# NOTE(sjperkins)
# https://stackoverflow.com/a/16592241/1611416
# Slowish, but we shouldn't be passing
# huge numpy arrays in the TableProxy constructor
return hash(args.tostring())
else:
return hash(args)
def freeze(arg):
"""Recursively generates a hashable object from arg"""
if isinstance(arg, set):
return tuple(map(freeze, sorted(arg)))
elif isinstance(arg, (tuple, list)):
return tuple(map(freeze, arg))
elif isinstance(arg, (dict, OrderedDict)):
return frozenset((freeze(k), freeze(v)) for k, v in sorted(arg.items()))
elif isinstance(arg, ndarray):
if arg.nbytes > 10:
warnings.warn(
f"freezing ndarray of size {arg.nbytes} " f" is probably inefficient"
)
return freeze(arg.tolist())
else:
return arg
def promote_columns(columns, default):
"""
Promotes `columns` to a list of columns.
- None returns `default`
- single string returns a list containing that string
- tuple of strings returns a list of string
Parameters
----------
columns : str or list of str or None
Table columns
default : list of str
Default columns
Returns
-------
list of str
List of columns
"""
if columns is None:
if not isinstance(default, list):
raise TypeError("'default' must be a list")
return default
elif isinstance(columns, (tuple, list)):
for c in columns:
if not isinstance(c, str):
raise TypeError("columns must be a list of strings")
return list(columns)
elif isinstance(columns, str):
return [columns]
raise TypeError("'columns' must be a string or a list of strings")
def table_path_split(path):
"""Splits a table path into a (root, table, subtable) tuple"""
if not isinstance(path, PurePath):
path = Path(path)
root = path.parent
parts = path.name.split("::", 1)
if len(parts) == 1:
table_name = parts[0]
subtable = ""
elif len(parts) == 2:
table_name, subtable = parts
else:
raise RuntimeError("len(parts) not in (1, 2)")
return root, table_name, subtable
def group_cols_str(group_cols):
return f"group_cols={group_cols}"
def index_cols_str(index_cols):
return f"index_cols={index_cols}"
def select_cols_str(select_cols):
return f"select_cols={select_cols}"
def assert_liveness(table_proxies, executors, collect=True):
"""
Asserts that the given number of TableProxy
and Executor objects are alive.
"""
from daskms.table_proxy import _table_cache
from daskms.table_executor import _executor_cache
import gc
if collect:
gc.collect()
if table_proxies is not None and len(_table_cache) != table_proxies:
lines = ["len(_table_cache)[%d] != %d" % (len(_table_cache), table_proxies)]
for i, v in enumerate(_table_cache.values()):
lines.append("%d: %s is referred to by " "the following objects" % (i, v))
for r in gc.get_referrers(v):
lines.append(f"\t{str(r)}")
raise ValueError("\n".join(lines))
if executors is not None and len(_executor_cache) != executors:
lines = ["len(_executor_cache)[%d] != %d" % (len(_executor_cache), executors)]
for i, v in enumerate(_executor_cache.values()):
lines.append("%d: %s is referred to by " "the following objects" % (i, v))
for r in gc.get_referrers(v):
lines.append(f"\t{str(r)}")
raise ValueError("\n".join(lines))
def log_call(fn):
def _wrapper(*args, **kwargs):
log.info("%s() start at %s", fn.__name__, time.clock())
try:
return fn(*args, **kwargs)
except Exception:
log.exception("%s() exception", fn.__name__)
raise
finally:
log.info("%s() done at %s", fn.__name__, time.clock())
return _wrapper
def requires(*args):
import_errors = []
msgs = []
for a in args:
if isinstance(a, ImportError):
import_errors.append(a)
elif isinstance(a, str):
msgs.append(a)
if import_errors:
# Required dependencies are missing
def decorator(fn):
lines = [
f"Optional extras required by "
f"{fn.__name__} are missing due to "
f"the following ImportErrors:"
]
for i, e in enumerate(import_errors, 1):
lines.append(f"{i}. {str(e)}")
if msgs:
lines.append("")
lines.extend(msgs)
msg = "\n".join(lines)
def wrapper(*args, **kwargs):
if in_pytest():
import pytest
pytest.skip(msg)
else:
raise ImportError(msg) from import_errors[0]
return wrapper
else:
# Return original function as is
def decorator(fn):
return fn
return decorator
def filter_kwargs(func, kwargs):
"""Filters unhandled kwargs and raises appropriate warnings."""
known_args = inspect.getfullargspec(func).args
unhandled_kwargs = [k for k in kwargs.keys() if k not in known_args]
for unhandled_kwarg in unhandled_kwargs:
kwargs.pop(unhandled_kwarg)
if unhandled_kwargs:
warnings.warn(
f"The following kwargs will be ignored in {func.__name__}: "
f"{unhandled_kwargs}.",
UserWarning,
)