In this article, you will learn how to implement JWT Authentication and Authorization in Rust using Actix-Web extractors (middleware). Let’s outline our path: we’ll start by examining JavaScript implementations, and then gradually transition to Rust implementations. This transition will highlight the complexity of the Rust approach compared to its JavaScript counterpart. Along the way, we’ll also delve into Actix-Web’s management of middleware.

In addition to using extractors to secure private routes, I’ll provide you with unit tests that cover various scenarios. Although I won’t delve into the code for API error handling, utility functions, or database models within this article to ensure clarity and simplicity, you can view their implementations directly from the source code linked in this article at https://github.com/wpcodevo/complete-restful-api-in-rust.

Articles in This Series

  1. Building a Rust API with Unit Testing in Mind
  2. How to Add Swagger UI, Redoc and RapiDoc to a Rust API
  3. JWT Authentication and Authorization in a Rust API using Actix-Web
  4. How to Write Unit Tests for Your Rust API
  5. Dockerizing a Rust API Project: SQL Database and pgAdmin
  6. Deploy Rust App on VPS with GitHub Actions and Docker
JWT Authentication and Authorization in Rust API using Actix-Web

Middleware Example in TypeScript

Let’s start by looking at how simple it is to write middleware in Node.js, specifically using TypeScript and the Express.js framework. If you’re familiar with JavaScript, you probably know how middleware functions in Node.js. Keep the middleware structure in mind because this is exactly what we’re going to replicate in Rust.

Authentication Middleware Guard

Creating a JavaScript middleware to authenticate users with JSON Web Tokens (JWTs) involves writing just one function. Below is an example of middleware using JWTs to authenticate users. While the exact syntax might vary a bit among different Node.js frameworks, it generally looks something like this:


export const deserializeUser = async (
  req: Request,
  res: Response,
  next: NextFunction
) => {
  try {
    // Extract the token
    let token: string | undefined;
    if (
      req.headers.authorization &&
      req.headers.authorization.startsWith("Bearer")
    ) {
      token = req.headers.authorization.split(" ")[1];
    } else if (req.cookies.token) {
      token = req.cookies.token;
    }

    // Check if token is missing
    if (!token) {
      return next(new AppError("You are not logged in", 401));
    }

    // Validate the token
    const decoded = verifyJwt<{ sub: string }>(token);

    // Check if token is invalid
    if (!decoded) {
      return next(new AppError(`Invalid token or user doesn't exist`, 401));
    }

    // Verify if user still exists

    const user = await findUserById(decoded.sub);
    if (!user) {
      return next(new AppError(`User with that token no longer exist`, 401));
    }

    // Add the user to the res.locals object
    res.locals.user = user;

    next();
  } catch (err: any) {
    next(err);
  }
};

router.use(deserializeUser);

Authorization Middleware Guard

Similarly, creating middleware that authorizes users by providing access to specific roles is also quite straightforward. You only need to create a function that returns another function with the logic for granting access. Here’s an example:


export const restrictTo =
  (...allowedRoles: string[]) =>
  (req: Request, res: Response, next: NextFunction) => {
    // Retrieve user information from previous middleware
    const user = res.locals.user;

    // Check if user's role is allowed to access the resource
    if (!allowedRoles.includes(user.role)) {
      return next(
        new AppError('You are not allowed to perform this action', 403)
      );
    }

    // Proceed to the next middleware if allowed
    next();
  };

  router.use(deserializeUser);
  router.get('/api/users', restrictTo('admin'), getAllUsersHandler);

These two middleware examples show how simple it is to create them. Let’s now delve into how middleware works in Rust, specifically using the Actix-Web framework. If you’re new to Rust, it might take a bit more effort to grasp the reasoning behind middleware creation in the Actix-Web framework. However, with the right tools, you’ll find it easy to understand.

Understanding How Middleware Works in Actix-Web

As with many things done in Rust, creating an Actix-Web middleware can be a bit tricky. It requires implementing the Transform and Service traits, and you might also need an extractor that implements FromRequest to retrieve any extra data stored in the request extension, such as the user object stored in the req.locals object in the TypeScript example above.

Yes, we do need the Transform and Service traits, and possibly an extractor. However, most of the complexity comes from the boilerplate code. Yet, it’s beneficial to grasp what that boilerplate is accomplishing. First, let’s take a step back and gain a broader understanding of what a Service actually represents.

Exploring the Service Trait

The Service trait models a request/response life cycle, where it accepts requests and returns responses. You can think of a service as a middleware function with one argument that does some work and returns the result asynchronously. Consequently, a service can invoke another service, which may be another middleware within the middleware stack or an endpoint handler.

Now, let’s examine the trait itself. The original source file contains numerous comments, but for the sake of conciseness, I’ve shortened them here.


/// Represents a service that processes requests and produces responses.
pub trait Service<Req> {
     /// The response type produced by the service.
    type Response;

    /// The error type that can be produced by the service.
    type Error;

    /// The future type representing the asynchronous response.
    type Future: Future<Output = Result<Self::Response, Self::Error>>;

    /// Returns `Ready` when the service is able to process requests.
    fn poll_ready(&self, ctx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;

    /// The future type representing the asynchronous response.
    fn call(&self, req: Req) -> Self::Future;
}

In Actix, the poll_ready function is used to decide if the service is ready for action. This comes in handy in scenarios where the service needs to control the number of times it can be simultaneously called.

By implementing the poll_ready function, you can ensure that the service is in the appropriate state to process requests. This function helps manage the concurrency and availability of the service, making sure that it’s not overwhelmed by incoming requests.

The call function is where you put the core logic of your middleware or service. It’s responsible for processing incoming requests and generating appropriate responses. Much like the JavaScript example, you can inspect, modify, or work with the request and response objects as needed within this function.

Additionally, if your middleware or service needs to pass some data down the stack to another middleware or the endpoint handler, the call function facilitates that as well. Since Rust is strongly typed, we can’t put extra data anywhere on the request. Instead, Actix has an ‘extension’ mechanism to allow inserting extra data into the request for later retrieval. We’ll get into that as we move forward.

Exploring the Transform Trait

Now that we have some understanding of how the Service trait functions, let’s delve into the role of the Transform trait. In essence, it’s a mechanism to modify or enhance the behaviour of an existing service by creating a new service with added capabilities. However, for our purposes, it’s simpler to view it as a factory.


/// Represents a transformation applied to a service.
pub trait Transform<S, Req> {
    /// The response type produced by the service.
    type Response;

    /// The error type produced by the service.
    type Error;

    /// The `TransformService` value created by this factory.
    type Transform: Service<Req, Response = Self::Response, Error = Self::Error>;

    /// Errors produced while building a transform service.
    type InitError;

    /// The future type representing the asynchronous response.
    type Future: Future<Output = Result<Self::Transform, Self::InitError>>;

    /// Creates and returns a new Transform component asynchronously.
    /// A `Self::Future` representing the asynchronous transformation process.
    fn new_transform(&self, service: S) -> Self::Future;
}

The new_transform method creates and returns a new instance of the middleware Service asynchronously. It takes the original service S as an argument and returns a future that resolves to the transformed service.

Signing and Verifying JSON Web Tokens (JWTs)

Moving on, let’s take a look at the functions responsible for signing and verifying the JWTs. These functions will be utilized in the middleware and unit tests, so it makes a lot of sense to show them.

We have two functions: create_token, which generates a signed JWT and provides it within a Result, and decode_token, which verifies the JWT and returns its payload. Towards the end, I’ve included unit tests to cover various scenarios.


use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};

use crate::error::{ErrorMessage, HttpError};

#[derive(Debug, Serialize, Deserialize)]
pub struct TokenClaims {
    pub sub: String,
    pub iat: usize,
    pub exp: usize,
}

pub fn create_token(
    user_id: &str,
    secret: &[u8],
    expires_in_seconds: i64,
) -> Result<String, jsonwebtoken::errors::Error> {
    if user_id.is_empty() {
        return Err(jsonwebtoken::errors::ErrorKind::InvalidSubject.into());
    }

    let now = Utc::now();
    let iat = now.timestamp() as usize;
    let exp = (now + Duration::minutes(expires_in_seconds)).timestamp() as usize;
    let claims: TokenClaims = TokenClaims {
        sub: user_id.to_string(),
        exp,
        iat,
    };

    encode(
        &Header::default(),
        &claims,
        &EncodingKey::from_secret(secret),
    )
}

pub fn decode_token<T: Into<String>>(token: T, secret: &[u8]) -> Result<String, HttpError> {
    let decoded = decode::<TokenClaims>(
        &token.into(),
        &DecodingKey::from_secret(secret),
        &Validation::new(Algorithm::HS256),
    );
    match decoded {
        Ok(token) => Ok(token.claims.sub),
        Err(_) => Err(HttpError::new(ErrorMessage::InvalidToken.to_string(), 401)),
    }
}

#[cfg(test)]
mod tests {

    use super::*;

    #[test]
    fn test_create_and_decoded_valid_token() {
        let user_id = "user123";
        let secret = b"my-secret-key";

        let token = create_token(user_id, secret, 60).unwrap();
        let decoded_user_id = decode_token(&token, secret).unwrap();

        assert_eq!(decoded_user_id, user_id);
    }

    #[test]
    fn test_create_token_with_empty_user_id() {
        let user_id = "";
        let secret = b"my-secret-key";

        let result = create_token(user_id, secret, 60);

        assert!(result.is_err());
        assert_eq!(
            result.unwrap_err().into_kind(),
            jsonwebtoken::errors::ErrorKind::InvalidSubject
        )
    }

    #[test]
    fn test_decoded_invalid_token() {
        let secret = b"my-secret-key";
        let invalid_token = "invalid-token";

        let result = decode_token(invalid_token, secret);

        assert!(result.is_err());
        assert_eq!(
            result.clone().unwrap_err().message,
            ErrorMessage::InvalidToken.to_string()
        );
        assert_eq!(result.unwrap_err().status, 401);
    }

    #[test]
    fn test_decode_expired_token() {
        let secret = b"my-secret-key";
        let expired_token = create_token("user123", secret, -60).unwrap();

        let result = decode_token(expired_token, secret);

        assert!(result.is_err());
        assert_eq!(
            result.clone().unwrap_err().message,
            ErrorMessage::InvalidToken.to_string()
        );
        assert_eq!(result.unwrap_err().status, 401);
    }
}

Creating a JWT Middleware using Actix-Web Extractor

At this point, we are now ready to create our JWT authenticator middleware, which simply extracts the token included in the request, validates its authenticity, checks the database to confirm whether the user associated with the token still exists, and inserts the user object into the request extension for subsequent middleware in the stack to utilize. The logic applied here mirrors that of the TypeScript implementation.

Creating the Middleware

First, let’s take a look at the middleware service structure. Disregard the boilerplate code and concentrate solely on the code within the call function. Below is the structure of the middleware that we need to implement.


/// Middleware responsible for handling authentication and user information extraction.
pub struct AuthMiddleware<S> {
    service: Rc<S>,
}

impl<S> Service<ServiceRequest> for AuthMiddleware<S>
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, actix_web::Error>>;

    /// Polls the readiness of the wrapped service.
    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    /// Handles incoming requests.
    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Attempt to extract token from cookie or authorization header

        // If token is missing, return unauthorized error

        // Decode token and handle errors

        // Handle user extraction and request processing
        async move {

            // Insert user information into request extensions
            req.extensions_mut().insert::<User>(user);

            // Call the wrapped service to handle the request
            let res = srv.call(req).await?;
            Ok(res)
        }
        .boxed_local()
    }
}

Moving forward, let’s delve into the actual implementation of the middleware. For the sake of simplicity, this article does not include the implementation details of the functions and structs that the middleware relies on. However, if you’re interested in exploring their implementation, you can find them in the source code available at https://github.com/wpcodevo/complete-restful-api-in-rust.


use actix_web::dev::{Service, ServiceRequest, ServiceResponse};
use actix_web::error::{ErrorInternalServerError, ErrorUnauthorized};
use actix_web::{http, web, HttpMessage};
use futures_util::future::{ready, LocalBoxFuture};
use futures_util::FutureExt;
use std::rc::Rc;
use std::task::{Context, Poll};

use crate::db::UserExt;
use crate::error::{ErrorMessage, ErrorResponse, HttpError};
use crate::models::User;
use crate::{utils, AppState};

/// Middleware responsible for handling authentication and user information extraction.
pub struct AuthMiddleware<S> {
    service: Rc<S>,
}

impl<S> Service<ServiceRequest> for AuthMiddleware<S>
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, actix_web::Error>>;

    /// Polls the readiness of the wrapped service.
    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    /// Handles incoming requests.
    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Attempt to extract token from cookie or authorization header
        let token = req
            .cookie("token")
            .map(|c| c.value().to_string())
            .or_else(|| {
                req.headers()
                    .get(http::header::AUTHORIZATION)
                    .map(|h| h.to_str().unwrap().split_at(7).1.to_string())
            });

        // If token is missing, return unauthorized error
        if token.is_none() {
            let json_error = ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::TokenNotProvided.to_string(),
            };
            return Box::pin(ready(Err(ErrorUnauthorized(json_error))));
        }

        let app_state = req.app_data::<web::Data<AppState>>().unwrap();

        // Decode token and handle errors
        let user_id = match utils::token::decode_token(
            &token.unwrap(),
            app_state.env.jwt_secret.as_bytes(),
        ) {
            Ok(id) => id,
            Err(e) => {
                return Box::pin(ready(Err(ErrorUnauthorized(ErrorResponse {
                    status: "fail".to_string(),
                    message: e.message,
                }))))
            }
        };

        let cloned_app_state = app_state.clone();
        let srv = Rc::clone(&self.service);

        // Handle user extraction and request processing
        async move {
            let user_id = uuid::Uuid::parse_str(user_id.as_str()).unwrap();
            let result = cloned_app_state
                .db_client
                .get_user(Some(user_id.clone()), None, None)
                .await
                .map_err(|e| ErrorInternalServerError(HttpError::server_error(e.to_string())))?;

            let user = result.ok_or(ErrorUnauthorized(ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::UserNoLongerExist.to_string(),
            }))?;

            // Insert user information into request extensions
            req.extensions_mut().insert::<User>(user);

            // Call the wrapped service to handle the request
            let res = srv.call(req).await?;
            Ok(res)
        }
        .boxed_local()
    }
}

Quite a lot is happening in the above code, let’s break it down:

  1. Token Extraction: The middleware’s first task is to extract the authentication token. It does this by searching for the token in either a cookie or an authorization header.
  2. Unauthorized Handling: If the token is missing, the middleware promptly responds with an unauthorized error. This ensures that only authenticated users can proceed, enhancing the application’s security.
  3. Decoding Token: Following token extraction, the middleware decodes the token and handles any potential errors that may arise during this process. This step ensures the token’s validity and authenticity.
  4. User Extraction: The middleware uses the decoded token to identify the user associated with it. It queries the application’s state to retrieve user information.
  5. Pass Down User Data: With user data in hand, the middleware inserts the user’s data into the request extension. This makes the user data available to subsequent components in the request pipeline.

Creating a Middleware Factory

Now that we’ve successfully implemented the AuthMiddleware, the next step involves creating a factory object that adheres to the Transform trait. This factory will be named RequireAuth, and its purpose is to utilize the new_transform method to instantiate a middleware service and wrap it in a future using the ready method.

Here’s the code that illustrates the creation of the RequireAuth factory:


/// Middleware factory for requiring authentication.
pub struct RequireAuth;

impl<S> Transform<S, ServiceRequest> for RequireAuth
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Transform = AuthMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    /// Creates and returns a new AuthMiddleware wrapped in a Result.
    fn new_transform(&self, service: S) -> Self::Future {
        // Wrap the AuthMiddleware instance in a Result and return it.
        ready(Ok(AuthMiddleware {
            service: Rc::new(service),
        }))
    }
}

If you take a closer look at the RequireAuth implementation, you’ll notice that most of the types align with those used for the middleware service, with the exception of the Transform trait.

Using the Middleware Factory

With the RequireAuth factory in place, we can now utilize it to secure specific routes by using the wrap method. This method registers the RequireAuth factory as a middleware for the designated routes. Here’s an illustrative example:


/// Defines a scope for user-related endpoints.
pub fn users_scope() -> Scope {
    web::scope("/api/users")
        .route(
            "",
            web::get()
                .to(get_users)
                // 👇 Apply RequireAuth middleware to this route.
                .wrap(RequireAuth),  
        )
        .route(
            "/me",
            web::get().to(get_me)
            // 👇 Apply RequireAuth middleware to this route.
            .wrap(RequireAuth),
        )
}

Writing Unit Tests for the JWT Middleware

With that out of the way, let’s proceed to write a couple of unit tests to test the various scenarios of the JWT middleware.


#[cfg(test)]
mod tests {
    use actix_web::{cookie::Cookie, get, test, App, HttpResponse};
    use sqlx::{Pool, Postgres};

    use crate::{
        db::DBClient,
        utils::{password, test_utils::get_test_config, token},
    };

    use super::*;

    #[get("/", wrap = "RequireAuth")]
    async fn handler_with_requireauth() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[get("/", wrap = "RequireAuth")]
    async fn handler_with_requireonlyadmin() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[sqlx::test]
    async fn test_auth_middelware_valid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }

    #[sqlx::test]
    async fn test_auth_middelware_valid_token_with_cookie(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let req = test::TestRequest::default()
            .cookie(Cookie::new("token", token))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }

    #[sqlx::test]
    async fn test_auth_middleware_missing_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default().to_request();
        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::TokenNotProvided.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_auth_middleware_invalid_token(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireauth),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((
                http::header::AUTHORIZATION,
                format!("Bearer {}", "invalid_token"),
            ))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::UNAUTHORIZED;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::InvalidToken.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }
}

Adding JWT Authorization to the Actix-Web Middleware

Our next step is to add authorization to our middleware. This will enable us to grant access to routes based on the user’s role. Fortunately, this step isn’t as complicated as the authentication, which we have already addressed.

Modifying the Authentication Middleware

As mentioned, if you need to refer to the implementations of certain structs and functions used in the code, you can find the details in the source code of the GitHub repository linked to this article. However, for the UserRole struct, it’s relevant to include it here as we will be using it for user authorization.


use chrono::prelude::*;
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, Serialize, Clone, Copy, sqlx::Type, PartialEq)]
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
pub enum UserRole {
    Admin,
    Moderator,
    User,
}

impl UserRole {
    pub fn to_str(&self) -> &str {
        match self {
            UserRole::Admin => "admin",
            UserRole::User => "user",
            UserRole::Moderator => "moderator",
        }
    }
}

Next, we need to incorporate the allowed roles into the AuthMiddleware. Inside the call function, after retrieving the user’s information from the database, we will check their role against the list of allowed roles. If their role is not found in the vector of allowed roles, a permission denied error will be returned.


/// Middleware responsible for authentication and authorization.
pub struct AuthMiddleware<S> {
    service: Rc<S>,
    allowed_roles: Rc<Vec<UserRole>>,
}

impl<S> Service<ServiceRequest> for AuthMiddleware<S>
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, actix_web::Error>>;

    /// Poll the readiness of the wrapped service.
    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    /// Handle the incoming service request.
    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Extract token from the cookie or authorization header.
        let token = req
            .cookie("token")
            .map(|c| c.value().to_string())
            .or_else(|| {
                req.headers()
                    .get(http::header::AUTHORIZATION)
                    .map(|h| h.to_str().unwrap().split_at(7).1.to_string())
            });

        // If no token provided, return Unauthorized error.
        if token.is_none() {
            let json_error = ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::TokenNotProvided.to_string(),
            };
            return Box::pin(ready(Err(ErrorUnauthorized(json_error))));
        }

        // Get the app state and decode the token.
        let app_state = req.app_data::<web::Data<AppState>>().unwrap();
        let user_id = match utils::token::decode_token(
            &token.unwrap(),
            app_state.env.jwt_secret.as_bytes(),
        ) {
            Ok(id) => id,
            Err(e) => {
                return Box::pin(ready(Err(ErrorUnauthorized(ErrorResponse {
                    status: "fail".to_string(),
                    message: e.message,
                }))))
            }
        };

        // Clone app state and allowed roles for async closure.
        let cloned_app_state = app_state.clone();
        let allowed_roles = self.allowed_roles.clone();
        let srv = Rc::clone(&self.service);

        async move {
            let user_id = uuid::Uuid::parse_str(user_id.as_str()).unwrap();
            let result = cloned_app_state
                .db_client
                .get_user(Some(user_id.clone()), None, None)
                .await
                .map_err(|e| ErrorInternalServerError(HttpError::server_error(e.to_string())))?;

            let user = result.ok_or(ErrorUnauthorized(ErrorResponse {
                status: "fail".to_string(),
                message: ErrorMessage::UserNoLongerExist.to_string(),
            }))?;

            // Check if user's role matches the required role
            if allowed_roles.contains(&user.role) {
                req.extensions_mut().insert::<User>(user);
                let res = srv.call(req).await?;
                Ok(res)
            } else {
                let json_error = ErrorResponse {
                    status: "fail".to_string(),
                    message: ErrorMessage::PermissionDenied.to_string(),
                };
                Err(ErrorForbidden(json_error))
            }
        }
        .boxed_local()
    }
}

Modifying the Middleware Factory

Now, let’s proceed to modify the RequireAuth factory to allow users to specify the permitted roles when utilizing the factory. This will enable us to pass down the allowed roles to the AuthMiddleware. In the RequireAuth factory, we will implement a method called allowed_roles which will enable us to provide the allowed roles in a vector. Here’s how it should look:


/// Middleware requirement for authentication and authorization.
pub struct RequireAuth {
    pub allowed_roles: Rc<Vec<UserRole>>,
}

impl RequireAuth {
    /// Create a new instance of `RequireAuth` middleware.
    pub fn allowed_roles(allowed_roles: Vec<UserRole>) -> Self {
        RequireAuth {
            allowed_roles: Rc::new(allowed_roles),
        }
    }
}

impl<S> Transform<S, ServiceRequest> for RequireAuth
where
    S: Service<
            ServiceRequest,
            Response = ServiceResponse<actix_web::body::BoxBody>,
            Error = actix_web::Error,
        > + 'static,
{
    type Response = ServiceResponse<actix_web::body::BoxBody>;
    type Error = actix_web::Error;
    type Transform = AuthMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    /// Create a new `AuthMiddleware` using the provided service.
    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(AuthMiddleware {
            service: Rc::new(service),
            allowed_roles: self.allowed_roles.clone(),
        }))
    }
}

Using the Middleware Factory

Continuing, let’s see how we can utilize the RequireAuth middleware now that it expects a list of allowed roles to be provided. The following example illustrates this usage:


/// Define a scope for user-related endpoints.
pub fn users_scope() -> Scope {
    web::scope("/api/users")
        .route(
            "",
            // Route to get a list of users, accessible to admins only.
            web::get()
                .to(get_users)
                .wrap(RequireAuth::allowed_roles(vec![UserRole::Admin])),
        )
        .route(
            "/me",
            // Route to get the current user's information, accessible to users, moderators, and admins.
            web::get()
                .to(get_me)
                .wrap(RequireAuth::allowed_roles(vec![
                    UserRole::User,
                    UserRole::Moderator,
                    UserRole::Admin,
                ])),
        )
}

You can see that we only granted the admin user to access the /api/users endpoint while allowing all users to access the /api/users/me endpoint.

Writing Unit Tests for the Authorization Logic

Below are unit tests that specifically focus on testing the authorization aspect of the JWT middleware:


#[cfg(test)]
mod tests {
    use actix_web::{get, test, App, HttpResponse};
    use sqlx::{Pool, Postgres};

    use crate::{
        db::DBClient,
        utils::{password, test_utils::get_test_config, token},
    };

    use super::*;

    #[get(
        "/",
        wrap = "RequireAuth::allowed_roles(vec![UserRole::User, UserRole::Moderator, UserRole::Admin])"
    )]
    async fn handler_with_requireauth() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[get("/", wrap = "RequireAuth::allowed_roles(vec![UserRole::Admin])")]
    async fn handler_with_requireonlyadmin() -> HttpResponse {
        HttpResponse::Ok().into()
    }

    #[sqlx::test]
    async fn test_auth_middleware_access_admin_only_endpoint_fail(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool);
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();

        let user = db_client
            .save_user("John", "john@example.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireonlyadmin),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let result = test::try_call_service(&app, req).await.err();

        match result {
            Some(err) => {
                let expected_status = http::StatusCode::FORBIDDEN;
                let actual_status = err.as_response_error().status_code();

                assert_eq!(actual_status, expected_status);

                let err_response: ErrorResponse = serde_json::from_str(&err.to_string())
                    .expect("Failed to deserialize JSON string");
                let expected_message = ErrorMessage::PermissionDenied.to_string();
                assert_eq!(err_response.message, expected_message);
            }
            None => {
                panic!("Service call succeeded, but an error was expected.");
            }
        }
    }

    #[sqlx::test]
    async fn test_auth_middleware_access_admin_only_endpoint_success(pool: Pool<Postgres>) {
        let db_client = DBClient::new(pool.clone());
        let config = get_test_config();

        let hashed_password = password::hash("password123").unwrap();
        let user = db_client
            .save_admin_user("John Doe", "johndoe@gmail.com", &hashed_password)
            .await
            .unwrap();

        let token =
            token::create_token(&user.id.to_string(), config.jwt_secret.as_bytes(), 60).unwrap();

        let app = test::init_service(
            App::new()
                .app_data(web::Data::new(AppState {
                    env: config.clone(),
                    db_client,
                }))
                .service(handler_with_requireonlyadmin),
        )
        .await;

        let req = test::TestRequest::default()
            .insert_header((http::header::AUTHORIZATION, format!("Bearer {}", token)))
            .to_request();

        let resp = test::call_service(&app, req).await;

        assert_eq!(resp.status(), http::StatusCode::OK);
    }
}

Retrieving the User Information

With the JWT authentication and authorization middleware fully implemented, our next task is to retrieve the user data that we inserted into the request extension within the route handler wrapped by the middleware. At this point, we have two choices: the first involves using the get() method provided by the req.extensions() object, while the second utilizes a request extractor. Although the first option is straightforward, there’s no harm in exploring the second option as well.

Using Only the Request Extension

Let’s start by looking at using the req.extensions().get() method within the route handler itself to extract the user data. Invoking this method returns user data within an option. Hence, we use pattern matching to retrieve the user data, filter it to eliminate sensitive information, and subsequently return the result in a JSON object.


/// Handler for retrieving the information of the authenticated user.
async fn get_me(req: HttpRequest) -> Result<HttpResponse, HttpError> {
    // Retrieve the user information from the request extensions.
    match req.extensions().get::<User>() {
        Some(user) => {
            // Filter sensitive user data before sending the response.
            let filtered_user = FilterUserDto::filter_user(user);

            // Prepare the response data with the filtered user information.
            let response_data = UserResponseDto {
                status: "success".to_string(),
                data: UserData {
                    user: filtered_user,
                },
            };

            // Respond with the filtered user information in JSON format.
            Ok(HttpResponse::Ok().json(response_data))
        }
        None => {
            // Return an error response if user information is not found in request extensions.
            Err(HttpError::server_error("User not found"))
        }
    }
}

Using an Actix-Web Extractor

Let’s explore how to use a request extractor along with the req.extensions() method to extract the user object. Actix makes extensive use of request extractors, where types like Data, Path, and Json implement the FromRequest trait to expose information from the request. We can create our own implementation to simplify the process of extracting the User.


/// Wrapper struct representing an authenticated user.
pub struct Authenticated(User);

impl FromRequest for Authenticated {
    type Error = actix_web::Error;
    type Future = Ready<Result<Self, Self::Error>>;

    /// Implement the `from_request` method to extract and wrap the authenticated user.
    fn from_request(
        req: &actix_web::HttpRequest,
        _payload: &mut actix_web::dev::Payload,
    ) -> Self::Future {
        // Attempt to retrieve the user information from request extensions.
        let value = req.extensions().get::<User>().cloned();

        // Check if the user information was successfully retrieved.
        let result = match value {
            Some(user) => Ok(Authenticated(user)),
            None => Err(ErrorInternalServerError(HttpError::server_error(
                "Authentication Error",
            ))),
        };

        // Return a ready future with the result.
        ready(result)
    }
}

impl std::ops::Deref for Authenticated {
    type Target = User;

    /// Implement the deref method to access the inner User value of Authenticated.
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

The from_request method uses the get() method available on the req.extensions() object to extract the user data and returns it.

Next, let’s see how we can use the extractor in our route handler to access the user data. Simply include it in the list of parameters accepted by the route handler function. Here’s an example demonstrating how to use it:


/// The `user` parameter is automatically extracted using the `Authenticated` guard.
async fn get_me(user: Authenticated) -> Result<HttpResponse, HttpError> {
    // Filter sensitive user data, such as passwords, before sending the response.
    let filtered_user = FilterUserDto::filter_user(&user);

    // Build the response data with the filtered user information.
    let response_data = UserResponseDto {
        status: "success".to_string(),
        data: UserData {
            user: filtered_user,
        },
    };

    // Return a successful HTTP response with the user data.
    Ok(HttpResponse::Ok().json(response_data))
}

Conclusion

And that concludes our journey! Throughout this article, we’ve explored the workings of middleware in the Actix-Web framework and also taken an extra step to create extractors that simplify user authentication and authorization using JWTs.
I hope you found the article helpful. If you have any questions or feedback, please feel free to leave a comment. Thank you for reading!