Skip to content

Commit

Permalink
Fixed error with category, added pip cache
Browse files Browse the repository at this point in the history
Signed-off-by: JonahSussman <sussmanjonah@gmail.com>
  • Loading branch information
JonahSussman committed Jul 25, 2024
1 parent 59c0738 commit d76fd3c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test-code-on-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
architecture: ${{ matrix.architecture }}
cache: pip
# After merging #225, use this as the cache instead.The cahce merely
# uses a hash of a specififed file to determine if it needs to update
# cache-dependency-path: requirements.in
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel coverage
Expand Down
23 changes: 18 additions & 5 deletions kai/service/incident_store/sql_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Any, Optional
from enum import Enum
from typing import Any, Optional, Type

import sqlalchemy
from sqlalchemy import (
VARCHAR,
Column,
Expand Down Expand Up @@ -40,6 +40,21 @@ def process_result_value(self, value: str, dialect: Dialect):
return Solution.model_validate_json(value)


def SQLEnum(enum_type: Type[Enum]):
"""
The default behavior of the Enum type in SQLAlchemy is to store the enum's name,
but we want to store the enum's value. This class is a workaround for that.
"""

return Enum(
value=enum_type.__name__,
names=[(e.value, e.value) for e in enum_type],
)


SQLCategory: Type = SQLEnum(report_types.Category)


class SQLBase(DeclarativeBase):
type_annotation_map = {
dict[str, Any]: JSON()
Expand Down Expand Up @@ -100,9 +115,7 @@ class SQLViolation(SQLBase):
ForeignKey("rulesets.ruleset_name"), primary_key=True
)

category = Column(
"category", sqlalchemy.Enum(report_types.Category).values_callable
)
category: Mapped[SQLCategory] # type: ignore
labels: Mapped[list[str]]

ruleset: Mapped[SQLRuleset] = relationship(back_populates="violations")
Expand Down

0 comments on commit d76fd3c

Please sign in to comment.