Skip to content

Commit

Permalink
load func -> custom target output folder
Browse files Browse the repository at this point in the history
Signed-off-by: Savitha Raghunathan <saveetha13@gmail.com>
  • Loading branch information
savitharaghunathan committed Feb 19, 2024
1 parent a0aa7e0 commit 355356e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 37 deletions.
4 changes: 2 additions & 2 deletions kai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def generate(


@app.command()
def load(folder_path: str):
def load(folder_path: str, output_dir: str):
"""
Load the incident store with the given applications
write the cached_violations to a file for later use
"""
incident_store = IncidentStore()
incident_store.load_incident_store(folder_path)
incident_store.load_incident_store(folder_path, output_dir)


@app.command()
Expand Down
40 changes: 19 additions & 21 deletions kai/incident_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,30 +234,24 @@ def find_common_violations(self, ruleset_name, violation_name):
break
return common_entries

def cleanup(self):
def cleanup(self, output_directory):
"""
Cleanup the incident store
"""
# delete cached_violations.yaml if it exists
if os.path.exists(
"samples/generated_output/incident_store/cached_violations.yaml"
):
os.remove("samples/generated_output/incident_store/cached_violations.yaml")
if os.path.exists(f"{output_directory}/cached_violations.yaml"):
os.remove(f"{output_directory}/cached_violations.yaml")
# delete solved_incidents.yaml if it exists
if os.path.exists(
"samples/generated_output/incident_store/solved_incidents.yaml"
):
os.remove("samples/generated_output/incident_store/solved_incidents.yaml")
if os.path.exists(f"{output_directory}/solved_incidents.yaml"):
os.remove(f"{output_directory}/solved_incidents.yaml")
# delete missing_incidents.yaml if it exists
if os.path.exists(
"samples/generated_output/incident_store/missing_incidents.yaml"
):
os.remove("samples/generated_output/incident_store/missing_incidents.yaml")
if os.path.exists(f"{output_directory}/missing_incidents.yaml"):
os.remove(f"{output_directory}/missing_incidents.yaml")
# clear cached_violations if it is not None
if self.cached_violations is not None:
self.cached_violations = {}

def load_incident_store(self, folder_path):
def load_incident_store(self, folder_path, output_directory):
# check if the folder exists
if not os.path.exists(folder_path):
print(f"Error: {folder_path} does not exist.")
Expand All @@ -270,7 +264,7 @@ def load_incident_store(self, folder_path):
print(f"Loading incident store with applications: {apps}\n")
if len(apps) != 0:
# cleanup incident store
self.cleanup()
self.cleanup(output_directory)

for app in apps:
# if app is a directory then check if there is a folder called initial
Expand All @@ -294,10 +288,16 @@ def load_incident_store(self, folder_path):
print("finding missing incidents")
self.update_incident_store(app)

self.write_cached_violations(self.cached_violations, "cached_violations.yaml")
self.write_cached_violations(
self.cached_violations, "cached_violations.yaml", output_directory
)
# write missing incidents to the a new file
self.write_cached_violations(self.missing_violations, "missing_incidents.yaml")
self.write_cached_violations(self.solved_violations, "solved_incidents.yaml")
self.write_cached_violations(
self.missing_violations, "missing_incidents.yaml", output_directory
)
self.write_cached_violations(
self.solved_violations, "solved_incidents.yaml", output_directory
)

def load_app_cached_violation(self, app, folder):
"""
Expand Down Expand Up @@ -338,12 +338,10 @@ def fetch_output_yaml(self, app_name, folder="solved"):

return output_yaml_path

def write_cached_violations(self, cached_violations, file_name):
def write_cached_violations(self, cached_violations, file_name, output_directory):
"""
Write the cached_violations to a file for later use
"""

output_directory = "samples/generated_output/incident_store"
dir_path = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.dirname(dir_path)
os.path.join(parent_dir, output_directory)
Expand Down
34 changes: 20 additions & 14 deletions tests/test_incident_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,15 @@ def test_update_cached_violations_with_incidents(self):
def test_load_incidentstore_cached_violation(self):
# Test when the specified app folder and output.yaml exist
folder_path = "tests/test_data/sample"
output_dir = "tests/test_data/incident_store"
i = IncidentStore()

i.load_incident_store(folder_path)
i.load_incident_store(folder_path, output_dir)
self.assertIsNotNone(i.cached_violations)
self.assertIsInstance(i.cached_violations, dict)
self.assertEqual(len(i.cached_violations), 3)

i.cleanup()
i.cleanup(output_dir)

def test_write_cached_violations(self):
test_cached_violations = {
Expand Down Expand Up @@ -205,18 +206,17 @@ def test_write_cached_violations(self):
},
}

output_file_path = (
"samples/generated_output/incident_store/test_cached_violations.yaml"
)
output_file_path = "tests/test_data/incident_store/test_cached_violations.yaml"
i = IncidentStore()
try:
# Call the function under test
i.write_cached_violations(
test_cached_violations, "test_cached_violations.yaml"
test_cached_violations,
"test_cached_violations.yaml",
"tests/test_data/incident_store",
)

# Check if the file was created
print(output_file_path)
self.assertTrue(os.path.exists(output_file_path))

# Check if the written datza matches the expected data
Expand All @@ -231,33 +231,39 @@ def test_write_cached_violations(self):

def test_find_solved_issues(self):
i = IncidentStore()
i.load_incident_store("tests/test_data/sample")
i.load_incident_store(
"tests/test_data/sample", "tests/test_data/incident_store"
)
patches = i.get_solved_issue(
"quarkus/springboot", "javaee-pom-to-quarkus-00010"
)
self.assertIsNotNone(patches)
self.assertEquals(len(patches), 1)
i.cleanup()
i.cleanup("tests/test_data/incident_store")

def test_find_solved_issues_no_solved_issues(self):
i = IncidentStore()
i.load_incident_store("tests/test_data/sample")
i.load_incident_store(
"tests/test_data/sample", "tests/test_data/incident_store"
)
patches = i.get_solved_issue(
"quarkus/springboot", "javaee-pom-to-quarkus-01010"
)
self.assertListEqual(patches, [])
self.assertEquals(len(patches), 0)
i.cleanup()
i.cleanup("tests/test_data/incident_store")

def test_find_common_violations(self):
i = IncidentStore()
i.load_incident_store("tests/test_data/sample")
i.load_incident_store(
"tests/test_data/sample", "tests/test_data/incident_store"
)
violations = i.find_common_violations(
"quarkus/springboot", "javaee-pom-to-quarkus-00010"
)
self.assertIsNotNone(violations)
self.assertEquals(len(violations), 1)
i.cleanup()
self.assertEqual(len(violations), 1)
i.cleanup("tests/test_data/incident_store")

if __name__ == "__main__":
unittest.main()

0 comments on commit 355356e

Please sign in to comment.