Skip to content

Commit

Permalink
Accepts Simulation object in custom writer methods
Browse files Browse the repository at this point in the history
  • Loading branch information
craabreu committed Apr 10, 2024
1 parent b79d9e4 commit ff7d478
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
12 changes: 5 additions & 7 deletions cvpack/reporting/custom_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import typing as t

import openmm as mm
from openmm import app as mmapp


@t.runtime_checkable
Expand All @@ -18,15 +18,15 @@ class CustomWriter(t.Protocol):
An abstract class for StateDataReporter writers
"""

def initialize(self, context: mm.Context) -> None:
def initialize(self, simulation: mmapp.Simulation) -> None:
"""
Initializes the writer. This method is called before the first report and
can be used to perform any necessary setup.
Parameters
----------
context
The context object.
simulation
The simulation object.
"""

def getHeaders(self) -> t.List[str]:
Expand All @@ -35,15 +35,13 @@ def getHeaders(self) -> t.List[str]:
"""
raise NotImplementedError("Method 'getHeaders' not implemented")

def getValues(self, context: mm.Context) -> t.List[float]:
def getValues(self, simulation: mmapp.Simulation) -> t.List[float]:
"""
Gets a list of floats containing the values to be added to the report.
Parameters
----------
simulation
The simulation object.
state
The state object.
"""
raise NotImplementedError("Method 'getValues' not implemented")
3 changes: 2 additions & 1 deletion cvpack/reporting/cv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def getHeaders(self) -> t.List[str]:
)
return headers

def getValues(self, context: mm.Context) -> t.List[float]:
def getValues(self, simulation: mm.app.Simulation) -> t.List[float]:
context = simulation.context
values = []
if self._value:
values.append(self._cv.getValue(context) / self._cv.getUnit())
Expand Down
3 changes: 2 additions & 1 deletion cvpack/reporting/meta_cv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def add_header(name: str, unit: mmunit.Unit) -> None:
)
return headers

def getValues(self, context: mm.Context) -> t.List[float]:
def getValues(self, simulation: mm.app.Simulation) -> t.List[float]:
context = simulation.context
values = []
if self._values:
inner_values = self._meta_cv.getInnerValues(context)
Expand Down
8 changes: 4 additions & 4 deletions cvpack/reporting/state_data_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def getHeaders(self) -> List[str]:
.. code-block::
def getValues(self, context: openmm.Context) -> List[float]:
def getValues(self, simulation: openmm.app.Simulation) -> List[float]:
pass
3. **initialize** (optional): performs any necessary setup before the first report.
If present, it must have the following signature:
.. code-block::
def initialize(self, context: openmm.Context) -> None:
def initialize(self, simulation: openmm.app.Simulation) -> None:
pass
Parameters
Expand Down Expand Up @@ -145,7 +145,7 @@ def _initializeConstants(self, simulation: mmapp.Simulation) -> None:
super()._initializeConstants(simulation)
for writer in self._writers:
if hasattr(writer, "initialize"):
writer.initialize(simulation.context)
writer.initialize(simulation)

def _constructHeaders(self) -> t.List[str]:
return self._expand(
Expand All @@ -158,5 +158,5 @@ def _constructReportValues(
) -> t.List[float]:
return self._expand(
super()._constructReportValues(simulation, state),
(w.getValues(simulation.context) for w in self._writers),
(w.getValues(simulation) for w in self._writers),
)

0 comments on commit ff7d478

Please sign in to comment.