Skip to content

Commit 316cb58

Browse files
Merge pull request #32 from Kalgoc/fix/expenses
Fix/expenses
2 parents fef195d + 2cf6490 commit 316cb58

17 files changed

+158
-119
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
exclude = .git,__pycache__,old,build,dist,.venv, */migrations/*
3-
ignore = E226,E302,E41,F401,F403,F405
3+
ignore = E226,E302,E41,F401,F403,F405,W503
44
max-line-length = 120
55
max-complexity = 10

AI/AI.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,28 @@
33
import os
44
from textwrap import dedent
55
from dotenv import load_dotenv
6+
from categories.models import Category
67

78

8-
def setup_api_key():
9-
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
10-
load_dotenv(dotenv_path=env_path)
11-
openai.api_key = os.getenv("OPENAI_KEY")
9+
GPT_MODEL = "gpt-4-turbo"
10+
11+
12+
def get_category_name_from_description(description):
13+
classified_category = classify_text(description)
14+
cleaned_category = category_matched_with_id(classified_category)
15+
16+
if not Category.objects.filter(name=cleaned_category).exists():
17+
raise ValueError("Categoría no encontrada.")
18+
return cleaned_category
1219

1320

1421
def classify_text(texto):
22+
setup_api_key()
23+
if not openai.api_key:
24+
raise ValueError("OpenAI API key is not set. Please check your .env file.")
25+
1526
response = openai.ChatCompletion.create(
16-
model="gpt-4-turbo",
27+
model=GPT_MODEL,
1728
messages=[
1829
{
1930
"role": "system",
@@ -43,7 +54,13 @@ def classify_text(texto):
4354
return category
4455

4556

46-
def clean_category(category):
57+
def setup_api_key():
58+
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
59+
load_dotenv(dotenv_path=env_path)
60+
openai.api_key = os.getenv("OPENAI_KEY")
61+
62+
63+
def category_matched_with_id(category):
4764
if category.lower().startswith("categoría: "):
4865
category = category[11:]
4966
if category.endswith("."):

AI/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# flake8: noqa
2-
from .AI import setup_api_key, classify_text, clean_category
2+
from .AI import setup_api_key, classify_text, category_matched_with_id

AI/tests.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# flake8: noqa
22
import unittest
33
from unittest.mock import patch
4-
from AI import classify_text, clean_category
4+
from AI import classify_text, category_matched_with_id
55
from textwrap import dedent
66

77

88
class TestGastosClassifier(unittest.TestCase):
9-
109
@patch("openai.ChatCompletion.create")
1110
def test_classify_text(self, mock_create):
1211
# Configurar el mock para devolver una respuesta simulada
@@ -50,13 +49,13 @@ def test_classify_text(self, mock_create):
5049

5150
def test_clean_category(self):
5251
# Pruebas con diferentes casos
53-
self.assertEqual(clean_category("Categoría: Comida."), "Comida")
54-
self.assertEqual(clean_category("categoría: transporte"), "Transporte")
55-
self.assertEqual(clean_category("vivienda."), "Vivienda")
52+
self.assertEqual(category_matched_with_id("Categoría: Comida."), "Comida")
53+
self.assertEqual(category_matched_with_id("categoría: transporte"), "Transporte")
54+
self.assertEqual(category_matched_with_id("vivienda."), "Vivienda")
5655

5756
# Prueba de una categoría no permitida
5857
with self.assertRaises(ValueError):
59-
clean_category("Categoría: Viajes.")
58+
category_matched_with_id("Categoría: Viajes.")
6059

6160

6261
if __name__ == "__main__":

authentication/services/cognito_service.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hashlib
55
import base64
66
from django.contrib.auth import get_user_model
7+
from user_expense_type.models import UserExpenseType
78

89

910
class CognitoService:
@@ -38,6 +39,10 @@ def register_user(self, name, phone, email, password):
3839
User.objects.create_user(
3940
username=email, email=email, password=password, first_name=name, phone=phone, user_id=cognito_uuid
4041
)
42+
user_expense_type = UserExpenseType.objects.create(
43+
username=cognito_uuid, set_by_user=False, name="Personal"
44+
)
45+
user_expense_type.save()
4146
return response
4247
except ClientError as e:
4348
raise e

categories/signals.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# categories/signals.py
2-
31
from django.db.models.signals import post_migrate
42
from django.dispatch import receiver
53
from .models import Category

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ services:
3232
POSTGRES_PASSWORD: ${DB_PASSWORD}
3333
POSTGRES_PORT: ${DB_PORT}
3434
POSTGRES_HOST: ${DB_HOST}
35+
OPENAI_API_KEY: ${OPENAI_API_KEY}
3536
DJANGO_SETTINGS_MODULE: piggywallet.settings.dev
3637

3738
networks:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Generated by Django 5.0.6 on 2024-06-25 02:31
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("expenses", "0001_initial"),
9+
]
10+
11+
operations = [
12+
migrations.AddField(
13+
model_name="expense",
14+
name="description",
15+
field=models.CharField(blank=True, max_length=255, null=True),
16+
),
17+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Generated by Django 5.0.6 on 2024-07-01 23:17
2+
3+
from django.db import migrations
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("expenses", "0002_alter_expense_bankcard_id_alter_expense_username"),
9+
("expenses", "0002_expense_description"),
10+
]
11+
12+
operations = []

expenses/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class Expense(models.Model):
1010
id = models.AutoField(primary_key=True)
1111
username = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, to_field="user_id")
12+
description = models.CharField(max_length=255, blank=True, null=True)
1213
user_expense_type = models.ForeignKey(UserExpenseType, on_delete=models.CASCADE)
1314
category = models.ForeignKey(Category, on_delete=models.CASCADE)
1415
bankcard_id = models.ForeignKey(BankCard, on_delete=models.CASCADE)

expenses/serializers.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,42 @@
99
class ExpenseSerializer(serializers.Serializer):
1010
id = serializers.IntegerField(read_only=True)
1111
username = serializers.PrimaryKeyRelatedField(queryset=get_user_model().objects.all())
12-
user_expense_type = serializers.PrimaryKeyRelatedField(queryset=UserExpenseType.objects.all())
12+
user_expense_type = serializers.PrimaryKeyRelatedField(queryset=UserExpenseType.objects.all(), required=False)
1313
category = serializers.PrimaryKeyRelatedField(queryset=Category.objects.all())
1414
bankcard_id = serializers.PrimaryKeyRelatedField(queryset=BankCard.objects.all())
1515
amount = serializers.IntegerField()
16+
description = serializers.CharField(max_length=255, allow_blank=True, allow_null=True)
17+
18+
def validate(self, data):
19+
user_expense_type = data.get("user_expense_type")
20+
username = data.get("username")
21+
if (
22+
user_expense_type
23+
and not UserExpenseType.objects.filter(id=user_expense_type.id, username=username).exists()
24+
):
25+
raise serializers.ValidationError(
26+
{"user_expense_type": "This user_expense_type does not belong to the specified user."}
27+
)
28+
29+
return data
1630

1731
def create(self, validated_data):
32+
user_id = validated_data.get("username")
33+
if "user_expense_type" not in validated_data:
34+
default_expense_type = UserExpenseType.objects.filter(
35+
username=user_id, name="Personal", set_by_user=False
36+
).first()
37+
if not default_expense_type:
38+
raise serializers.ValidationError("Default personal expense type not found for the user.")
39+
validated_data["user_expense_type"] = default_expense_type
40+
1841
return Expense.objects.create(**validated_data)
1942

2043
def update(self, instance, validated_data):
2144
instance.amount = validated_data.get("amount", instance.amount)
45+
instance.description = validated_data.get("description", instance.description)
46+
instance.bankcard_id = validated_data.get("bankcard_id", instance.bankcard_id)
47+
instance.user_expense_type = validated_data.get("user_expense_type", instance.user_expense_type)
48+
instance.category = validated_data.get("category", instance.category)
2249
instance.save()
2350
return instance
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from rest_framework.response import Response
2+
from rest_framework import status
3+
from categories.models import Category
4+
from AI.AI import get_category_name_from_description
5+
6+
7+
def categorize_expense_description(data):
8+
if "description" in data:
9+
try:
10+
category_name = get_category_name_from_description(data["description"])
11+
category_obj = Category.objects.get(name=category_name)
12+
data["category"] = category_obj.id
13+
return data, None
14+
except ValueError as e:
15+
return None, Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
16+
else:
17+
return None, Response({"error": "Description is required"}, status=status.HTTP_400_BAD_REQUEST)

expenses/views.py

+18-53
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from rest_framework import viewsets, status
22
from rest_framework.response import Response
3-
4-
# from .models import Expense
53
from .models import Expense
64
from user_expense_type.models import UserExpenseType
75
from categories.models import Category
@@ -10,6 +8,8 @@
108
import jwt
119
from django.db.models import Sum
1210
from django.utils.timezone import now
11+
from authentication.utils import get_user_id_from_token
12+
from expenses.services.categorize_expense import categorize_expense_description
1313

1414

1515
class ExpenseViewSet(viewsets.ViewSet):
@@ -35,11 +35,9 @@ def get_user_id_from_token(self, request):
3535
@cognito_authenticated
3636
def list(self, request):
3737
try:
38-
username = self.get_user_id_from_token(request)
38+
username = get_user_id_from_token(request)
3939
expenses = Expense.objects.filter(username=username)
4040
serializer = ExpenseSerializer(expenses, many=True)
41-
for ser in serializer.data:
42-
ser.pop("username")
4341
return Response(serializer.data)
4442
except Expense.DoesNotExist:
4543
return Response({"error": "Expenses not found"}, status=status.HTTP_404_NOT_FOUND)
@@ -49,16 +47,10 @@ def list(self, request):
4947
@cognito_authenticated
5048
def retrieve(self, request, pk=None):
5149
try:
52-
username = self.get_user_id_from_token(request)
50+
username = get_user_id_from_token(request)
5351
expense = Expense.objects.get(id=pk, username=username)
5452
serializer = ExpenseSerializer(expense)
55-
response = {
56-
"user_expense_type": serializer.data["user_expense_type"],
57-
"category": serializer.data["category"],
58-
"bankcard_id": serializer.data["bankcard_id"],
59-
"amount": serializer.data["amount"],
60-
}
61-
return Response(data=response)
53+
return Response(serializer.data)
6254
except Expense.DoesNotExist:
6355
return Response({"error": "Expense not found"}, status=status.HTTP_404_NOT_FOUND)
6456
except Exception as e:
@@ -67,28 +59,26 @@ def retrieve(self, request, pk=None):
6759
@cognito_authenticated
6860
def create(self, request):
6961
try:
70-
username = self.get_user_id_from_token(request)
62+
username = get_user_id_from_token(request)
7163
data = request.data.copy()
7264
data["username"] = username
7365

66+
data, error_response = categorize_expense_description(data)
67+
if error_response:
68+
return error_response
69+
7470
serializer = ExpenseSerializer(data=data)
7571
if serializer.is_valid():
7672
serializer.save()
77-
response = {
78-
"user_expense_type": serializer.data["user_expense_type"],
79-
"category": serializer.data["category"],
80-
"bankcard_id": serializer.data["bankcard_id"],
81-
"amount": serializer.data["amount"],
82-
}
83-
return Response(data=response, status=status.HTTP_201_CREATED)
73+
return Response(serializer.data, status=status.HTTP_201_CREATED)
8474
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
8575
except Exception as e:
8676
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
8777

8878
@cognito_authenticated
8979
def destroy(self, request, pk=None):
9080
try:
91-
username = self.get_user_id_from_token(request)
81+
username = get_user_id_from_token(request)
9282
expense = Expense.objects.get(id=pk, username=username)
9383
expense.delete()
9484
return Response(status=status.HTTP_204_NO_CONTENT)
@@ -100,49 +90,24 @@ def destroy(self, request, pk=None):
10090
@cognito_authenticated
10191
def partial_update(self, request, pk=None):
10292
try:
103-
username = self.get_user_id_from_token(request)
93+
username = get_user_id_from_token(request)
10494
expense = Expense.objects.get(id=pk, username=username)
10595
serializer = ExpenseSerializer(expense, data=request.data, partial=True)
106-
if serializer.is_valid():
107-
serializer.save()
108-
response = {
109-
"user_expense_type": serializer.data["user_expense_type"],
110-
"category": serializer.data["category"],
111-
"bankcard_id": serializer.data["bankcard_id"],
112-
"amount": serializer.data["amount"],
113-
}
114-
return Response(data=response)
115-
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
96+
serializer.is_valid(raise_exception=True)
97+
serializer.save()
98+
return Response(serializer.data, status=status.HTTP_200_OK)
99+
116100
except Expense.DoesNotExist:
117101
return Response({"error": "Expense not found"}, status=status.HTTP_404_NOT_FOUND)
118102
except Exception as e:
119103
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
120104

121105

122106
class ExpenseGroupedByTypeAndCategoryViewSet(viewsets.ViewSet):
123-
def get_user_id_from_token(self, request):
124-
try:
125-
authorization_header = request.headers.get("Authorization")
126-
if not authorization_header:
127-
raise Exception("Authorization header not found")
128-
129-
token = authorization_header.split()[1]
130-
decoded_token = jwt.decode(token, options={"verify_signature": False})
131-
username = decoded_token.get("username")
132-
if not username:
133-
raise Exception("User ID not found in token")
134-
return username
135-
except jwt.DecodeError:
136-
raise Exception("Invalid token")
137-
except jwt.ExpiredSignatureError:
138-
raise Exception("Expired token")
139-
except Exception as e:
140-
raise Exception(f"Error decoding token: {e}")
141-
142107
@cognito_authenticated
143108
def list(self, request):
144109
try:
145-
username = self.get_user_id_from_token(request)
110+
username = get_user_id_from_token(request)
146111
expenses_grouped = {}
147112
today = now()
148113
expenses = (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Generated by Django 5.0.6 on 2024-06-25 04:06
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("user_expense_type", "0001_initial"),
9+
]
10+
11+
operations = [
12+
migrations.AlterField(
13+
model_name="userexpensetype",
14+
name="set_by_user",
15+
field=models.BooleanField(default=True),
16+
),
17+
]

user_expense_type/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ class UserExpenseType(models.Model):
77
username = models.UUIDField()
88
name = models.CharField(max_length=70, default="Personal")
99
description = models.CharField(max_length=255, blank=True, null=True)
10-
set_by_user = models.BooleanField(default=False)
10+
set_by_user = models.BooleanField(default=True)
1111
created_at = models.DateTimeField(auto_now_add=True)
1212
updated_at = models.DateTimeField(auto_now=True)

0 commit comments

Comments
 (0)