Skip to content

Commit 9bd5cfc

Browse files
glados-vermacopybara-github
authored andcommitted
Add from_dataframe method to Measurement to create multidim measurement from a DataFrame.
This also adds some symmetry with the existing to_dataframe method. PiperOrigin-RevId: 722726903
1 parent b8cfa7e commit 9bd5cfc

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

openhtf/core/measurements.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,35 @@ def to_dataframe(self, columns: Any = None) -> Any:
480480

481481
return dataframe
482482

483+
def from_dataframe(self, dataframe: Any, metric_column: str) -> None:
484+
"""Convert a pandas DataFrame to a multi-dim measurement.
485+
486+
Args:
487+
dataframe: A pandas DataFrame. Dimensions for this multi-dim measurement
488+
need to match columns in the DataFrame (can be multi-index).
489+
metric_column: The column name of the metric to be measured.
490+
491+
Raises:
492+
TypeError: If this measurement is not dimensioned.
493+
ValueError: If dataframe is missing dimensions.
494+
"""
495+
if not isinstance(self._measured_value, DimensionedMeasuredValue):
496+
raise TypeError(
497+
'Only a dimensioned measurement can be set from a DataFrame'
498+
)
499+
dimension_labels = [d.name for d in self.dimensions]
500+
dimensioned_df = dataframe.reset_index()
501+
try:
502+
dimensioned_df.set_index(dimension_labels, inplace=True)
503+
except KeyError as e:
504+
raise ValueError('DataFrame is missing dimensions') from e
505+
if metric_column not in dimensioned_df.columns:
506+
raise ValueError(
507+
f'DataFrame does not have a column named {metric_column}'
508+
)
509+
for row_dimensions, row_metrics in dimensioned_df.iterrows():
510+
self.measured_value[row_dimensions] = row_metrics[metric_column]
511+
483512

484513
@attr.s(slots=True)
485514
class MeasuredValue(object):

test/core/measurements_test.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from openhtf.core import measurements
2626
from examples import all_the_things
2727
from openhtf.util import test as htf_test
28+
import pandas
2829

2930
# Fields that are considered 'volatile' for record comparison.
3031
_VOLATILE_FIELDS = {
@@ -231,14 +232,19 @@ def test_to_dataframe__no_pandas(self):
231232
with self.assertRaises(RuntimeError):
232233
self.test_to_dataframe(units=True)
233234

234-
def test_to_dataframe(self, units=True):
235+
def _make_multidim_measurement(self, units=''):
235236
measurement = htf.Measurement('test_multidim')
236237
measurement.with_dimensions('ms', 'assembly', htf.Dimension('my_zone'))
238+
if units:
239+
measurement.with_units(units)
240+
return measurement
237241

242+
def test_to_dataframe(self, units=True):
238243
if units:
239-
measurement.with_units('°C')
244+
measurement = self._make_multidim_measurement('°C')
240245
measure_column_name = 'degree Celsius'
241246
else:
247+
measurement = self._make_multidim_measurement()
242248
measure_column_name = 'value'
243249

244250
for t in range(5):
@@ -260,6 +266,100 @@ def test_to_dataframe(self, units=True):
260266
def test_to_dataframe__no_units(self):
261267
self.test_to_dataframe(units=False)
262268

269+
def _multidim_testdata(self):
270+
return {
271+
'ms': [1, 2, 3],
272+
'assembly': ['A', 'B', 'C'],
273+
'my_zone': ['X', 'Y', 'Z'],
274+
'degree_celsius': [10, 20, 30],
275+
}
276+
277+
def test_from_dataframe_raises_if_dimensions_missing_in_dataframe(self):
278+
measurement = self._make_multidim_measurement('°C')
279+
source_data = self._multidim_testdata()
280+
del source_data['assembly']
281+
with self.assertRaisesRegex(
282+
ValueError, 'DataFrame is missing dimensions'
283+
) as cm:
284+
measurement.from_dataframe(
285+
pandas.DataFrame(source_data),
286+
metric_column='degree_celsius',
287+
)
288+
with self.assertRaisesRegex(
289+
KeyError, r"None of \['assembly'\] are in the columns"
290+
):
291+
raise cm.exception.__cause__
292+
293+
def test_from_dataframe_raises_if_metric_missing_in_dataframe(self):
294+
measurement = self._make_multidim_measurement('°C')
295+
source_data = self._multidim_testdata()
296+
del source_data['degree_celsius']
297+
with self.assertRaisesRegex(
298+
ValueError, 'DataFrame does not have a column named degree_celsius'
299+
):
300+
measurement.from_dataframe(
301+
pandas.DataFrame(source_data),
302+
metric_column='degree_celsius',
303+
)
304+
305+
def _assert_multidim_measurement_matches_testdata(self, measurement):
306+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 10)
307+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
308+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
309+
pandas.testing.assert_frame_equal(
310+
measurement.to_dataframe().rename(
311+
columns={
312+
'ms': 'ms',
313+
'assembly': 'assembly',
314+
'my_zone': 'my_zone',
315+
# The metric column name comes from the unit.
316+
'degree Celsius': 'degree_celsius',
317+
}
318+
),
319+
pandas.DataFrame(self._multidim_testdata()),
320+
)
321+
322+
def test_from_flat_dataframe(self):
323+
measurement = self._make_multidim_measurement('°C')
324+
source_dataframe = pandas.DataFrame(self._multidim_testdata())
325+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
326+
measurement.outcome = measurements.Outcome.PASS
327+
self._assert_multidim_measurement_matches_testdata(measurement)
328+
329+
def test_from_dataframe_with_multiindex_dataframe(self):
330+
measurement = self._make_multidim_measurement('°C')
331+
source_dataframe = pandas.DataFrame(self._multidim_testdata()).set_index(
332+
['ms', 'assembly', 'my_zone']
333+
)
334+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
335+
measurement.outcome = measurements.Outcome.PASS
336+
self._assert_multidim_measurement_matches_testdata(measurement)
337+
338+
def test_from_dataframe_ignores_extra_columns(self):
339+
measurement = self._make_multidim_measurement('°C')
340+
source_data = self._multidim_testdata()
341+
source_data['degrees_fahrenheit'] = [11, 21, 31]
342+
source_dataframe = pandas.DataFrame(source_data)
343+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
344+
measurement.outcome = measurements.Outcome.PASS
345+
self._assert_multidim_measurement_matches_testdata(measurement)
346+
347+
def test_from_dataframe_with_duplicate_dimensions_overwrites(self):
348+
"""Verifies multi-dim measurement overwrite with duplicate dimensions."""
349+
measurement = self._make_multidim_measurement('°C')
350+
source_dataframe = pandas.DataFrame({
351+
'ms': [1, 2, 3, 1],
352+
'assembly': ['A', 'B', 'C', 'A'],
353+
'my_zone': ['X', 'Y', 'Z', 'X'],
354+
'degree_celsius': [10, 20, 30, 11],
355+
})
356+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
357+
measurement.outcome = measurements.Outcome.PASS
358+
# Overwritten value.
359+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 11)
360+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
361+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
362+
263363
def test_bad_validator(self):
264364
measurement = htf.Measurement('bad_measure')
265365
measurement.with_dimensions('a')

0 commit comments

Comments
 (0)