Skip to content

Commit

Permalink
reformatted code to better git epymorph standards
Browse files Browse the repository at this point in the history
  • Loading branch information
IzMo2000 committed Mar 20, 2024
1 parent 33d13bc commit dc90991
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 53 deletions.
8 changes: 2 additions & 6 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
"**/__pycache__": true,
".pytest_cache": true
},
"vim.leader": "<Space>",
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,

"editor.lineNumbers": "relative",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
Expand All @@ -20,15 +21,12 @@
"notebook.codeActionsOnSave": {
"source.organizeImports": "explicit"
},

"autopep8.importStrategy": "fromEnvironment",
"isort.importStrategy": "fromEnvironment",

"python.formatting.provider": "none",
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "basic",
"python.analysis.diagnosticMode": "workspace",

"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true,
"python.testing.unittestArgs": [
Expand All @@ -38,14 +36,12 @@
"-p",
"*_test.py"
],

"editor.defaultFormatter": null,
"[python]": {
"editor.detectIndentation": false,
"editor.insertSpaces": true,
"editor.tabSize": 4,
"editor.defaultFormatter": "ms-python.autopep8",
},

"jupyter.notebookFileRoot": "${workspaceFolder}",
}
19 changes: 11 additions & 8 deletions epymorph/test/viz_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import unittest
import epymorph.viz
from epymorph.compartment_model import CompartmentModel
from epymorph import ipm_library
from os import path, remove, rmdir

from graphviz import Digraph
from sympy import Symbol
from os import path, remove, rmdir

import epymorph.viz
from epymorph import ipm_library
from epymorph.compartment_model import CompartmentModel


class VizTest(unittest.TestCase):
test_ipm = ipm_library['sirs']()
Expand All @@ -25,7 +28,8 @@ def test_build_model(self):
self.assertEqual(graph_attrs['edge_attr']['minlen'], '2.0')

for edge_index in range(len(graph_attrs['body'])):
self.assertTrue(graph_attrs['body'][edge_index].startswith(edge_prefixes[edge_index]))
self.assertTrue(graph_attrs['body'][edge_index].startswith(
edge_prefixes[edge_index]))

def test_edge_tracker(self):
test_tracker = epymorph.viz.EdgeTracker()
Expand All @@ -41,8 +45,8 @@ def test_edge_tracker(self):
self.assertEqual(d + d, test_tracker.get_edge_label('a', 'c'))

def test_pngtolabel(self):
self.assertEqual(epymorph.viz.png_to_label('test_path'),
'<<TABLE border="0"><TR><TD><IMG SRC="test_path"/></TD></TR></TABLE>>')
self.assertEqual(epymorph.viz.png_to_label('test_path'),
'<<TABLE border="0"><TR><TD><IMG SRC="test_path"/></TD></TR></TABLE>>')

def test_save(self):
self.assertFalse(epymorph.viz.save_model(self.test_graph, ''))
Expand All @@ -54,4 +58,3 @@ def test_save(self):
remove('model_pngs/test_model.png')

rmdir('model_pngs')

83 changes: 44 additions & 39 deletions epymorph/viz.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from graphviz import Digraph
from sympy import Expr, Symbol, preview
from typing import List, Union
from IPython import display
from tempfile import NamedTemporaryFile
from os import path, makedirs
from io import BytesIO
from matplotlib.image import imread
from os import makedirs, path
from tempfile import NamedTemporaryFile

import matplotlib.pyplot as plt
from graphviz import Digraph, Graph
from IPython import display
from matplotlib.image import imread
from sympy import Expr, Symbol, preview


class ipmDraw():
"""class for functions related to drawing a graphviz ipm graph"""

@staticmethod
def jupyter(graph: Digraph):
"""draws the graph in a jupyter notebook"""

display.display_png(graph)

@staticmethod
def console(graph: Digraph):
"""draws graph to console"""

Expand All @@ -35,21 +38,22 @@ def console(graph: Digraph):
plt.axis('off')

# show the model png
plt.show()
plt.show()


class EdgeTracker():
""" class for keeping track of the edges added to the visualization """

def __init__(self):
self.edge_dict = {}
"""
dictionary for tracking edges, key = (head, tail)
dictionary for tracking edges, key = (head, tail)
value = edge label
"""

def track_edge(self, head: str, tail: str, label: Expr) -> None:
"""
given a head, tail, and label for an edge, tracks it and updates the
"""
given a head, tail, and label for an edge, tracks it and updates the
edge label (a sympy expr) if necessary
"""

Expand All @@ -65,8 +69,7 @@ def track_edge(self, head: str, tail: str, label: Expr) -> None:
# add edge w/ label to edge dict
self.edge_dict[(head, tail)] = label


def get_edge_label(self, head:str, tail:str) -> str:
def get_edge_label(self, head: str, tail: str) -> str:
""" given a head and tail for an edge, return its label """

# ensure label exists
Expand All @@ -76,60 +79,61 @@ def get_edge_label(self, head:str, tail:str) -> str:
return self.edge_dict[(head, tail)]

# not in list, return empty expression
return None
return ""


def build_ipm_graph(ipm) -> Digraph:
"""
primary function for creating a model visualization, given an ipm label
primary function for creating a model visualization, given an ipm label
that exists within the ipm library
"""
# init a tracker to be used for tacking edges and edge labels
tracker = EdgeTracker()

# fetch ipm event data
ipm_events = ipm.events

# init graph for model visualization to save to png, strict flag makes
# it so repeated edges are merged
model_viz = Digraph(format = 'png', strict=True,
graph_attr = {'rankdir': 'LR'},
node_attr = {'shape': 'square',
'width': '.9',
'height': '.8'},
edge_attr = {'minlen': '2.0'})
model_viz = Digraph(format='png', strict=True,
graph_attr={'rankdir': 'LR'},
node_attr={'shape': 'square',
'width': '.9',
'height': '.8'},
edge_attr={'minlen': '2.0'})

# render edges
for event in ipm_events:

# get the current head and tail of the edge
curr_head, curr_tail = str(event.compartment_from), \
str(event.compartment_to)
str(event.compartment_to)

# add edge to tracker, using the rate as the label
tracker.track_edge(curr_head, curr_tail, event.rate)

# get santized edge label from newly tracked edge
label_expr = tracker.get_edge_label(curr_head, curr_tail)

# create a temporary png file to render LaTeX edge label
with NamedTemporaryFile(suffix='.png',
delete=False) as temp_png:
with NamedTemporaryFile(suffix='.png',
delete=False) as temp_png:

# load label as LaTeX png into temp file
preview(label_expr, viewer='file', filename=temp_png.name,
euler=False)
preview(label_expr, viewer='file', filename=temp_png.name,
euler=False)

# render edge
model_viz.edge(curr_head, curr_tail,
label=png_to_label(temp_png.name))
model_viz.edge(curr_head, curr_tail,
label=png_to_label(temp_png.name))

# return created visualization graph
return model_viz


def render(ipm, save: bool = False, filename: str = "",
console: bool = False) \
-> None:
console: bool = False) \
-> None:
"""
main function for converting an ipm into a visual model
ipm: the model to be converted
Expand All @@ -141,17 +145,18 @@ def render(ipm, save: bool = False, filename: str = "",

if console:
ipmDraw.console(ipm_graph)

else:
ipmDraw.jupyter(ipm_graph)

if save:
save_model(ipm_graph, filename)


def save_model(ipm_graph: Digraph, filename: str) -> bool:
"""
function that saves a given graphviz ipm digraph to a png in the
'model_pngs' folder with the given file name. Creates the folder if it
'model_pngs' folder with the given file name. Creates the folder if it
does not exist
"""

Expand All @@ -165,16 +170,16 @@ def save_model(ipm_graph: Digraph, filename: str) -> bool:
makedirs('model_pngs')

# render and save png
ipm_graph.render(filename, directory = 'model_pngs', cleanup=True)
ipm_graph.render(filename, directory='model_pngs', cleanup=True)

return True


# file name is empty, print err message
print("ERR: no file name provided, could not save model")

return False


def png_to_label(png_filepath: str) -> str:
"""
helper function for displaying an image label using graphvz, requires the
Expand All @@ -183,5 +188,5 @@ def png_to_label(png_filepath: str) -> str:

return (
f'<<TABLE border="0"><TR><TD><IMG SRC="{png_filepath}"/>' +
'</TD></TR></TABLE>>'
'</TD></TR></TABLE>>'
)

0 comments on commit dc90991

Please sign in to comment.