Skip to content

Commit

Permalink
Add tests for auth routes
Browse files Browse the repository at this point in the history
  • Loading branch information
bdmendes committed Nov 29, 2022
1 parent 550d5a4 commit cdf62b6
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pt.up.fe.ni.website.backend.controller

import org.springframework.security.access.prepost.PreAuthorize
import org.springframework.security.oauth2.jwt.JwtDecoder
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestBody
Expand All @@ -21,18 +20,18 @@ data class TokenDto(

@RestController
@RequestMapping("/auth")
class AuthController(val authService: AuthService, val jwtDecoder: JwtDecoder) {
class AuthController(val authService: AuthService) {
@PostMapping("/new")
fun getNewToken(@RequestBody loginDto: LoginDto): Map<String, String> {
val authentication = authService.authenticate(loginDto.email, loginDto.password)
val accessToken = authService.generateAccessToken(authentication)
val refreshToken = authService.generateRefreshToken(authentication)
val account = authService.authenticate(loginDto.email, loginDto.password)
val accessToken = authService.generateAccessToken(account)
val refreshToken = authService.generateRefreshToken(account)
return mapOf("access_token" to accessToken, "refresh_token" to refreshToken)
}

@PostMapping("/refresh")
fun refreshAccessToken(@RequestBody tokenDto: TokenDto): Map<String, String> {
val accessToken = authService.refreshToken(tokenDto.token)
val accessToken = authService.refreshAccessToken(tokenDto.token)
return mapOf("access_token" to accessToken)
}

Expand Down
46 changes: 27 additions & 19 deletions src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pt.up.fe.ni.website.backend.service

import org.springframework.http.HttpStatus
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.Authentication
import org.springframework.security.core.GrantedAuthority
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContextHolder
Expand All @@ -27,42 +26,45 @@ class AuthService(
val jwtDecoder: JwtDecoder,
private val passwordEncoder: PasswordEncoder
) {
fun authenticate(email: String, password: String): Authentication {
fun authenticate(email: String, password: String): Account {
val account = accountService.getAccountByEmail(email)
if (!passwordEncoder.matches(password, account.password)) {
throw ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, "Invalid credentials")
throw ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, "invalid credentials")
}
val authorities = listOf("BOARD", "MEMBER").stream() // TODO: get roles from account
.map { role -> SimpleGrantedAuthority(role) }
.collect(Collectors.toList())
return UsernamePasswordAuthenticationToken(email, password, authorities)
val authentication = UsernamePasswordAuthenticationToken(email, password, getAuthorities(account))
SecurityContextHolder.getContext().authentication = authentication
return account
}

fun generateAccessToken(authentication: Authentication): String {
return generateToken(authentication, Duration.ofMinutes(authConfigProperties.jwtAccessExpirationMinutes))
fun generateAccessToken(account: Account): String {
return generateToken(account, Duration.ofMinutes(authConfigProperties.jwtAccessExpirationMinutes))
}

fun generateRefreshToken(authentication: Authentication): String {
return generateToken(authentication, Duration.ofDays(authConfigProperties.jwtRefreshExpirationDays), true)
fun generateRefreshToken(account: Account): String {
return generateToken(account, Duration.ofDays(authConfigProperties.jwtRefreshExpirationDays), true)
}

fun refreshToken(refreshToken: String): String {
val jwt = jwtDecoder.decode(refreshToken)
fun refreshAccessToken(refreshToken: String): String {
val jwt =
try {
jwtDecoder.decode(refreshToken)
} catch (e: Exception) {
throw ResponseStatusException(HttpStatus.UNAUTHORIZED, "invalid refresh token")
}
if (jwt.expiresAt?.isBefore(Instant.now()) != false) {
throw ResponseStatusException(HttpStatus.UNAUTHORIZED, "Refresh token has expired")
throw ResponseStatusException(HttpStatus.UNAUTHORIZED, "refresh token has expired")
}
val account = accountService.getAccountByEmail(jwt.subject)
val authentication = authenticate(account.email, account.password)
return generateAccessToken(authentication)
return generateAccessToken(account)
}

fun getAuthenticatedAccount(): Account {
val authentication = SecurityContextHolder.getContext().authentication
return accountService.getAccountByEmail(authentication.name)
}

private fun generateToken(authentication: Authentication, expiration: Duration, isRefresh: Boolean = false): String {
val roles = if (isRefresh) emptyList<GrantedAuthority>() else authentication.authorities
private fun generateToken(account: Account, expiration: Duration, isRefresh: Boolean = false): String {
val roles = if (isRefresh) emptyList() else getAuthorities(account)
val scope = roles
.stream()
.map(GrantedAuthority::getAuthority)
Expand All @@ -73,9 +75,15 @@ class AuthService(
.issuer("self")
.issuedAt(currentInstant)
.expiresAt(currentInstant.plus(expiration))
.subject(authentication.name)
.subject(account.email)
.claim("scope", scope)
.build()
return jwtEncoder.encode(JwtEncoderParameters.from(claims)).tokenValue
}

private fun getAuthorities(account: Account): List<GrantedAuthority> {
return listOf("BOARD", "MEMBER").stream() // TODO: get roles from account
.map { role -> SimpleGrantedAuthority(role) }
.collect(Collectors.toList())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import pt.up.fe.ni.website.backend.model.Account
import pt.up.fe.ni.website.backend.model.CustomWebsite
import pt.up.fe.ni.website.backend.repository.AccountRepository
import pt.up.fe.ni.website.backend.utils.TestUtils
import pt.up.fe.ni.website.backend.utils.ValidationTester
import java.util.Calendar
import java.util.Date
import pt.up.fe.ni.website.backend.model.constants.AccountConstants as Constants
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package pt.up.fe.ni.website.backend.controller

import com.fasterxml.jackson.databind.ObjectMapper
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.DisplayName
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType
import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.test.annotation.DirtiesContext
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get
import org.springframework.test.web.servlet.post
import pt.up.fe.ni.website.backend.model.Account
import pt.up.fe.ni.website.backend.model.CustomWebsite
import pt.up.fe.ni.website.backend.repository.AccountRepository
import pt.up.fe.ni.website.backend.utils.TestUtils
import java.util.Calendar

@SpringBootTest
@AutoConfigureMockMvc
@AutoConfigureTestDatabase
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS)
class AuthControllerTest @Autowired constructor(
val repository: AccountRepository,
val mockMvc: MockMvc,
val objectMapper: ObjectMapper,
passwordEncoder: PasswordEncoder
) {
final val testPassword = "testPassword"

// TODO: Make sure to add "MEMBER" role to the account
val testAccount = Account(
"Test Account",
"test_account@test.com",
passwordEncoder.encode(testPassword),
"This is a test account",
TestUtils.createDate(2001, Calendar.JULY, 28),
"https://test-photo.com",
"https://linkedin.com",
"https://github.com",
listOf(
CustomWebsite("https://test-website.com", "https://test-website.com/logo.png")
)
)

@Nested
@DisplayName("POST /auth/new")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
inner class GetNewToken {
@BeforeAll
fun setup() {
repository.save(testAccount)
}

@Test
fun `should fail when email is invalid`() {
mockMvc.post("/auth/new") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(
mapOf(
"email" to "president@niaefeup.pt",
"password" to testPassword
)
)
}.andExpect {
status { isNotFound() }
jsonPath("$.errors[0].message") { value("account not found with email president@niaefeup.pt") }
}
}

@Test
fun `should fail when password is incorrect`() {
mockMvc.post("/auth/new") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(LoginDto(testAccount.email, "wrong_password"))
}.andExpect {
status { isUnprocessableEntity() }
jsonPath("$.errors[0].message") { value("invalid credentials") }
}
}

@Test
fun `should return access and refresh tokens`() {
mockMvc.post("/auth/new") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(LoginDto(testAccount.email, testPassword))
}.andExpect {
status { isOk() }
jsonPath("$.access_token") { exists() }
jsonPath("$.refresh_token") { exists() }
}
}
}

@Nested
@DisplayName("POST /auth/refresh")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
inner class RefreshToken {
@BeforeAll
fun setup() {
repository.save(testAccount)
}

@Test
fun `should fail when refresh token is invalid`() {
mockMvc.post("/auth/refresh") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(TokenDto("invalid_refresh_token"))
}.andExpect {
status { isUnauthorized() }
jsonPath("$.errors[0].message") { value("invalid refresh token") }
}
}

@Test
fun `should return new access token`() {
mockMvc.post("/auth/new") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(LoginDto(testAccount.email, testPassword))
}.andReturn().response.let { response ->
val refreshToken = objectMapper.readTree(response.contentAsString)["refresh_token"].asText()
mockMvc.post("/auth/refresh") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(TokenDto(refreshToken))
}.andExpect {
status { isOk() }
jsonPath("$.access_token") { exists() }
}
}
}
}

@Nested
@DisplayName("GET /auth/check")
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
inner class CheckToken {
@BeforeAll
fun setup() {
repository.save(testAccount)
}

@Test
fun `should fail when no access token is provided`() {
mockMvc.get("/auth").andExpect {
status { isUnauthorized() }
jsonPath("$.errors[0].message") { value("Access is denied") }
}
}

@Test
fun `should fail when access token is invalid`() {
mockMvc.get("/auth") {
header("Authorization", "Bearer invalid_access_token")
}.andExpect {
status { isUnauthorized() }
}
}

@Test
fun `should return authenticated user`() {
mockMvc.post("/auth/new") {
contentType = MediaType.APPLICATION_JSON
content = objectMapper.writeValueAsString(LoginDto(testAccount.email, testPassword))
}.andReturn().response.let { response ->
val accessToken = objectMapper.readTree(response.contentAsString)["access_token"].asText()
mockMvc.get("/auth") {
header("Authorization", "Bearer $accessToken")
}.andExpect {
status { isOk() }
jsonPath("$.authenticated_user") { value(testAccount.email) }
}
}
}

// TODO: Add tests for role access when implemented
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.springframework.test.web.servlet.post
import pt.up.fe.ni.website.backend.model.Event
import pt.up.fe.ni.website.backend.repository.EventRepository
import pt.up.fe.ni.website.backend.utils.TestUtils
import pt.up.fe.ni.website.backend.utils.ValidationTester
import java.util.Calendar
import pt.up.fe.ni.website.backend.model.constants.EventConstants as Constants

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pt.up.fe.ni.website.backend.controller

import com.fasterxml.jackson.databind.ObjectMapper
import org.hamcrest.Matchers.not
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotEquals
import org.junit.jupiter.api.BeforeAll
Expand All @@ -23,6 +22,7 @@ import org.springframework.test.web.servlet.post
import org.springframework.test.web.servlet.put
import pt.up.fe.ni.website.backend.model.Post
import pt.up.fe.ni.website.backend.repository.PostRepository
import pt.up.fe.ni.website.backend.utils.ValidationTester
import java.text.SimpleDateFormat
import java.util.Date
import pt.up.fe.ni.website.backend.model.constants.PostConstants as Constants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.springframework.test.web.servlet.post
import org.springframework.test.web.servlet.put
import pt.up.fe.ni.website.backend.model.Project
import pt.up.fe.ni.website.backend.repository.ProjectRepository
import pt.up.fe.ni.website.backend.utils.ValidationTester
import pt.up.fe.ni.website.backend.model.constants.ProjectConstants as Constants

@SpringBootTest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pt.up.fe.ni.website.backend.controller
package pt.up.fe.ni.website.backend.utils

import org.springframework.http.MediaType
import org.springframework.test.web.servlet.ResultActionsDsl
Expand Down

0 comments on commit cdf62b6

Please sign in to comment.