In this article, you will learn how to write unit tests for your Rust API project using the Actix-web framework and the SQLx toolkit. It’s crucial to note that these tests are tailored specifically for the Rust API we’ve been creating in this tutorial series.

Since we are using SQLx, we get the luxury of SQLx marking our test functions asynchronous, along with setting up and managing a test database for the duration of the test. Similarly, Actix-web provides macros that streamline the process of creating unit tests for our API, eliminating the need for external libraries or crates.

Before we proceed, it’s important to clarify that while we’ll do our best to cover a significant number of test cases within the project, due to the sheer quantity of tests involved, I won’t be able to explain each test case individually. This approach aims to prevent the article from becoming excessively lengthy. However, once you grasp the mechanics of one test, you’ll be able to apply that knowledge to the remaining ones, as they follow the same structure.

Articles in This Series

  1. Building a Rust API with Unit Testing in Mind
  2. How to Add Swagger UI, Redoc and RapiDoc to a Rust API
  3. JWT Authentication and Authorization in a Rust API using Actix-Web
  4. How to Write Unit Tests for Your Rust API
  5. Dockerizing a Rust API Project: SQL Database and pgAdmin
  6. Deploy Rust App on VPS with GitHub Actions and Docker
How to Write Unit Tests for Your Rust API

Running the Unit Tests on Your Machine

To run the Rust project on your local machine and execute the accompanying unit tests, follow the steps outlined below:

  • Download or clone the project from its GitHub repository at https://github.com/wpcodevo/complete-restful-api-in-rust and open the source code in your preferred code editor.
  • Start the Postgres and pgAdmin Docker containers by running the command docker-compose -f docker-compose.no_api.yml up -d.
  • Apply the database migrations to the PostgreSQL database by running sqlx migrate run. If you don’t already have the SQLx-CLI installed on your computer, you can do so by running the command cargo install sqlx-cli --no-default-features --features postgres.
  • In the terminal of the root directory, run the command cargo test which will first build the project and run the 49 unit tests that come with the Rust project.
  • Once the tests have been completed successfully, run the command cargo run to start the Actix-web development server.
  • You can access the Swagger documentation by navigating to http://localhost:8000/. Additionally, you have the option to use the Redoc UI available at http://localhost:8000/redoc or the Rapidoc UI accessible through http://localhost:8000/rapidoc. These user interfaces offer a user-friendly approach to interact with the API endpoints and explore their functionality.

What are Unit Tests?

Unit tests are a fundamental aspect of software testing that focuses on testing individual units or components of code in isolation. These units could be functions, methods, or small groups of tightly interconnected functions. Testing is a cheap and easy way to find bugs. The great thing about unit tests is that they are inexpensive to set up and can be rerun at a modest cost.

Let’s demonstrate unit testing with an example using an Actix-Web API handler. Consider an Actix-Web handler that returns a simple JSON response for a health checker endpoint:


#[get("/api/healthchecker")]
async fn health_checker_handler() -> impl Responder {
    const MESSAGE: &str = "Complete Restful API in Rust";

    HttpResponse::Ok().json(serde_json::json!({"status": "success", "message": MESSAGE}))
}

Now, let’s create a unit test for this handler using Actix-Web’s built-in testing framework. We’ll test whether the handler returns the expected JSON response:


#[cfg(test)]
mod tests {
    use super::*;
    use actix_web::{http::StatusCode, test, App};

    #[actix_web::test]
    async fn test_health_checker_handler() {
        // Initialize the test service with the handler.
        let app = test::init_service(App::new().service(health_checker_handler)).await;

        // Create a test request for the health checker endpoint.
        let req = test::TestRequest::get()
            .uri("/api/healthchecker")
            .to_request();

        // Send the request to the handler and get the response.
        let resp = test::call_service(&app, req).await;

        // Assert that the response status code is OK (200).
        assert_eq!(resp.status(), StatusCode::OK);

        // Read and parse the response body.
        let body = test::read_body(resp).await;
        let expected_json =
            serde_json::json!({"status": "success", "message": "Complete Restful API in Rust"});

        // Assert that the response body matches the expected JSON.
        assert_eq!(body, serde_json::to_string(&expected_json).unwrap());
    }
}

We’ve organized the unit test within a tests module, utilizing the #[cfg(test)] attribute. This attribute ensures that the enclosed code block is compiled and included when running tests while excluding it from the final release or production build.

Let’s delve into the components of the test and understand the structure we’ll be following throughout our testing process:

  • Asynchronous Test Function: We begin by marking our test function with the #[actix_web::test] attribute, indicating its asynchronous nature.
  • Test Service Initialization: Using the Actix-Web testing framework’s init_service function, we initialize the test service by providing the app with the specific handler under test.
  • Creating a Test Request: We generate a test request tailored for the health checker endpoint using the TestRequest::get() method.
  • Invoking the Handler: With the test::call_service() method, we send the generated request to the handler and obtain the resulting response.
  • Assertion of Response Status: Using the assert_eq! macro, we verify that the response status code matches the expected StatusCode::OK (200).
  • Parsing the Response Body: Using the test::read_body method, we extract and parse the response body.
  • Expected JSON Comparison: We define the expected JSON response in the expected_json variable and use another assert_eq! to validate its correspondence with the actual response body.

Why Should We Perform Unit Testing on Our API?

Testing is an integral part of software engineering. However, for beginners, writing test cases to ensure that the code we create performs exactly as expected can be daunting and time-consuming. While testing might be forgone in small projects, as applications scale, the risk of encountering issues after pushing new features to production grows significantly.

Professional developers strongly advocate for writing extensive unit tests to avoid post-production anxiety. By doing so, you can reap the following benefits:

  1. Early Bug Detection: Unit tests catch bugs and defects during the early stages of development, reducing the effort and cost required to rectify them later.
  2. Isolation: Unit tests focus on testing individual components or units in isolation. This helps identify issues within specific parts of the codebase.
  3. Verification: Unit tests ensure that each code unit functions as intended, guaranteeing that the code behaves according to its specifications.
  4. Refactoring Confidence: Unit tests provide confidence when refactoring code. They allow you to modify or enhance code while ensuring that existing functionality remains intact.
  5. Documentation: Unit tests act as live documentation, showcasing how components are meant to be utilized and offering examples of their expected behaviour.
  6. Maintainability: Well-written unit tests make it easier to maintain and extend the codebase over time. Tests act as a safety net when making modifications.
  7. Collaboration: Unit tests facilitate collaboration among team members. New developers can use tests to understand existing code and contribute without fear of breaking things.

Writing Unit Tests for the JWT Utility Functions

In this section, we’ll explore how to write unit tests for the JWT utility functions in the src/utils/token.rs module. These functions are responsible for signing and decoding JSON Web Tokens (JWTs) using the jsonwebtoken crate. JWTs are widely used for authentication and authorization purposes in web applications.

Let’s take a closer look at the utility functions and the corresponding unit tests:

Utility Functions (src/utils/token.rs):


#[derive(Debug, Serialize, Deserialize)]
pub struct TokenClaims {
    pub sub: String,
    pub iat: usize,
    pub exp: usize,
}

pub fn create_token(
    user_id: &str,
    secret: &[u8],
    expires_in_seconds: i64,
) -> Result<String, jsonwebtoken::errors::Error> {
    if user_id.is_empty() {
        return Err(jsonwebtoken::errors::ErrorKind::InvalidSubject.into());
    }

    let now = Utc::now();
    let iat = now.timestamp() as usize;
    let exp = (now + Duration::minutes(expires_in_seconds)).timestamp() as usize;
    let claims: TokenClaims = TokenClaims {
        sub: user_id.to_string(),
        exp,
        iat,
    };

    encode(
        &Header::default(),
        &claims,
        &EncodingKey::from_secret(secret),
    )
}

pub fn decode_token<T: Into<String>>(token: T, secret: &[u8]) -> Result<String, HttpError> {
    let decoded = decode::<TokenClaims>(
        &token.into(),
        &DecodingKey::from_secret(secret),
        &Validation::new(Algorithm::HS256),
    );
    match decoded {
        Ok(token) => Ok(token.claims.sub),
        Err(_) => Err(HttpError::new(ErrorMessage::InvalidToken.to_string(), 401)),
    }
}

Unit Tests (src/utils/token.rs):


#[cfg(test)]
mod tests {

    use super::*;

    #[test]
    fn test_create_and_decoded_valid_token() {
        let user_id = "user123";
        let secret = b"my-secret-key";

        let token = create_token(user_id, secret, 60).unwrap();
        let decoded_user_id = decode_token(&token, secret).unwrap();

        assert_eq!(decoded_user_id, user_id);
    }

    #[test]
    fn test_create_token_with_empty_user_id() {
        let user_id = "";
        let secret = b"my-secret-key";

        let result = create_token(user_id, secret, 60);

        assert!(result.is_err());
        assert_eq!(
            result.unwrap_err().into_kind(),
            jsonwebtoken::errors::ErrorKind::InvalidSubject
        )
    }

    #[test]
    fn test_decoded_invalid_token() {
        let secret = b"my-secret-key";
        let invalid_token = "invalid-token";

        let result = decode_token(invalid_token, secret);

        assert!(result.is_err());
        assert_eq!(
            result.clone().unwrap_err().message,
            ErrorMessage::InvalidToken.to_string()
        );
        assert_eq!(result.unwrap_err().status, 401);
    }

    #[test]
    fn test_decode_expired_token() {
        let secret = b"my-secret-key";
        let expired_token = create_token("user123", secret, -60).unwrap();

        let result = decode_token(expired_token, secret);

        assert!(result.is_err());
        assert_eq!(
            result.clone().unwrap_err().message,
            ErrorMessage::InvalidToken.to_string()
        );
        assert_eq!(result.unwrap_err().status, 401);
    }
}

In the unit tests, we cover different scenarios to ensure the correctness and reliability of the JWT utility functions. Here’s what each test case accomplishes:

  1. test_create_and_decode_valid_token: This test case creates a valid JWT token using the create_token function and then decodes it using the decode_token function. It asserts that the decoded user ID matches the original user ID, ensuring that token creation and decoding work correctly.
  2. test_create_token_with_empty_user_id: This test case validates the behaviour of the create_token function when an empty user ID is provided. It checks whether the function correctly returns an error of type ErrorKind::InvalidSubject.
  3. test_decode_invalid_token: This test case simulates decoding an invalid token using the decode_token function. It asserts that the function returns an error with the expected message and status code.
  4. test_decode_expired_token: This test case verifies the behaviour of the decode_token function when decoding an expired token. It ensures that the function correctly handles expired tokens and returns an error with the appropriate message and status code.

Writing Unit Tests for the Password Utility Functions

Moving on, let’s focus on writing unit tests for the password utility functions in the src/utils/password.rs module. These utility functions are responsible for hashing and comparing passwords using the argon2 crate, which is a popular password-hashing algorithm.

Let’s delve into the details of the password utility functions and the corresponding unit tests:

Utility Functions (src/utils/password.rs):


const MAX_PASSWORD_LENGTH: usize = 64;

pub fn hash(password: impl Into<String>) -> Result<String, ErrorMessage> {
    let password = password.into();

    if password.is_empty() {
        return Err(ErrorMessage::EmptyPassword);
    }

    if password.len() > MAX_PASSWORD_LENGTH {
        return Err(ErrorMessage::ExceededMaxPasswordLength(MAX_PASSWORD_LENGTH));
    }

    let salt = SaltString::generate(&mut OsRng);
    let hashed_password = Argon2::default()
        .hash_password(password.as_bytes(), &salt)
        .map_err(|_| ErrorMessage::HashingError)?
        .to_string();

    Ok(hashed_password)
}

pub fn compare(password: &str, hashed_password: &str) -> Result<bool, ErrorMessage> {
    if password.is_empty() {
        return Err(ErrorMessage::EmptyPassword);
    }

    if password.len() > MAX_PASSWORD_LENGTH {
        return Err(ErrorMessage::ExceededMaxPasswordLength(MAX_PASSWORD_LENGTH));
    }

    let parsed_hash =
        PasswordHash::new(hashed_password).map_err(|_| ErrorMessage::InvalidHashFormat)?;

    let password_matches = Argon2::default()
        .verify_password(password.as_bytes(), &parsed_hash)
        .map_or(false, |_| true);

    Ok(password_matches)
}

Unit Tests (src/utils/password.rs):


#[cfg(test)]
mod tests {
    use super::*;
    use crate::error::ErrorMessage;

    fn setup_test() -> (String, String) {
        let password = "password123";
        let hashed_password = hash(password).unwrap();
        (password.to_string(), hashed_password)
    }

    #[test]
    fn test_compare_hashed_passwords_should_return_true() {
        let (password, hashed_password) = setup_test();

        assert_eq!(compare(&password, &hashed_password).unwrap(), true);
    }

    #[test]
    fn test_compare_hashed_passwords_should_return_false() {
        let (_, hashed_password) = setup_test();

        assert_eq!(compare("wrongpassword", &hashed_password).unwrap(), false);
    }

    #[test]
    fn test_compare_empty_password_should_return_fail() {
        let (_, hashed_password) = setup_test();

        assert_eq!(
            compare("", &hashed_password).unwrap_err(),
            ErrorMessage::EmptyPassword
        )
    }

    #[test]
    fn test_compare_long_password_should_return_fail() {
        let (_, hashed_password) = setup_test();

        let long_password = "a".repeat(1000);
        assert_eq!(
            compare(&long_password, &hashed_password).unwrap_err(),
            ErrorMessage::ExceededMaxPasswordLength(MAX_PASSWORD_LENGTH)
        );
    }

    #[test]
    fn test_compare_invalid_hash_should_fail() {
        let invalid_hash = "invalid-hash";

        assert_eq!(
            compare("password123", invalid_hash).unwrap_err(),
            ErrorMessage::InvalidHashFormat
        )
    }

    #[test]
    fn test_hash_empty_password_should_fail() {
        let result = hash("");

        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), ErrorMessage::EmptyPassword)
    }

    #[test]
    fn test_hash_long_password_should_fail() {
        let result = hash("a".repeat(1000));

        assert!(result.is_err());
        assert_eq!(
            result.unwrap_err(),
            ErrorMessage::ExceededMaxPasswordLength(MAX_PASSWORD_LENGTH)
        );
    }
}

In the unit tests, we ensure that the password utility functions work as expected in various scenarios. Here’s what each test case covers:

  1. test_compare_hashed_passwords_should_return_true: This test case verifies that comparing the correct password with its hashed counterpart using the compare function should return true.
  2. test_compare_hashed_passwords_should_return_false: This test case checks that comparing an incorrect password with the hashed password using the compare function should return false.
  3. test_compare_empty_password_should_return_fail: This test case examines the behaviour of the compare function when compared with an empty password, confirming that it returns the appropriate error.
  4. test_compare_long_password_should_return_fail: This test case tests the compare function’s response when compared with a long password, ensuring it returns the expected error.
  5. test_compare_invalid_hash_should_fail: This test case validates the behaviour of the compare function when given an invalid hash, making sure it returns the correct error.
  6. test_hash_empty_password_should_fail: This test case assesses the hash function’s response when hashing an empty password, verifying that it returns the appropriate error.
  7. test_hash_long_password_should_fail: This test case examines the hash function’s behaviour when hashing a long password, ensuring it returns the expected error.

Creating Stubs

This article isn’t primarily a theoretical overview of testing, so I’m assuming you’re already familiar with the concept of stubs. However, if you need a quick refresher, I’ve got you covered.

In software development, stubs serve as placeholders or mock implementations for specific functionalities or components. Stubs, particularly in unit testing, offer the advantage of isolating and scrutinizing specific segments of your codebase without relying on external dependencies.

With this understanding, let’s delve into the process of creating a couple of stubs using the code provided in src/utils/test_utils.rs. This module contains helper functions to initialize test users and configurations for unit testing purposes.

src/utils/test_utils.rs


use sqlx::{Pool, Postgres};
use uuid::Uuid;

use crate::{
    config::Config,
    db::{DBClient, UserExt},
};

#[allow(dead_code)]
pub struct TestUser {
    name: &'static str,
    email: &'static str,
    password: &'static str,
}

#[allow(dead_code)]
pub async fn init_test_users(pool: &Pool<Postgres>) -> (Uuid, Uuid, Uuid) {
    let db_client = DBClient::new(pool.clone());

    let users: Vec<TestUser> = vec![
        TestUser {
            name: "John Doe",
            email: "johndoe@gmail.com",
            password: "password1234",
        },
        TestUser {
            name: "Nico Smith",
            email: "nicosmith@gmail.com",
            password: "123justgetit",
        },
        TestUser {
            name: "Michelle Like",
            email: "michellelike@gmail.com",
            password: "mostsecurepass",
        },
    ];

    let mut user_ids = vec![];

    for user_data in users {
        let user = db_client
            .save_user(user_data.name, user_data.email, user_data.password)
            .await
            .unwrap();
        user_ids.push(user.id);
    }

    (
        user_ids[0].clone(),
        user_ids[1].clone(),
        user_ids[2].clone(),
    )
}

#[allow(dead_code)]
pub fn get_test_config() -> Config {
    Config {
        database_url: "".to_string(),
        jwt_secret: "my-jwt-secret".to_string(),
        jwt_maxage: 60,
        port: 8000,
    }
}

The provided code demonstrates the creation of stubs for initializing test users and retrieving test configurations. Let’s break down what these stubs do:

  1. init_test_users: This function initializes test users and returns their IDs. It simulates saving test user data to a database by using a mock DBClient instance. The function iterates through a list of TestUser instances and saves them as users, collecting their IDs. This stub helps in creating a controlled environment for testing user-related functionality without relying on a real database.
  2. get_test_config: This function returns a test configuration. It returns a Config struct with dummy values. Stubs like these are beneficial when you want to provide a predefined configuration for testing without using actual configuration files or values.

Writing Unit Tests for the Database Access Layer

At this point, we are ready to create functions or services that will interface directly with the database. By extracting the database access code into distinct modules, we can maintain clean route controllers while ensuring they remain insulated from direct database communication.

Our next step involves not only creating these functions but also writing unit tests for the database access layer to ensure its accuracy and reliability.

Let’s dive into the details of the provided code in src/db.rs, which includes the structure and traits used to interact with the database:


#[derive(Debug, Clone)]
pub struct DBClient {
    pool: Pool<Postgres>,
}

impl DBClient {
    pub fn new(pool: Pool<Postgres>) -> Self {
        DBClient { pool }
    }
}

#[async_trait]
pub trait UserExt {
    async fn get_user(
        &self,
        user_id: Option<Uuid>,
        name: Option<&str>,
        email: Option<&str>,
    ) -> Result<Option<User>, sqlx::Error>;
    async fn get_users(&self, page: u32, limit: usize) -> Result<Vec<User>, sqlx::Error>;
    async fn save_user<T: Into<String> + Send>(
        &self,
        name: T,
        email: T,
        password: T,
    ) -> Result<User, sqlx::Error>;
    async fn save_admin_user<T: Into<String> + Send>(
        &self,
        name: T,
        email: T,
        password: T,
    ) -> Result<User, sqlx::Error>;
}

#[async_trait]
impl UserExt for DBClient {
    async fn get_user(
        &self,
        user_id: Option<uuid::Uuid>,
        name: Option<&str>,
        email: Option<&str>,
    ) -> Result<Option<User>, sqlx::Error> {
        let mut user: Option<User> = None;

        if let Some(user_id) = user_id {
            user = sqlx::query_as!(User, r#"SELECT id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole" FROM users WHERE id = $1"#, user_id)
                .fetch_optional(&self.pool)
                .await?;
        } else if let Some(name) = name {
            user = sqlx::query_as!(User, r#"SELECT id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole" FROM users WHERE name = $1"#, name)
                .fetch_optional(&self.pool)
                .await?;
        } else if let Some(email) = email {
            user = sqlx::query_as!(User, r#"SELECT id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole" FROM users WHERE email = $1"#, email)
                .fetch_optional(&self.pool)
                .await?;
        }

        Ok(user)
    }

    async fn get_users(&self, page: u32, limit: usize) -> Result<Vec<User>, sqlx::Error> {
        let offset = (page - 1) * limit as u32;

        let users = sqlx::query_as!(
            User,
            r#"SELECT id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole" FROM users
            LIMIT $1 OFFSET $2"#,
            limit as i64,
            offset as i64
        )
        .fetch_all(&self.pool)
        .await?;

        Ok(users)
    }

    async fn save_user<T: Into<String> + Send>(
        &self,
        name: T,
        email: T,
        password: T,
    ) -> Result<User, sqlx::Error> {
        let user = sqlx::query_as!(
            User,
            r#"INSERT INTO users (name, email, password) VALUES ($1, $2, $3) RETURNING id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole""#,
            name.into(),
            email.into(),
            password.into()
        )
        .fetch_one(&self.pool)
        .await?;

        Ok(user)
    }

    async fn save_admin_user<T: Into<String> + Send>(
        &self,
        name: T,
        email: T,
        password: T,
    ) -> Result<User, sqlx::Error> {
        let user = sqlx::query_as!(
            User,
            r#"INSERT INTO users (name, email, password, role) VALUES ($1, $2, $3, $4) RETURNING id,name, email, password, photo,verified,created_at,updated_at,role as "role: UserRole""#,
            name.into(),
            email.into(),
            password.into(),
            UserRole::Admin as UserRole
        )
        .fetch_one(&self.pool)
        .await?;

        Ok(user)
    }
}

Accompanying these functionalities are the unit tests that focus on the database access layer, residing within the tests module in the same file:


#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::test_utils::init_test_users;

    #[sqlx::test]
    async fn test_get_user_by_id(pool: Pool<Postgres>) {
        let (id_one, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let user = db_client
            .get_user(Some(id_one), None, None)
            .await
            .unwrap_or_else(|err| panic!("Failed to get user by id: {}", err))
            .expect("User not found");

        assert_eq!(user.id, id_one);
    }

    #[sqlx::test]
    async fn test_get_user_by_name(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let name_to_find = "Nico Smith";

        let user = db_client
            .get_user(None, Some(name_to_find), None)
            .await
            .unwrap_or_else(|err| panic!("Failed to get user by name: {}", err))
            .expect("User not found");

        assert_eq!(user.name, name_to_find);
    }

    #[sqlx::test]
    async fn test_get_user_by_nonexistent_name(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let name = "Nonexistent Name";

        let user = db_client
            .get_user(None, Some(name), None)
            .await
            .expect("Failed to get user by name");

        assert!(user.is_none(), "Expected user to be None");
    }

    #[sqlx::test]
    async fn test_get_user_by_email(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let email = "johndoe@gmail.com";

        let user = db_client
            .get_user(None, None, Some(email))
            .await
            .expect("Failed to get user by email")
            .expect("User not found");

        assert_eq!(user.email, email);
    }

    #[sqlx::test]
    async fn test_get_user_by_nonexistent_email(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let email = "nonexistent@example.com";

        let user = db_client.get_user(None, None, Some(email)).await.unwrap();

        assert!(user.is_none());
    }

    #[sqlx::test]
    async fn test_get_users(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let users = db_client.get_users(1, 10).await.unwrap();

        assert_eq!(users.len(), 3);
    }

    #[sqlx::test]
    async fn test_save_user(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);
        let name = "Peace Jocy";
        let email = "peacejocy@hotmail.com";
        let password = "newPassword";

        db_client.save_user(name, email, password).await.unwrap();

        let user = db_client
            .get_user(None, Some(name), None)
            .await
            .unwrap()
            .unwrap();

        assert_eq!(user.email, email);
        assert_eq!(user.name, name);
    }

    #[sqlx::test]
    async fn test_save_user_but_email_is_taken(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let name = "John Doe";
        let email = "johndoe@gmail.com";
        let password = "randompass123";

        let saved_result = db_client.save_user(name, email, password).await;

        match saved_result {
            Err(sqlx::Error::Database(db_err)) if db_err.is_unique_violation() => {
                // Unique constraint violation detected, test passes
            }
            _ => {
                assert!(false, "Expected unique constraint violation error");
            }
        }
    }

    #[sqlx::test]
    async fn test_save_user_with_long_name_fails(pool: Pool<Postgres>) {
        init_test_users(&pool).await;
        let db_client = DBClient::new(pool);

        let long_name = "a".repeat(150);
        let email = "email@example.com";
        let password = "newPassword";

        let saved_result = db_client
            .save_user(long_name.as_str(), email, password)
            .await;

        assert!(saved_result.is_err(), "Expected save to fail");
    }
}

In these tests, we employ the #[sqlx::test] attribute to mark asynchronous test functions that interact with the database. This attribute ensures that database connections are properly managed during testing.

These tests cover scenarios such as retrieving users by different criteria (user ID, name, email), saving new users, and handling edge cases like duplicate email addresses and exceeding maximum name lengths. The init_test_users function from the test_utils module helps set up the initial test users for these unit tests.

Writing Unit Tests for the Middleware Guard

Security is paramount in any API, and our application is no exception. A robust security mechanism is essential to safeguard sensitive data and ensure authorized access. In our application, we’ve used a combination of Actix-Web extractors and JSON Web Tokens (JWTs) to establish a robust user authentication and authorization framework.

The auth.rs module contains the middleware guard implementation responsible for handling user authentication and authorization. It includes structs like Authenticated and RequireAuth, along with their associated implementations. The middleware is responsible for enforcing authentication and checking roles before granting access to protected endpoints.

src/extractors/auth.rs


pub struct Authenticated(User);

impl FromRequest for Authenticated {
    type Error = actix_web::Error;
    type Future = Ready<Result<Self, Self::Error>>;

    fn from_request(
        req: &actix_web::HttpRequest,
        _payload: &mut actix_web::dev::Payload,
    ) -> Self::Future {
        let value = req.extensions().get::<User>().cloned();
        let result = match value {
            Some(user) => Ok(Authenticated(user)),
            None => Err(ErrorInternalServerError(HttpError::server_error(
                "Authentication Error",
            ))),
        };
        ready(result)
    }
}

impl std::ops::Deref for Authenticated {
    type Target = User;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

pub struct RequireAuth {
    pub allowed_roles: Rc<Vec<UserRole>>,
}

impl RequireAuth {
    pub fn allowed_roles(allowed_roles: Vec<UserRole>) -> Self {
        RequireAuth {
            allowed_roles: Rc::new(allowed_roles),
        }
    }
}

impl<S> Transform<S, ServiceRequest> for RequireAuth
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Transform = AuthMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(AuthMiddleware {
            service: Rc::new(service),
            allowed_roles: self.allowed_roles.clone(),
        }))
    }
}

pub struct AuthMiddleware<S> {
    service: Rc<S>,
    allowed_roles: Rc<Vec<UserRole>>,
}

impl<S> Service<ServiceRequest> for AuthMiddleware<S>
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, actix_web::Error>>;

    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let token = req
            .cookie("token")
            .map(|c| c.value().to_string())
            .or_else(|| {
                req.headers()
                    .get(http::header::AUTHORIZATION)
                    .map(|h| h.to_str().unwrap().split_at(7).1.to_string())
            });

        if token.is_none() {
            let json_error = ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::TokenNotProvided.to_string(),
            };
            return Box::pin(ready(Err(ErrorUnauthorized(json_error))));
        }

        let app_state = req.app_data::<web::Data<AppState>>().unwrap();
        let user_id = match utils::token::decode_token(
            &token.unwrap(),
            app_state.env.jwt_secret.as_bytes(),
        ) {
            Ok(id) => id,
            Err(e) => {
                return Box::pin(ready(Err(ErrorUnauthorized(ErrorResponse {
                    status: "fail".to_string(),
                    message: e.message,
                }))))
            }
        };

        let cloned_app_state = app_state.clone();
        let allowed_roles = self.allowed_roles.clone();
        let srv = Rc::clone(&self.service);

        async move {
            let user_id = uuid::Uuid::parse_str(user_id.as_str()).unwrap();
            let result = cloned_app_state
                .db_client
                .get_user(Some(user_id.clone()), None, None)
                .await
                .map_err(|e| ErrorInternalServerError(HttpError::server_error(e.to_string())))?;

            let user = result.ok_or(ErrorUnauthorized(ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::UserNoLongerExist.to_string(),
            }))?;

            // Check if user's role matches the required role
            if allowed_roles.contains(&user.role) {
                req.extensions_mut().insert::<User>(user);
                let res = srv.call(req).await?;
                Ok(res)
            } else {
                let json_error = ErrorResponse {
                    status: "fail".to_string(),
                    message: ErrorMessage::PermissionDenied.to_string(),
                };
                Err(ErrorForbidden(json_error))
            }
        }
        .boxed_local()
    }
}

Now, let’s take a look at how we can create unit tests to ensure that the authentication and authorization logic of our API works as expected.


#[cfg(test)]
mod tests {
    use actix_web::{cookie::Cookie, get, test, App, HttpResponse};
    use sqlx::{Pool, Postgres};

    use crate::{
        db::DBClient,
        extractors::auth::RequireAuth,
        utils::{password, test_utils::get_test_config, token},
    };

    use super::*;

    #[get(
        "/",
        wrap = "RequireAuth::allowed_roles(vec![UserRole::User, UserRole::Moderator, UserRole::Admin])"
    )]
    async fn handler_with_requireauth() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[get("/", wrap = "RequireAuth::allowed_roles(vec![UserRole::Admin])")]
    async fn handler_with_requireonlyadmin() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[sqlx::test]
    async fn test_auth_middelware_valid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }

    #[sqlx::test]
    async fn test_auth_middelware_valid_token_with_cookie(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let req = test::TestRequest::default()
            .cookie(Cookie::new("token", token))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }

    #[sqlx::test]
    async fn test_auth_middleware_missing_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default().to_request();
        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::TokenNotProvided.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_auth_middleware_invalid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", "invalid_token"),
            ))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_auth_middleware_access_admin_only_endpoint_fail(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireonlyadmin),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::FORBIDDEN;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::PermissionDenied.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_auth_middleware_access_admin_only_endpoint_success(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();
        let user = db_client
            .save_admin_user("John Doe", "johndoe@gmail.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireonlyadmin),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }
}

In these unit tests, we’ve covered various scenarios related to authentication and authorization:

  1. Valid Token Provided, Access Allowed: We simulate a scenario where a valid token is provided, and the roles are correctly configured. This should grant access to the endpoint.
  2. Valid Token Provided via Cookie, Access Allowed: Similar to the previous test, but we provide the token through a cookie header.
  3. Token Missing, Access Denied: In this case, no token is provided, leading to an authentication failure and access being denied.
  4. Invalid Token Provided, Access Denied: We test the scenario where an invalid token is provided, leading to access denial.
  5. Access to Admin-Only Endpoint Denied: This test ensures that an attempt to access an admin-only endpoint with a non-admin user results in access being denied.
  6. Access to Admin-Only Endpoint Allowed: Conversely, we check that a user with the admin role is granted access to the admin-only endpoint.

Writing Unit Tests for the Authentication Handlers

Now, let’s delve into creating unit tests for the authentication handlers defined in the src/scopes/auth.rs module. These tests ensure the reliability and correctness of the authentication and registration processes within the API. But before we proceed, here is the code for the src/scopes/auth.rs module:

src/scopes/auth.rs


pub async fn register(
    app_state: web::Data<AppState>,
    body: web::Json<RegisterUserDto>,
) -> Result<HttpResponse, HttpError> {
    body.validate()
        .map_err(|e| HttpError::bad_request(e.to_string()))?;

    let hashed_password =
        password::hash(&body.password).map_err(|e| HttpError::server_error(e.to_string()))?;

    let result = app_state
        .db_client
        .save_user(&body.name, &body.email, &hashed_password)
        .await;

    match result {
        Ok(user) => Ok(HttpResponse::Created().json(UserResponseDto {
            status: "success".to_string(),
            data: UserData {
                user: FilterUserDto::filter_user(&user),
            },
        })),
        Err(sqlx::Error::Database(db_err)) => {
            if db_err.is_unique_violation() {
                Err(HttpError::unique_constraint_voilation(
                    ErrorMessage::EmailExist,
                ))
            } else {
                Err(HttpError::server_error(db_err.to_string()))
            }
        }
        Err(e) => Err(HttpError::server_error(e.to_string())),
    }
}

pub async fn login(
    app_state: web::Data<AppState>,
    body: web::Json<LoginUserDto>,
) -> Result<HttpResponse, HttpError> {
    body.validate()
        .map_err(|e| HttpError::bad_request(e.to_string()))?;

    let result = app_state
        .db_client
        .get_user(None, None, Some(&body.email))
        .await
        .map_err(|e| HttpError::server_error(e.to_string()))?;

    let user = result.ok_or(HttpError::unauthorized(ErrorMessage::WrongCredentials))?;

    let password_matches = password::compare(&body.password, &user.password)
        .map_err(|_| HttpError::unauthorized(ErrorMessage::WrongCredentials))?;

    if password_matches {
        let token = token::create_token(
            &user.id.to_string(),
            &app_state.env.jwt_secret.as_bytes(),
            app_state.env.jwt_maxage,
        )
        .map_err(|e| HttpError::server_error(e.to_string()))?;
        let cookie = Cookie::build("token", token.to_owned())
            .path("/")
            .max_age(ActixWebDuration::new(60 * &app_state.env.jwt_maxage, 0))
            .http_only(true)
            .finish();

        Ok(HttpResponse::Ok()
            .cookie(cookie)
            .json(UserLoginResponseDto {
                status: "success".to_string(),
                token,
            }))
    } else {
        Err(HttpError::unauthorized(ErrorMessage::WrongCredentials))
    }
}

pub async fn logout() -> impl Responder {
    let cookie = Cookie::build("token", "")
        .path("/")
        .max_age(ActixWebDuration::new(-1, 0))
        .http_only(true)
        .finish();

    HttpResponse::Ok()
        .cookie(cookie)
        .json(json!({"status": "success"}))
}

Below are the unit tests for the authentication handlers, and they are provided in the tests module within the same file:

src/scopes/auth.rs


#[cfg(test)]
mod tests {
    use actix_web::{http, test, App};
    use sqlx::{Pool, Postgres};

    use crate::{db::DBClient, error::ErrorResponse, utils::test_utils::get_test_config};

    use super::*;

    #[sqlx::test]
    async fn test_register_valid_user(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/register", web::post().to(register))),
        )
        .await;

        let name = "John Doe".to_string();
        let email = "john@example.com".to_string();
        let password = "password123".to_string();
        let password_confirm = "password123".to_string();
        let req = test::TestRequest::post()
            .uri("/api/auth/register")
            .set_json(RegisterUserDto {
                name: name.clone(),
                email: email.clone(),
                password: password.clone(),
                password_confirm: password_confirm.clone(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::CREATED);

        let body = test::read_body(resp).await;
        let user_response: UserResponseDto =
            serde_json::from_slice(&body).expect("Failed to deserialize user response from JSON");
        let user = &user_response.data.user;

        assert_eq!(user.name, name);
        assert_eq!(user.email, email);
    }

    #[sqlx::test]
    async fn test_register_duplicate_email(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        db_client
            .save_user("John", "john@example.com", "password123")
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/register", web::post().to(register))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/register")
            .set_json(RegisterUserDto {
                name: "John Doe".to_string(),
                email: "john@example.com".to_string(),
                password: "password123".to_string(),
                password_confirm: "password123".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::CONFLICT);

        let body = test::read_body(resp).await;
        let expected_message = "An User with this email already exists";

        let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        let actual_message = body_json["message"].as_str().unwrap();

        assert_eq!(actual_message, expected_message);
    }

    #[sqlx::test]
    async fn test_login_valid_credentials(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .set_json(LoginUserDto {
                email: "john@example.com".to_string(),
                password: "password123".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);

        let body = test::read_body(resp).await;

        let body_json: UserLoginResponseDto = serde_json::from_slice(&body).unwrap();

        assert!(!body_json.token.is_empty());
    }

    #[sqlx::test]
    async fn test_login_valid_credentials_receive_cookie(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .set_json(LoginUserDto {
                email: "john@example.com".to_string(),
                password: "password123".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        let token_cookie = resp
            .response()
            .cookies()
            .find(|cookie| cookie.name() == "token");

        assert!(token_cookie.is_some());
    }

    #[sqlx::test]
    async fn test_login_with_nonexistent_user_credentials(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .set_json(LoginUserDto {
                email: "john@example.com".to_string(),
                password: "password123".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);

        let body = test::read_body(resp).await;
        let expected_message = "Email or password is wrong";

        let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        let actual_message = body_json["message"].as_str().unwrap();

        assert_eq!(actual_message, expected_message);
    }

    #[sqlx::test]
    async fn test_login_with_wrong_email(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();
        db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .set_json(LoginUserDto {
                email: "wrongemail@example.com".to_string(),
                password: "password123".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);

        let body = test::read_body(resp).await;
        let expected_message = "Email or password is wrong";

        let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        let actual_message = body_json["message"].as_str().unwrap();

        assert_eq!(actual_message, expected_message);
    }

    #[sqlx::test]
    async fn test_login_with_wrong_password(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .set_json(LoginUserDto {
                email: "john@example.com".to_string(),
                password: "wrongpassword".to_string(),
            })
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);

        let body = test::read_body(resp).await;
        let expected_message = "Email or password is wrong";

        let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        let actual_message = body_json["message"].as_str().unwrap();

        assert_eq!(actual_message, expected_message);
    }

    #[sqlx::test]
    async fn test_login_with_no_data(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/login")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::BAD_REQUEST);

        let body = test::read_body(resp).await;
        let body_str = String::from_utf8_lossy(&body);

        let expected_message = "Content type error";

        assert!(body_str.contains(expected_message));
    }

    #[sqlx::test]
    async fn test_login_with_empty_json_object(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route("/login", web::post().to(login))),
        )
        .await;

        let req = test::TestRequest::post()
            .set_json(json!({}))
            .uri("/api/auth/login")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::BAD_REQUEST);

        let body = test::read_body(resp).await;
        let expected_message = "Json deserialize error: missing field";

        let body_str = String::from_utf8_lossy(&body);

        assert!(body_str.contains(expected_message));
    }

    #[sqlx::test]
    async fn test_logout_with_valid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();
        let hashed_password = password::hash("password123").unwrap();
        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route(
                    "/logout",
                    web::post().to(logout).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::post()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .uri("/api/auth/logout")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);

        let body = test::read_body(resp).await;

        let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
        let status_field_value = body_json["status"].as_str().unwrap();

        assert_eq!(status_field_value, "success");
    }

    #[sqlx::test]
    async fn test_logout_with_invalid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route(
                    "/logout",
                    web::post().to(logout).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/logout")
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", "invalid_token"),
            ))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_logout_with_misssing_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route(
                    "/logout",
                    web::post().to(logout).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/logout")
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::TokenNotProvided.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_logout_with_expired_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let user_id = uuid::Uuid::new_v4();
        let expired_token =
            token::create_token(&user_id.to_string(), config.jwt_secret.as_bytes(), -60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/auth").route(
                    "/logout",
                    web::post().to(logout).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::post()
            .uri("/api/auth/logout")
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", expired_token),
            ))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }
}

In these unit tests, we have encompassed a range of scenarios connected to user registration, login, and logout. Given the breadth of test cases covered, I won’t provide detailed explanations. The tests themselves are self-explanatory, and by now, you should be familiar with their structure, which closely resembles the approach we employed for previous tests, including those for the middleware guards.

Writing Unit Tests for the User-Related Handlers

To ensure the reliability and correctness of our user-related route handlers, let’s dive into writing unit tests that evaluate their behaviour. The code below represents the implementation of these route handlers:

src/scopes/users.rs


async fn get_me(user: Authenticated) -> Result<HttpResponse, HttpError> {
    let filtered_user = FilterUserDto::filter_user(&user);

    let response_data = UserResponseDto {
        status: "success".to_string(),
        data: UserData {
            user: filtered_user,
        },
    };

    Ok(HttpResponse::Ok().json(response_data))
}

pub async fn get_users(
    query: web::Query<RequestQueryDto>,
    app_state: web::Data<AppState>,
) -> Result<HttpResponse, HttpError> {
    let query_params: RequestQueryDto = query.into_inner();

    query_params
        .validate()
        .map_err(|e| HttpError::bad_request(e.to_string()))?;

    let page = query_params.page.unwrap_or(1);
    let limit = query_params.limit.unwrap_or(10);

    let users = app_state
        .db_client
        .get_users(page as u32, limit)
        .await
        .map_err(|e| HttpError::server_error(e.to_string()))?;

    Ok(HttpResponse::Ok().json(UserListResponseDto {
        status: "success".to_string(),
        users: FilterUserDto::filter_users(&users),
        results: users.len(),
    }))
}

The following code outlines the unit tests for the user-related handlers. These tests evaluate the behaviour and functionality of the handlers responsible for retrieving user data and managing user listings. The tests are structured using Actix Web’s testing utilities and SQLx for database interaction. They validate scenarios such as valid and invalid token usage, different roles’ access rights, and query parameter variations.

src/scopes/users.rs


#[cfg(test)]
mod tests {
    use actix_web::{http, test, App};
    use sqlx::{Pool, Postgres};

    use crate::{
        db::DBClient,
        error::{ErrorMessage, ErrorResponse},
        utils::{
            password,
            test_utils::{get_test_config, init_test_users},
            token,
        },
    };

    use super::*;

    #[sqlx::test]
    async fn test_get_me_with_valid_token(pool: Pool<Postgres>) {
        let (user_id, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let token =
            token::create_token(&user_id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/users").route(
                    "/me",
                    web::get().to(get_me).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .uri("/api/users/me")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);

        let body = test::read_body(resp).await;

        let user_response: UserResponseDto =
            serde_json::from_slice(&body).expect("Failed to deserialize user response from JSON");
        let user = user_response.data.user;

        assert_eq!(user_id.to_string(), user.id);
    }

    #[sqlx::test]
    async fn test_get_me_with_invalid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/users").route(
                    "/me",
                    web::get().to(get_me).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", "invlaid_token"),
            ))
            .uri("/api/users/me")
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_get_me_with_missing_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/users").route(
                    "/me",
                    web::get().to(get_me).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::get().uri("/api/users/me").to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::TokenNotProvided.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_get_me_with_expired_token(pool: Pool<Postgres>) {
        let (user_id, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let expired_token =
            token::create_token(&user_id.to_string(), config.jwt_secret.as_bytes(), -60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(web::scope("/api/users").route(
                    "/me",
                    web::get().to(get_me).wrap(RequireAuth::allowed_roles(vec![
                        UserRole::User,
                        UserRole::Moderator,
                        UserRole::Admin,
                    ])),
                )),
        )
        .await;

        let req = test::TestRequest::get()
            .uri("/api/users/me")
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", expired_token),
            ))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_all_users_with_valid_token_with_admin_user(pool: Pool<Postgres>) {
        let (_, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();
        let user = db_client
            .save_admin_user("Vivian", "vivian@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(
                    web::scope("/api/users").route(
                        "",
                        web::get()
                            .to(get_users)
                            .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
                    ),
                ),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .uri("/api/users")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);

        let body = test::read_body(resp).await;

        let user_list_response: UserListResponseDto =
            serde_json::from_slice(&body).expect("Failed to deserialize users response from JSON");

        assert_eq!(user_list_response.users.len(), 4);
    }

    #[sqlx::test]
    async fn test_all_users_with_page_one_and_limit_two_query_parameters(pool: Pool<Postgres>) {
        let (_, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();
        let user = db_client
            .save_admin_user("Vivian", "vivian@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(
                    web::scope("/api/users").route(
                        "",
                        web::get()
                            .to(get_users)
                            .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
                    ),
                ),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .uri("/api/users?page=1&limit=2")
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);

        let body = test::read_body(resp).await;

        let user_list_response: UserListResponseDto =
            serde_json::from_slice(&body).expect("Failed to deserialize users response from JSON");

        assert_eq!(user_list_response.users.len(), 2);
    }

    #[sqlx::test]
    async fn test_all_users_with_valid_token_by_regular_user(pool: Pool<Postgres>) {
        let (user_id, _, _) = init_test_users(&pool).await;
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let token =
            token::create_token(&user_id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(
                    web::scope("/api/users").route(
                        "",
                        web::get()
                            .to(get_users)
                            .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
                    ),
                ),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .uri("/api/users")
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::FORBIDDEN;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::PermissionDenied.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_all_users_with_invalid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(
                    web::scope("/api/users").route(
                        "",
                        web::get()
                            .to(get_users)
                            .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
                    ),
                ),
        )
        .await;

        let req = test::TestRequest::get()
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", "invalid_token"),
            ))
            .uri("/api/users")
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_all_users_with_missing_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(
                    web::scope("/api/users").route(
                        "",
                        web::get()
                            .to(get_users)
                            .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
                    ),
                ),
        )
        .await;

        let req = test::TestRequest::get().uri("/api/users").to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::TokenNotProvided.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }
}

Conclusion

Congratulations if you’ve made it to the end. Throughout this article, we’ve learned about unit tests, their necessity, and have created unit tests for our API project, covering various scenarios. I hope you found it enjoyable and helpful. If you have any questions or feedback, please don’t hesitate to leave them in the comments below.