Skip to content

Commit

Permalink
Separate Get and Post endpoint for Recommendation and Train
Browse files Browse the repository at this point in the history
  • Loading branch information
noeun-kim committed Jan 14, 2024
1 parent df97ae5 commit de241b2
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 32 deletions.
28 changes: 11 additions & 17 deletions new/recsys_serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ poetry install
```bash
# .env
# Default Baseline for DeepCoNN
MODEL_PATH=./data/src/model_versions/deepconn_model.pt
MODEL_PATH=./data/src/model_versions/model.pt
PYTHONPATH=
```
실행
Expand All @@ -27,27 +27,21 @@ poetry run python main.py

## Usage

### Predict
### Train & Recommendation

```bash
curl -X POST "http://0.0.0.0:8000/predict" -H "Content-Type: application/json" -d '{"features": [5.1, 3.5, 1.4, 0.2]}'

{"id":3,"result":0}
# train 값 설정을 통해 학습 유무를 결정할 수 있습니다
# Default model.pt 은 DeepCoNN 모델에 대한 파일입니다.
curl -X POST "http://0.0.0.0:8000/scoring" -H "Content-Type: application/json" -d '{"model": "DeepCoNN", "train": true, "vector_create": true}'
{"id":11676,"isbn": "0000000000", "rating": 0.0, "model":"FM"} # example
```

### Get all predictions

```bash
curl "http://0.0.0.0:8000/predict"

[{"id":1,"result":0},{"id":2,"result":0},{"id":3,"result":0}]
```

### Get a prediction

### Get a Predictions
```bash
curl "http://0.0.0.0:8000/predict/1"
{"id":1,"result":0}
# 전체 결과 검색
curl -X GET "http://0.0.0.0:8000/scoring"
# 특정 결과 검색
curl -X GET "http://0.0.0.0:8000/scoring/{user_id}"
```

## Build
Expand Down
79 changes: 64 additions & 15 deletions new/recsys_serving/api.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,57 @@
import os
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException
from sqlmodel import Session
from database import PredictionResult, engine
from pydantic import BaseModel, ValidationError
from fastapi import APIRouter, Body, HTTPException, status

from model import ModelOptions

router = APIRouter()

model_path = os.environ.get("MODEL_PATH")
data_path = os.environ.get("DATA_PATH")
model_path = os.path.join(os.path.curdir, "data/src/model_versions")
data_path = os.path.join(os.path.curdir, "data/src/")


class PredictionResponse(BaseModel):
user_id: int
isbn: str
rating: float
model: str


# 여러 모델에 대한 옵션, 학습 옵션을 추가해 볼 수 있습니다
# DeepCoNN 은 첫 실행 시 vector_create True 설정 필요
# 요청 예시: 0.0.0.0:8000/scoring/context?model_type=wdn
@router.get("/scoring/context")
def predict(model_type: str = "FM") -> PredictionResponse:
model_name = model_type.lower()
model_options = ModelOptions(model_name, data_path, model_path, accelerator="cpu")
@router.post("/scoring")
def train_model(
input_data: dict = Body(...),
) -> PredictionResponse:
model_name = input_data.get("model", "").lower()
train = input_data.get("train", True)
vector_create = input_data.get("vector_create", False)
model_builder = ModelOptions(model_name, data_path, model_path, accelerator="cpu")

if model_path is None or data_path is None:
raise ValueError("MODEL_PATH or DATA_PATH is not defined.")

try:
embeddings = model_options.get_embedding()
model_options.load_model(embeddings=embeddings)
model = model_options.get_model()
# DeepCoNN 의 경우 첫 실행 시 vector_create true 설정이 필요합니다
# 다른 모델에 대한 옵션 값 설정도 구현해 봅시다
embeddings = model_builder.get_embedding(vector_create)
model_builder.load_model(embeddings=embeddings)
model = model_builder.get_model()

trainer = model_options.get_trainer(
model=model,
)
trainer.train(embeddings)
trainer = model_builder.get_trainer(model=model)
trainer.train(embeddings) if train else None
scores = trainer.test(embeddings)

# sample
response = PredictionResponse(
user_id=11676,
isbn='0000000000',
rating=scores[0]
rating=scores[0],
model=model_name
)

except RuntimeError:
Expand All @@ -51,5 +60,45 @@ def predict(model_type: str = "FM") -> PredictionResponse:
raise HTTPException(status_code=400, detail="Input is not valid")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Something went wrong: {e}")
# 결과를 데이터베이스에 저장

prediction_result = PredictionResult(result=response)
with Session(engine) as session:
session.add(prediction_result)
session.commit()
session.refresh(prediction_result)

return response


@router.get("/scoring/{user_id}")
def get_result(user_id: int) -> PredictionResponse:
# 데이터베이스에서 특정 결과 가져옴
with Session(engine) as session:
prediction_result = session.get(PredictionResult, user_id)
if not prediction_result:
raise HTTPException(
detail="Not found", status_code=status.HTTP_404_NOT_FOUND
)
return PredictionResponse(
user_id=prediction_result.user_id,
isbn=prediction_result.isbn,
rating=prediction_result.rating,
model=prediction_result.model
)


@router.get("/scoring")
def get_results() -> list[PredictionResponse]:
# 데이터베이스에서 결과 가져옴
with Session(engine) as session:
prediction_results = session.query(PredictionResult).all()
return [
PredictionResponse(
user_id=prediction_result.user_id,
isbn=prediction_result.isbn,
rating=prediction_result.rating,
model=prediction_result.model,
)
for prediction_result in prediction_results
]

0 comments on commit de241b2

Please sign in to comment.