Skip to content

Commit

Permalink
add the token middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
huangcheng committed Dec 16, 2023
1 parent ace3e2c commit 8c53437
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 80 deletions.
76 changes: 2 additions & 74 deletions src/middlewares.rs
Original file line number Diff line number Diff line change
@@ -1,74 +1,2 @@
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket_db_pools::deadpool_redis::redis::AsyncCommands;

use crate::config::Config;
use crate::{Claims, RedisDb};

pub struct JwtMiddleware {
pub username: String,
}

#[derive(Debug)]
pub enum JwtError {
ConfigError,
CacheError,
MissingToken,
InvalidToken,
ExpiredToken,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for JwtMiddleware {
type Error = JwtError;

async fn from_request(request: &'r rocket::Request<'_>) -> Outcome<Self, Self::Error> {
let config = match request.rocket().figment().extract::<Config>() {
Ok(config) => config,
Err(_) => return Outcome::Error((Status::InternalServerError, JwtError::ConfigError)),
};

let token = match request.headers().get_one("Authorization") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, JwtError::MissingToken)),
};

let token = match token.strip_prefix("Bearer ") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, JwtError::MissingToken)),
};

let is_in_white_list: &Option<bool> = request
.local_cache_async(async {
let redis = request.guard::<&RedisDb>().await.succeeded()?;
let mut connection = redis.get().await.ok()?;

let result = connection.exists(token).await.ok()?;

Some(result)
})
.await;

if is_in_white_list.is_none() {
return Outcome::Error((Status::Unauthorized, JwtError::CacheError));
}

if *is_in_white_list == Some(false) {
return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken));
}

let token_data = match decode::<Claims>(
token,
&DecodingKey::from_secret(config.jwt.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
) {
Ok(token) => token,
Err(_) => return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken)),
};

let username = token_data.claims.sub.clone();

Outcome::Success(JwtMiddleware { username })
}
}
pub mod jwt;
pub mod token;
74 changes: 74 additions & 0 deletions src/middlewares/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket_db_pools::deadpool_redis::redis::AsyncCommands;

use crate::config::Config;
use crate::{Claims, RedisDb};

pub struct JwtMiddleware {
pub username: String,
}

#[derive(Debug)]
pub enum JwtError {
ConfigError,
CacheError,
MissingToken,
InvalidToken,
ExpiredToken,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for JwtMiddleware {
type Error = JwtError;

async fn from_request(request: &'r rocket::Request<'_>) -> Outcome<Self, Self::Error> {
let config = match request.rocket().figment().extract::<Config>() {
Ok(config) => config,
Err(_) => return Outcome::Error((Status::InternalServerError, JwtError::ConfigError)),
};

let token = match request.headers().get_one("Authorization") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, JwtError::MissingToken)),
};

let token = match token.strip_prefix("Bearer ") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, JwtError::MissingToken)),
};

let is_in_white_list: &Option<bool> = request
.local_cache_async(async {
let redis = request.guard::<&RedisDb>().await.succeeded()?;
let mut connection = redis.get().await.ok()?;

let result = connection.exists(token).await.ok()?;

Some(result)
})
.await;

if is_in_white_list.is_none() {
return Outcome::Error((Status::Unauthorized, JwtError::CacheError));
}

if *is_in_white_list == Some(false) {
return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken));
}

let token_data = match decode::<Claims>(
token,
&DecodingKey::from_secret(config.jwt.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
) {
Ok(token) => token,
Err(_) => return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken)),
};

let username = token_data.claims.sub.clone();

Outcome::Success(JwtMiddleware { username })
}
}
33 changes: 33 additions & 0 deletions src/middlewares/token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};

pub struct TokenMiddleware {
pub token: String,
}

#[derive(Debug)]
pub enum TokenError {
MissingToken,
InvalidToken,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for TokenMiddleware {
type Error = TokenError;

async fn from_request(request: &'r rocket::Request<'_>) -> Outcome<Self, Self::Error> {
let token = match request.headers().get_one("Authorization") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, TokenError::MissingToken)),
};

let token = match token.strip_prefix("Bearer ") {
Some(token) => token,
None => return Outcome::Error((Status::Unauthorized, TokenError::InvalidToken)),
};

Outcome::Success(Self {
token: token.to_string(),
})
}
}
2 changes: 1 addition & 1 deletion src/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rocket_db_pools::deadpool_redis::redis::AsyncCommands;
use rocket_db_pools::Connection;

use crate::config::Config;
use crate::middlewares::JwtMiddleware;
use crate::middlewares::jwt::JwtMiddleware;
use crate::response::auth::Logout;
use crate::state::AppState;
use crate::{handlers, request, response, MySQLDb, RedisDb};
Expand Down
2 changes: 1 addition & 1 deletion src/routes/category.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::handlers::category::{
self, add_category, delete_category, get_categories, sort_categories, sort_category_sites,
update_category,
};
use crate::middlewares::JwtMiddleware;
use crate::middlewares::jwt::JwtMiddleware;
use crate::request::category::{CreateCategory, SortCategory, UpdateCategory};
use crate::response::category::Category;
use crate::response::site::Site;
Expand Down
2 changes: 1 addition & 1 deletion src/routes/site.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rocket_db_pools::Connection;
use crate::config::Config;
use crate::handlers::site;
use crate::handlers::site::get_sites;
use crate::middlewares::JwtMiddleware;
use crate::middlewares::jwt::JwtMiddleware;
use crate::request::site::{CreateSite, UpdateSite};
use crate::response::site::SiteWithCategory;
use crate::response::WithTotal;
Expand Down
2 changes: 1 addition & 1 deletion src/routes/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rocket::{post, FromForm, State};

use crate::config::Config;
use crate::handlers;
use crate::middlewares::JwtMiddleware;
use crate::middlewares::jwt::JwtMiddleware;

#[derive(FromForm)]
pub struct Upload<'r> {
Expand Down
6 changes: 4 additions & 2 deletions src/routes/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use rocket_db_pools::Connection;

use crate::config::Config;
use crate::handlers::user::{get_user, update_user, update_user_password};
use crate::middlewares::JwtMiddleware;
use crate::middlewares::jwt::JwtMiddleware;
use crate::middlewares::token::TokenMiddleware;
use crate::request::user::{UpdatePassword, UpdateUser};
use crate::response::auth::Logout;
use crate::response::user::User;
Expand Down Expand Up @@ -67,6 +68,7 @@ pub async fn update_password<'r>(
password: Json<UpdatePassword<'r>>,
mut db: Connection<MySQLDb>,
mut cache: Connection<RedisDb>,
token: TokenMiddleware,
_jwt: JwtMiddleware,
) -> Result<Logout, Status> {
update_user_password(username, password.deref(), &mut db)
Expand All @@ -77,7 +79,7 @@ pub async fn update_password<'r>(
e.status()
})?;

cache.del(String::from(username)).await.map_err(|e| {
cache.del(token.token).await.map_err(|e| {
error!("{}", e);

Status::InternalServerError
Expand Down

0 comments on commit 8c53437

Please sign in to comment.