From d3c68cb1f37252d204cee3f13b133429e3dc1d88 Mon Sep 17 00:00:00 2001 From: JonahSussman Date: Thu, 25 Jul 2024 15:32:13 -0400 Subject: [PATCH] Fixed error with category, added pip cache Signed-off-by: JonahSussman --- .github/workflows/test-code-on-pr.yml | 9 +++++++++ kai/service/incident_store/sql_types.py | 23 ++++++++++++++++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-code-on-pr.yml b/.github/workflows/test-code-on-pr.yml index f75ce45a..acd31efb 100644 --- a/.github/workflows/test-code-on-pr.yml +++ b/.github/workflows/test-code-on-pr.yml @@ -29,10 +29,19 @@ jobs: with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.architecture }} + cache: pip + # After merging #225, use this as the cache instead.The cahce merely + # uses a hash of a specififed file to determine if it needs to update + # cache-dependency-path: requirements.in - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel coverage pip install -r requirements.txt + - name: Download prerequisites + run: | + cd example + ./fetch.sh + cd .. - name: Test with unittest run: | python -m coverage run --branch -m unittest discover diff --git a/kai/service/incident_store/sql_types.py b/kai/service/incident_store/sql_types.py index e8ab67e3..c04d76cf 100644 --- a/kai/service/incident_store/sql_types.py +++ b/kai/service/incident_store/sql_types.py @@ -1,7 +1,7 @@ import datetime -from typing import Any, Optional +from enum import Enum +from typing import Any, Optional, Type -import sqlalchemy from sqlalchemy import ( VARCHAR, Column, @@ -40,6 +40,21 @@ def process_result_value(self, value: str, dialect: Dialect): return Solution.model_validate_json(value) +def SQLEnum(enum_type: Type[Enum]): + """ + The default behavior of the Enum type in SQLAlchemy is to store the enum's name, + but we want to store the enum's value. This class is a workaround for that. + """ + + return Enum( + value=enum_type.__name__, + names=[(e.value, e.value) for e in enum_type], + ) + + +SQLCategory: Type = SQLEnum(report_types.Category) + + class SQLBase(DeclarativeBase): type_annotation_map = { dict[str, Any]: JSON() @@ -100,9 +115,7 @@ class SQLViolation(SQLBase): ForeignKey("rulesets.ruleset_name"), primary_key=True ) - category = Column( - "category", sqlalchemy.Enum(report_types.Category).values_callable - ) + category: Mapped[SQLCategory] # type: ignore labels: Mapped[list[str]] ruleset: Mapped[SQLRuleset] = relationship(back_populates="violations")