Skip to content

Commit f7ca6fe

Browse files
committed
Refactor and update source code.
- Refactor visualizations module - Add search_grid ML utility function - Clean up docstrings
1 parent 6a0e48d commit f7ca6fe

File tree

5 files changed

+250
-220
lines changed

5 files changed

+250
-220
lines changed

src/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Source code used in NYC Collisions Analysis."""

src/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Project constants, parameters, and settings."""
1+
"""Project constants, parameters, and settings."""
22

33
COORD_REF_SYSTEM = "EPSG:4326" # default geojson CRS
44

src/strings.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
""" String formatting functions."""
1+
"""String formatting functions."""
22

33
from typing import Iterable
44
import numpy as np
55

66

77
def add_line_breaks(str_arr: Iterable):
8-
"""Returns list of strings with optimal line break inserted"""
8+
"""Return list of strings with optimal line break inserted."""
99
max_line_length = max(length_after_line_break(s) for s in str_arr)
1010
return [insert_line_break(s, max_line_length) for s in str_arr]
1111

1212

1313
def insert_line_break(s: str, idx: int):
14-
"""Returns new string with line break inserted at or before index, idx"""
14+
"""Return new string with line break inserted at or before index (idx)."""
1515
spaces = [x for x in get_space_indices(s) if x <= idx]
1616
if not spaces:
1717
return s
@@ -22,8 +22,9 @@ def insert_line_break(s: str, idx: int):
2222

2323

2424
def length_after_line_break(s: str):
25-
"""Returns length of longest line after replacing space with line break
26-
near string mid-point."""
25+
"""Return length of longest line after replacing space with line break
26+
near string mid-point.
27+
"""
2728
spaces = get_space_indices(s)
2829
if not spaces:
2930
return len(s)
@@ -33,7 +34,7 @@ def length_after_line_break(s: str):
3334

3435

3536
def get_space_indices(s: str):
36-
"""Returns list of indices for spaces in a string"""
37+
"""Return list of indices for spaces in a string."""
3738
if not isinstance(s, str):
3839
return None
3940
return [idx for idx, c in enumerate(s) if c == " "]

src/utils.py

+59-30
Original file line numberDiff line numberDiff line change
@@ -9,46 +9,32 @@
99
import shapely
1010
from shapely.geometry import shape
1111
from shapely.strtree import STRtree
12+
from sklearn import model_selection
1213

1314

14-
def make_week_crosstab(df, divisor, values=None, aggfunc=None, day_of_week_map=None):
15-
"""Return an hour / day-of-week crosstab scaled by a divisor."""
16-
ct = pd.crosstab(
17-
index=df["datetime"].dt.dayofweek,
18-
columns=df["datetime"].dt.hour,
19-
values=values,
20-
aggfunc=aggfunc,
21-
)
22-
if day_of_week_map:
23-
ct.rename(index=day_of_week_map, inplace=True)
24-
ct /= divisor # scale crosstab by divisor
25-
return ct
26-
27-
28-
def get_crosstab_min_max(
29-
df, col, categories, divisor=None, values_col=None, aggfunc=None
15+
def min_max_across_crosstabs(
16+
categories, cat_series, idx_series, col_series, value_series=None, aggfunc=None
3017
):
31-
"""Return the min and max values of weekly crosstabs across all categories.
18+
"""Return the min and max values of crosstabs across all categories.
3219
3320
Categories should be an iterable. Used to ensure that different heatmaps
3421
have the same scale.
3522
"""
23+
if value_series is not None and aggfunc is None:
24+
raise TypeError("'value_series' requires 'aggfunc' to be specified.")
3625
max_val = float("-inf")
3726
min_val = float("inf")
3827
for cat in categories:
39-
is_true = df[col].isin([cat])
40-
idx = df.loc[is_true, "datetime"].dt.dayofweek
41-
cols = df.loc[is_true, "datetime"].dt.hour
28+
is_true = cat_series.isin([cat])
29+
idx = idx_series[is_true]
30+
cols = col_series[is_true]
4231
values = None
4332
if aggfunc:
44-
values = df.loc[is_true, values_col]
33+
values = value_series[is_true]
4534
ct = pd.crosstab(index=idx, columns=cols, values=values, aggfunc=aggfunc)
4635

47-
min_val = min(min_val, min(ct.min())) # ct.min() returns pd.Series
36+
min_val = min(min_val, min(ct.min())) # ct.min() / max() return pd.Series
4837
max_val = max(max_val, max(ct.max()))
49-
if divisor:
50-
min_val /= divisor
51-
max_val /= divisor
5238
return min_val, max_val
5339

5440

@@ -65,7 +51,7 @@ def make_heatmap_labels(
6551
return ct_labels
6652

6753

68-
def date_to_season(dt: datetime.datetime):
54+
def date_to_season(dt: datetime.datetime | pd.Timestamp):
6955
"""Convert individual datetime or pd.Timestamp to season of year."""
7056
# day of year corresponding to following dates:
7157
# 1-Jan, 21-Mar, 21-Jun, 21-Sep, 21-Dec, 31-Dec
@@ -83,8 +69,7 @@ def date_to_season(dt: datetime.datetime):
8369

8470

8571
def read_geojson(shape_file_loc: str, property_name: str):
86-
"""
87-
Return list of geometry ids and list of geometries from geojson.
72+
"""Return list of geometry ids and list of geometries from geojson.
8873
8974
Assumes geojson conforms to 2016 geojson convention.
9075
"""
@@ -96,8 +81,7 @@ def read_geojson(shape_file_loc: str, property_name: str):
9681

9782

9883
def id_nearest_shape(geometry: shapely.Point, r_tree: shapely.STRtree, shape_ids: list):
99-
"""
100-
Return the id (from list of shape_ids) of the nearest shape to input geometry.
84+
"""Return the id (from list of shape_ids) of the nearest shape to input geometry.
10185
10286
Uses a Shapely STRtree (R-tree) to perform a faster lookup.
10387
"""
@@ -125,3 +109,48 @@ def add_location_feature(
125109
lambda x: id_nearest_shape(x.geometry, tree, geom_ids), axis=1
126110
)
127111
return gdf
112+
113+
114+
def search_grid(x, y, model, params, score, num_cv=5, low_score_best=True):
115+
"""Perform grid search cross validation then print and return results.
116+
117+
Args:
118+
x (pd.DataFrame, pd.Series, or np.ndarray): Model features.
119+
y (pd.Series, or np.ndarray): Target.
120+
model (sklearn model): Model to use in grid search.
121+
params (dict): Key-value parameters to use in grid search. Key is model
122+
input name.
123+
score (str, callable, list, tuple or dict): Strategy to evaluate the
124+
performance of the cross-validated model on the test set.
125+
num_cv (int, cv generator or iterable): CV splitting strategy.
126+
low_score_best (bool): Whether the lowest score is best. False indicates
127+
that the highest score is best score.
128+
129+
Returns:
130+
list(tup): List of grid search cross-validation results as tuples containing:
131+
1) mean test score
132+
2) run time in minutes
133+
3) parameters used
134+
135+
"""
136+
param_grid = model_selection.ParameterGrid(params)
137+
results = []
138+
print("Mean Score", "\tRun Time(min)", "\tParameters")
139+
for param in param_grid:
140+
parameterized_model = model(**param)
141+
cv_run = model_selection.cross_validate(
142+
parameterized_model, x, y, scoring=score, cv=num_cv
143+
)
144+
145+
mean_score = sum(cv_run["test_score"]) / num_cv
146+
minutes = (sum(cv_run["fit_time"]) + sum(cv_run["score_time"])) / 60
147+
results.append((mean_score, minutes, param))
148+
result_string = f"{mean_score:.4f}\t\t{minutes:.3f}\t\t{param}"
149+
print(result_string)
150+
151+
results.sort(key=lambda z: z[0], reverse=low_score_best)
152+
best_score = f"\nBest score: {results[0][0]}\n"
153+
best_params = f"Best parameters: {results[0][2]}\n"
154+
print(best_score + best_params)
155+
156+
return results

0 commit comments

Comments
 (0)