diff --git a/internal/apprunner/app_runner.go b/internal/apprunner/app_runner.go index 691a162..16e9684 100644 --- a/internal/apprunner/app_runner.go +++ b/internal/apprunner/app_runner.go @@ -2,6 +2,7 @@ package apprunner import ( "context" + "fmt" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/fiber/v2/middleware/limiter" @@ -10,6 +11,7 @@ import ( "notes_service/internal/models" "notes_service/internal/ports" use_cases "notes_service/internal/usecases" + "notes_service/pkg/auth/password" "notes_service/pkg/storage/postgres" "notes_service/pkg/storage/redis" "os/user" @@ -47,12 +49,14 @@ func RunApp(mainConfig *config.Configs) error { notesUseCase := use_cases.NewNoteCRUDUseCase(notesRepo) notesHandler := http.NewNotesHandler(*notesUseCase) - usersRepo := ports.NewUsersRepoDB(db) + passwordManager := password.NewPasswordManagerBcrypt() + + usersRepo := ports.NewUsersRepoDB(db, passwordManager) usersUseCase := use_cases.NewUserCRUDUseCase(usersRepo) usersHandler := http.NewUsersHandler(*usersUseCase) - jwtHandler := http.NewJWTHandler(*usersUseCase, *mainConfig.JWT) + jwtHandler := http.NewJWTHandler(usersUseCase, *mainConfig.JWT) app := fiber.New() @@ -75,9 +79,12 @@ func RunApp(mainConfig *config.Configs) error { apiV1 := app.Group("/api/v1") - apiV1.Post("/users", jwtHandler.SignUpHandler) + apiV1.Post("/sign-up", jwtHandler.SignUpHandler) + apiV1.Post("/sign-in", jwtHandler.SignInHandler) + apiV1.Post("/refresh", jwtHandler.RefreshHandler) - protectedV1 := apiV1.Use("/protected", jwtHandler.JWTMiddleware()) + protectedV1 := apiV1.Group("/protected") + protectedV1.Use(jwtHandler.JWTMiddleware()) protectedV1.Get("/notes/by-user/:id", notesHandler.ListNotesHandler) protectedV1.Get("/notes/:id", notesHandler.GetNoteByIDHandler) @@ -86,11 +93,17 @@ func RunApp(mainConfig *config.Configs) error { protectedV1.Delete("/notes/:id", notesHandler.DeleteNoteHandler) protectedV1.Get("/notes/count-by-user/:id", notesHandler.CountNotesByUserHandler) + protectedV1.Get("/users/me", usersHandler.GetCurrentUserHandler) protectedV1.Get("/users/:id", usersHandler.GetUserByIDHandler) protectedV1.Get("/users/by-login/:login", usersHandler.GetUserByLoginHandler) protectedV1.Put("/users/:id", usersHandler.UpdateUserHandler) protectedV1.Delete("/users/:id", usersHandler.DeleteUserHandler) + routes := app.GetRoutes() + for _, route := range routes { + fmt.Printf("%s %s\n", route.Method, route.Path) + } + log.Info("Listening on port " + mainConfig.HTTP.Port) log.Info("Redis on " + mainConfig.Redis.URL) err = app.Listen(":" + mainConfig.HTTP.Port) diff --git a/internal/handler/http/auth.go b/internal/handler/http/auth.go index ffd75ea..efcb1ba 100644 --- a/internal/handler/http/auth.go +++ b/internal/handler/http/auth.go @@ -1,29 +1,37 @@ package http import ( + "errors" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "notes_service/config" "notes_service/internal/handler/schemas" "notes_service/internal/models" - "notes_service/internal/usecases" "notes_service/pkg/auth/jwtutils" "time" ) +// AuthUseCase Use case interface required for JWTHandler +type AuthUseCase interface { + GetUserByID(userID uuid.UUID) (models.User, bool, error) + CreateUser(user models.User) (models.User, error) + GetUserByLogin(login string) (models.User, bool, error) + GetUserByLoginAndPassword(login string, password string) (models.User, bool, error) +} + // JWTHandler is the auth header for fiber application. // It has some values for creating JWT's and an auth use case type JWTHandler struct { secretKey string accessTokenLifetime time.Duration refreshTokenLifetime time.Duration - useCase usecases.UserUseCase + useCase AuthUseCase } // NewJWTHandler creates and returns a new instance of NewJWTHandler. // It accepts the use case and a config.JWT to extract values from. -func NewJWTHandler(useCase usecases.UserUseCase, jwtConfig config.JWT) *JWTHandler { +func NewJWTHandler(useCase AuthUseCase, jwtConfig config.JWT) *JWTHandler { return &JWTHandler{ secretKey: jwtConfig.SecretKey, @@ -43,76 +51,159 @@ func (h *JWTHandler) JWTMiddleware() fiber.Handler { } tokenString := authHeader[len("Bearer "):] - token, err := jwtutils.ValidateToken(tokenString, h.secretKey) - if err != nil || !token.Valid { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid token"}) + userID, _, err := h.runChecksForTokenString(tokenString, "access") + if err != nil { + return NotAuthenticatedError(c, err) } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid token claims"}) - } + c.Locals("userID", userID) + return c.Next() + } +} - tokenType, ok := claims["type"].(string) - if !ok || tokenType != "access" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "token type must be access"}) - } +func (h *JWTHandler) runChecksForTokenString(tokenString string, requiredTokenType string) (uuid.UUID, string, error) { + token, err := jwtutils.ValidateToken(tokenString, h.secretKey) + if err != nil || !token.Valid { + return uuid.UUID{}, "", errors.New("invalid token") + } - userIDString, ok := claims["sub"].(string) - userID, err := uuid.Parse(userIDString) - if !ok || err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid user ID in token"}) - } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return uuid.UUID{}, "", errors.New("invalid token claims") + } - _, err = h.useCase.GetUserByID(userID) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "can't find user with given user ID"}) - } + tokenType, ok := claims["type"].(string) + if !ok || tokenType != requiredTokenType { + return uuid.UUID{}, "", errors.New("token type must be " + requiredTokenType) + } - c.Locals("userID", userID) - return c.Next() + userIDString, ok := claims["sub"].(string) + userID, err := uuid.Parse(userIDString) + if !ok || err != nil { + return uuid.UUID{}, "", errors.New("invalid user ID in token") } + + userLogin := claims["username"].(string) + + // user, userFound, err := h.useCase.GetUserByID(userID) + // if !userFound || err != nil { + // return models.User{}, errors.New("can't find user with given user ID") + // } + + return userID, userLogin, nil } -func (h *JWTHandler) SignUpHandler(c *fiber.Ctx) error { +func (h *JWTHandler) parseUser(c *fiber.Ctx) (models.User, error) { var body schemas.UserBodySchema if err := c.BodyParser(&body); err != nil { - return BadRequest(c, "invalid request body") + return models.User{}, err } user := models.User{ Login: body.Login, Password: body.Password, } - createdUser, err := h.useCase.CreateUser(user) - if err != nil { - return InternalServerError(c, err) - } + return user, nil +} +func (h *JWTHandler) createTokensForUser(userID uuid.UUID, userLogin string) (fiber.Map, error) { accessToken, err := jwtutils.GenerateToken( - createdUser.ID, - createdUser.Login, + userID, + userLogin, "access", h.accessTokenLifetime, h.secretKey, ) if err != nil { - return InternalServerError(c, err) + return nil, err } refreshToken, err := jwtutils.GenerateToken( - createdUser.ID, - createdUser.Login, + userID, + userLogin, "refresh", h.refreshTokenLifetime, h.secretKey, ) if err != nil { - return InternalServerError(c, err) + return nil, err } - return c.Status(fiber.StatusOK).JSON(fiber.Map{ + return fiber.Map{ "accessToken": accessToken, "refreshToken": refreshToken, - }) + }, nil +} + +// SignUpHandler HTTP handler for creating a new user and retrieving a new token +func (h *JWTHandler) SignUpHandler(c *fiber.Ctx) error { + user, err := h.parseUser(c) + if err != nil { + return BadRequest(c, "invalid user data") + } + + _, found, err := h.useCase.GetUserByLogin(user.Login) + if err != nil { + return InternalServerError(c, err) + } + if found { + return BadRequest(c, "username already exists") + } + + createdUser, err := h.useCase.CreateUser(user) + if err != nil { + return InternalServerError(c, err) + } + + returnData, err := h.createTokensForUser(createdUser.ID, createdUser.Login) + if err != nil { + return InternalServerError(c, err) + } + + return c.Status(fiber.StatusOK).JSON(returnData) +} + +// SignInHandler HTTP Handler for retrieving a new token by login and password +func (h *JWTHandler) SignInHandler(c *fiber.Ctx) error { + user, err := h.parseUser(c) + if err != nil { + return BadRequest(c, "invalid user data") + } + + foundUser, userFound, err := h.useCase.GetUserByLoginAndPassword(user.Login, user.Password) + if err != nil { + return InternalServerError(c, err) + } + if !userFound { + return NotAuthenticatedError(c, errors.New("user not found")) + } + + returnData, err := h.createTokensForUser(foundUser.ID, foundUser.Login) + if err != nil { + return InternalServerError(c, err) + } + + return c.JSON(returnData) +} + +// RefreshHandler HTTP handler for refreshing JWT's +func (h *JWTHandler) RefreshHandler(c *fiber.Ctx) error { + var body schemas.RefreshTokenSchema + if err := c.BodyParser(&body); err != nil { + return BadRequest(c, "invalid refresh token data") + } + + tokenString := body.RefreshToken + + userID, userLogin, err := h.runChecksForTokenString(tokenString, "refresh") + if err != nil { + return NotAuthenticatedError(c, err) + } + + returnData, err := h.createTokensForUser(userID, userLogin) + if err != nil { + return InternalServerError(c, err) + } + + return c.Status(fiber.StatusOK).JSON(returnData) } diff --git a/internal/handler/http/helpers.go b/internal/handler/http/helpers.go index c468e76..752d7ab 100644 --- a/internal/handler/http/helpers.go +++ b/internal/handler/http/helpers.go @@ -7,18 +7,34 @@ import ( func BadRequest(c *fiber.Ctx, message string) error { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ - "message": message, + "error": message, }) } func InternalServerError(c *fiber.Ctx, err error) error { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ - "message": err.Error(), + "error": err.Error(), }) } +func NotAuthenticatedError(c *fiber.Ctx, err error) error { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": err.Error(), + }) +} + +func NotFoundError(c *fiber.Ctx, message string) error { + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{ + "error": message, + }) +} + +func ForbiddenError(c *fiber.Ctx) error { + return c.Status(fiber.StatusForbidden).JSON(fiber.Map{}) +} + func ParseUUID(c *fiber.Ctx, fieldName string) (uuid.UUID, error) { - uuidField, err := uuid.Parse(c.Params("id")) + uuidField, err := uuid.Parse(c.Params(fieldName)) if err != nil { return uuid.UUID{}, err } diff --git a/internal/handler/http/notes.go b/internal/handler/http/notes.go index f592bd0..fb501a1 100644 --- a/internal/handler/http/notes.go +++ b/internal/handler/http/notes.go @@ -9,45 +9,71 @@ import ( "strconv" ) +// NotesHandler is a struct of HTTP handlers that relate to notes. Uses usecases.NoteUseCase. type NotesHandler struct { useCase usecases.NoteUseCase } +// NewNotesHandler creates and returns a new NotesHandler with given useCase func NewNotesHandler(useCase usecases.NoteUseCase) *NotesHandler { return &NotesHandler{useCase} } +func (h *NotesHandler) checkNoteBelongsToUser(c *fiber.Ctx, noteID uint, userID uuid.UUID) error { + existingNote, foundNote, err := h.useCase.GetNoteByID(noteID) + if err != nil { + return InternalServerError(c, err) + } + if !foundNote { + return NotFoundError(c, "couldn't find note with given id") + } + if existingNote.UserID != userID { + return ForbiddenError(c) + } + return nil +} + +// ListNotesHandler HTTP handler to list all notes by someone's UUID func (h *NotesHandler) ListNotesHandler(c *fiber.Ctx) error { - userID, err := strconv.Atoi(c.Params("id")) + userID, err := uuid.Parse(c.Params("id")) if err != nil { return BadRequest(c, "invalid id") } - result, err := h.useCase.GetNotesByUserID(uint(userID)) + result, err := h.useCase.GetNotesByUserID(userID) if err != nil { return InternalServerError(c, err) } return c.Status(fiber.StatusOK).JSON(result) } +// GetNoteByIDHandler HTTP handler to retrieve note by id. +// +// id param: id of the returned note +// +// Returns 404 if the note couldn't be found. func (h *NotesHandler) GetNoteByIDHandler(c *fiber.Ctx) error { noteID, err := strconv.Atoi(c.Params("id")) if err != nil { return BadRequest(c, "invalid note ID") } - note, err := h.useCase.GetNoteByID(uint(noteID)) + note, noteFound, err := h.useCase.GetNoteByID(uint(noteID)) + if !noteFound { + return NotFoundError(c, "couldn't find note with given id") + } if err != nil { return InternalServerError(c, err) } return c.Status(fiber.StatusOK).JSON(note) } +// CreateNoteHandler HTTP handler to create notes. Assigns the new record to current user. func (h *NotesHandler) CreateNoteHandler(c *fiber.Ctx) error { var body schemas.NoteBodySchema if err := c.BodyParser(&body); err != nil { return BadRequest(c, "invalid request body") } note := models.Note{ - UserID: body.UserID, + UserID: c.Locals("userID").(uuid.UUID), Title: body.Title, Content: body.Content, } @@ -58,6 +84,11 @@ func (h *NotesHandler) CreateNoteHandler(c *fiber.Ctx) error { return c.Status(fiber.StatusCreated).JSON(createdNote) } +// UpdateNoteHandler HTTP handler to update a note by id. +// +// id param: id of the note being affected. +// +// checks if the note belongs to the authenticated user func (h *NotesHandler) UpdateNoteHandler(c *fiber.Ctx) error { noteID, err := strconv.Atoi(c.Params("id")) if err != nil { @@ -69,29 +100,54 @@ func (h *NotesHandler) UpdateNoteHandler(c *fiber.Ctx) error { return BadRequest(c, "invalid request body") } updatedNote := models.Note{ - UserID: body.UserID, + UserID: c.Locals("userID").(uuid.UUID), Title: body.Title, Content: body.Content, } - result, err := h.useCase.Update(updatedNote, uint(noteID)) + noteIDUint := uint(noteID) + + err = h.checkNoteBelongsToUser(c, noteIDUint, updatedNote.UserID) + if err != nil { + return err + } + + result, err := h.useCase.Update(updatedNote, noteIDUint) if err != nil { return InternalServerError(c, err) } return c.Status(fiber.StatusOK).JSON(result) } +// DeleteNoteHandler HTTP handler to delete a note by id. +// +// id param: id of the note being deleted. +// +// checks if the note belongs to the authenticated user func (h *NotesHandler) DeleteNoteHandler(c *fiber.Ctx) error { noteID, err := strconv.Atoi(c.Params("id")) if err != nil { return BadRequest(c, "invalid note ID") } + + noteIDUint := uint(noteID) + + err = h.checkNoteBelongsToUser(c, noteIDUint, c.Locals("userID").(uuid.UUID)) + if err != nil { + return err + } + if err := h.useCase.DeleteNote(uint(noteID)); err != nil { return InternalServerError(c, err) } return c.Status(fiber.StatusNoContent).SendString("") } +// CountNotesByUserHandler HTTP handler that returns the amount (count) of someone's notes. +// +// id param: string uuid of the user whose notes the handler is going to count +// +// Returns 0 if user couldn't be found func (h *NotesHandler) CountNotesByUserHandler(c *fiber.Ctx) error { userIDStr := c.Params("id") if userIDStr == "" { @@ -103,7 +159,7 @@ func (h *NotesHandler) CountNotesByUserHandler(c *fiber.Ctx) error { return BadRequest(c, "invalid user ID") } - count, err := h.useCase.CountNotesByUser(userID) // Преобразование uint в UUID + count, err := h.useCase.CountNotesByUser(userID) if err != nil { return InternalServerError(c, err) } diff --git a/internal/handler/http/users.go b/internal/handler/http/users.go index 1ff0dd4..92a2d5f 100644 --- a/internal/handler/http/users.go +++ b/internal/handler/http/users.go @@ -2,6 +2,7 @@ package http import ( "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "notes_service/internal/handler/schemas" "notes_service/internal/models" use_cases "notes_service/internal/usecases" @@ -15,24 +16,46 @@ func NewUsersHandler(useCase use_cases.UserUseCase) *UsersHandler { return &UsersHandler{useCase} } -func (h *UsersHandler) GetUserByIDHandler(c *fiber.Ctx) error { - userID, err := ParseUUID(c, "id") - if err != nil { - return BadRequest(c, "invalid user UUID") +func (h *UsersHandler) returnInfoAboutUserById(c *fiber.Ctx, userID uuid.UUID) error { + user, userFound, err := h.useCase.GetUserByID(userID) + if !userFound { + return NotFoundError(c, "couldn't find user with given user UUID") } - user, err := h.useCase.GetUserByID(userID) if err != nil { return InternalServerError(c, err) } return c.Status(fiber.StatusOK).JSON(user) } +func (h *UsersHandler) checkIfUserIdBelongsToCurrentUser(c *fiber.Ctx, userID uuid.UUID) error { + if c.Locals("userID").(uuid.UUID) != userID { + return BadRequest(c, "user ID is not owned by you") + } + return nil +} + +func (h *UsersHandler) GetUserByIDHandler(c *fiber.Ctx) error { + userID, err := ParseUUID(c, "id") + if err != nil { + return BadRequest(c, "invalid user UUID") + } + return h.returnInfoAboutUserById(c, userID) +} + +func (h *UsersHandler) GetCurrentUserHandler(c *fiber.Ctx) error { + userID := c.Locals("userID").(uuid.UUID) + return h.returnInfoAboutUserById(c, userID) +} + func (h *UsersHandler) GetUserByLoginHandler(c *fiber.Ctx) error { login := c.Params("login") if login == "" { return BadRequest(c, "login is required") } - user, err := h.useCase.GetUserByLogin(login) + user, userFound, err := h.useCase.GetUserByLogin(login) + if !userFound { + return NotFoundError(c, "couldn't find user with given login") + } if err != nil { return InternalServerError(c, err) } @@ -45,6 +68,11 @@ func (h *UsersHandler) UpdateUserHandler(c *fiber.Ctx) error { return BadRequest(c, "invalid user UUID") } + err = h.checkIfUserIdBelongsToCurrentUser(c, userID) + if err != nil { + return err + } + var body schemas.UserBodySchema if err := c.BodyParser(&body); err != nil { return BadRequest(c, "invalid request body") @@ -67,6 +95,11 @@ func (h *UsersHandler) DeleteUserHandler(c *fiber.Ctx) error { return BadRequest(c, "invalid user UUID") } + err = h.checkIfUserIdBelongsToCurrentUser(c, userID) + if err != nil { + return err + } + if err := h.useCase.DeleteUser(userID); err != nil { return InternalServerError(c, err) } diff --git a/internal/handler/schemas/schemas.go b/internal/handler/schemas/schemas.go index fa86e8a..66023f2 100644 --- a/internal/handler/schemas/schemas.go +++ b/internal/handler/schemas/schemas.go @@ -1,13 +1,9 @@ package schemas -import ( - "github.com/google/uuid" -) - type NoteBodySchema struct { - UserID uuid.UUID `json:"user_id"` - Title string `json:"title"` - Content string `json:"content"` + // UserID uuid.UUID `json:"user_id"` + Title string `json:"title"` + Content string `json:"content"` } type UserBodySchema struct { diff --git a/internal/ports/adapters.go b/internal/ports/adapters.go index 04f60e6..f55ae0b 100644 --- a/internal/ports/adapters.go +++ b/internal/ports/adapters.go @@ -1,11 +1,13 @@ package ports import ( + "errors" "github.com/go-redis/redis/v7" "github.com/gofiber/fiber/v2" "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" "notes_service/internal/models" + "notes_service/pkg/auth/password" "notes_service/pkg/storage/postgres" "time" ) @@ -25,17 +27,23 @@ func NewNotesRepoDB(db *postgres.DBInstance) *NotesRepoDB { } // GetNotesByUserID retrieves all notes associated with a specific user ID. -func (r *NotesRepoDB) GetNotesByUserID(userID uint) ([]models.Note, error) { +func (r *NotesRepoDB) GetNotesByUserID(userID uuid.UUID) ([]models.Note, error) { var notesList []models.Note r.db.Db.Find(¬esList, "user_id = ?", userID) return notesList, nil } // GetNoteByID retrieves a single note by its unique ID. -func (r *NotesRepoDB) GetNoteByID(noteID uint) (models.Note, error) { +func (r *NotesRepoDB) GetNoteByID(noteID uint) (models.Note, bool, error) { var note models.Note - r.db.Db.First(¬e, noteID) - return note, nil + result := r.db.Db.First(¬e, noteID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return note, false, nil + } + return note, false, result.Error + } + return note, true, nil } // CreateNote creates a new note and saves it to the database. @@ -69,19 +77,23 @@ func (r *NotesRepoDB) CountNotesByUser(noteID uuid.UUID) (int64, error) { // UsersRepoDB represents a repository interface for interacting with the Users database. type UsersRepoDB struct { - db *postgres.DBInstance // Database instance for executing queries. + db *postgres.DBInstance // Database instance for executing queries. + passwordManager password.PasswordManager // Password utils instance for generating and checking passwords } var _ UsersRepo = (*UsersRepoDB)(nil) // NewUsersRepoDB initializes and returns a new UsersRepoDB instance. -func NewUsersRepoDB(db *postgres.DBInstance) *UsersRepoDB { - return &UsersRepoDB{db} +func NewUsersRepoDB(db *postgres.DBInstance, passwordManager password.PasswordManager) *UsersRepoDB { + return &UsersRepoDB{ + db, + passwordManager, + } } // SetPassword hashes the password for a user using bcrypt and sets it in the user object. -func SetPassword(user *models.User) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) +func (r *UsersRepoDB) SetPassword(user *models.User) error { + hashedPassword, err := r.passwordManager.GeneratePassword(user.Password) if err != nil { return err } @@ -90,32 +102,49 @@ func SetPassword(user *models.User) error { } // GetUserByID retrieves a user by their unique UUID. -func (r *UsersRepoDB) GetUserByID(userID uuid.UUID) (models.User, error) { +func (r *UsersRepoDB) GetUserByID(userID uuid.UUID) (models.User, bool, error) { var user models.User result := r.db.Db.First(&user, userID) - err := result.Error - return user, err + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return user, false, nil + } + return user, false, result.Error + } + return user, true, nil } // GetUserByLogin retrieves a user by their login name. -func (r *UsersRepoDB) GetUserByLogin(login string) (models.User, error) { +func (r *UsersRepoDB) GetUserByLogin(login string) (models.User, bool, error) { var user models.User result := r.db.Db.First(&user, "login = ?", login) - err := result.Error - return user, err + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return user, false, nil + } + return user, false, result.Error + } + return user, true, nil } // GetUserByLoginAndPassword retrieves a user by their login name and password. -func (r *UsersRepoDB) GetUserByLoginAndPassword(login string, password string) (models.User, error) { +func (r *UsersRepoDB) GetUserByLoginAndPassword(login string, password string) (models.User, bool, error) { var user models.User - result := r.db.Db.First(&user, "login = ? AND password = ?", login, password) - err := result.Error - return user, err + result := r.db.Db.First(&user, "login = ?", login) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return user, false, nil + } + return user, false, result.Error + } else if passwordsEqual, err := r.passwordManager.CheckPassword(password, user.Password); !passwordsEqual || err != nil { + return user, false, nil + } + return user, true, nil } // CreateUser creates a new user and saves it to the database, hashing the password before saving. func (r *UsersRepoDB) CreateUser(user models.User) (models.User, error) { - err := SetPassword(&user) + err := r.SetPassword(&user) if err != nil { return models.User{}, err } @@ -125,7 +154,7 @@ func (r *UsersRepoDB) CreateUser(user models.User) (models.User, error) { // UpdateUser updates an existing user in the database based on their UUID, rehashing the password if changed. func (r *UsersRepoDB) UpdateUser(user models.User, userID uuid.UUID) (models.User, error) { - err := SetPassword(&user) + err := r.SetPassword(&user) if err != nil { return models.User{}, err } diff --git a/internal/ports/ports.go b/internal/ports/ports.go index 29b29b9..f067eb1 100644 --- a/internal/ports/ports.go +++ b/internal/ports/ports.go @@ -7,9 +7,9 @@ import ( // UsersRepo Port for users type UsersRepo interface { - GetUserByID(userID uuid.UUID) (models.User, error) - GetUserByLogin(login string) (models.User, error) - GetUserByLoginAndPassword(login string, password string) (models.User, error) + GetUserByID(userID uuid.UUID) (models.User, bool, error) + GetUserByLogin(login string) (models.User, bool, error) + GetUserByLoginAndPassword(login string, password string) (models.User, bool, error) CreateUser(user models.User) (models.User, error) UpdateUser(user models.User, userID uuid.UUID) (models.User, error) DeleteUser(userID uuid.UUID) error @@ -17,8 +17,8 @@ type UsersRepo interface { // NotesRepo Port for notes type NotesRepo interface { - GetNotesByUserID(userID uint) ([]models.Note, error) - GetNoteByID(noteID uint) (models.Note, error) + GetNotesByUserID(userID uuid.UUID) ([]models.Note, error) + GetNoteByID(noteID uint) (models.Note, bool, error) CreateNote(note models.Note) (models.Note, error) UpdateNote(note models.Note, id uint) (models.Note, error) DeleteNote(noteID uint) error diff --git a/internal/usecases/use_cases.go b/internal/usecases/use_cases.go index 0cfc5f6..65b9c35 100644 --- a/internal/usecases/use_cases.go +++ b/internal/usecases/use_cases.go @@ -20,12 +20,12 @@ func NewNoteCRUDUseCase(notesRepo ports.NotesRepo) *NoteUseCase { } // GetNotesByUserID returns a list of notes for the specified user by their ID. -func (u *NoteUseCase) GetNotesByUserID(userID uint) ([]models.Note, error) { +func (u *NoteUseCase) GetNotesByUserID(userID uuid.UUID) ([]models.Note, error) { return u.notesRepo.GetNotesByUserID(userID) } // GetNoteByID returns a note by its unique ID. -func (u *NoteUseCase) GetNoteByID(noteID uint) (models.Note, error) { +func (u *NoteUseCase) GetNoteByID(noteID uint) (models.Note, bool, error) { return u.notesRepo.GetNoteByID(noteID) } @@ -64,18 +64,18 @@ func NewUserCRUDUseCase(usersRepo ports.UsersRepo) *UserUseCase { } // GetUserByID returns a user by their unique UUID. -func (u *UserUseCase) GetUserByID(userID uuid.UUID) (models.User, error) { +func (u *UserUseCase) GetUserByID(userID uuid.UUID) (models.User, bool, error) { return u.usersRepo.GetUserByID(userID) } // GetUserByLogin returns a user by their login. -func (u *UserUseCase) GetUserByLogin(login string) (models.User, error) { +func (u *UserUseCase) GetUserByLogin(login string) (models.User, bool, error) { return u.usersRepo.GetUserByLogin(login) } // GetUserByLoginAndPassword searches for a user by login and password. // Returns the found user or an error if the user is not found. -func (u *UserUseCase) GetUserByLoginAndPassword(login string, password string) (models.User, error) { +func (u *UserUseCase) GetUserByLoginAndPassword(login string, password string) (models.User, bool, error) { return u.usersRepo.GetUserByLoginAndPassword(login, password) } diff --git a/pkg/auth/password/password_utils.go b/pkg/auth/password/password_utils.go new file mode 100644 index 0000000..d6c7763 --- /dev/null +++ b/pkg/auth/password/password_utils.go @@ -0,0 +1,35 @@ +package password + +import "golang.org/x/crypto/bcrypt" + +type ( + + // PasswordManager Interface that represents a password manager that generates and checks them + PasswordManager interface { + GeneratePassword(password string) (string, error) + CheckPassword(password string, hash string) (bool, error) + } + + // PasswordManagerBcrypt is a bcrypt implementation of PasswordUtils + PasswordManagerBcrypt struct{} +) + +func NewPasswordManagerBcrypt() *PasswordManagerBcrypt { + return &PasswordManagerBcrypt{} +} + +func (p *PasswordManagerBcrypt) GeneratePassword(password string) (string, error) { + generatedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return string(generatedPassword), nil +} + +func (p *PasswordManagerBcrypt) CheckPassword(password string, hash string) (bool, error) { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + if err != nil { + return false, err + } + return true, nil +}