Skip to content

Commit

Permalink
reuse response schemas as well as body schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkw committed Dec 3, 2024
1 parent c86d368 commit de6933c
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 1,169 deletions.
165 changes: 118 additions & 47 deletions dropshot/src/api_description.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use crate::extractor::RequestExtractor;
use crate::handler::HttpHandlerFunc;
use crate::handler::HttpResponse;
use crate::handler::HttpResponseContent;
use crate::handler::HttpResponseError;
use crate::handler::HttpRouteHandler;
use crate::handler::RouteHandler;
Expand Down Expand Up @@ -53,7 +52,7 @@ pub struct ApiEndpoint<Context: ServerContext> {
pub parameters: Vec<ApiEndpointParameter>,
pub body_content_type: ApiEndpointBodyContentType,
pub response: ApiEndpointResponse,
pub error: ApiEndpointErrorResponse,
pub error: Option<ApiEndpointErrorResponse>,
pub summary: Option<String>,
pub description: Option<String>,
pub tags: Vec<String>,
Expand Down Expand Up @@ -82,9 +81,7 @@ impl<'a, Context: ServerContext> ApiEndpoint<Context> {
.expect("unsupported mime type");
let func_parameters = FuncParams::metadata(body_content_type.clone());
let response = ResponseType::response_metadata();
let error = ApiEndpointErrorResponse {
schema: <HandlerType::Error>::content_metadata(),
};
let error = ApiEndpointErrorResponse::for_type::<HandlerType::Error>();
ApiEndpoint {
operation_id,
handler: HttpRouteHandler::new(handler),
Expand Down Expand Up @@ -185,9 +182,7 @@ impl<'a> ApiEndpoint<StubContext> {
.expect("unsupported mime type");
let func_parameters = FuncParams::metadata(body_content_type.clone());
let response = <ResultType::Response>::response_metadata();
let error = ApiEndpointErrorResponse {
schema: <ResultType::Error>::content_metadata(),
};
let error = ApiEndpointErrorResponse::for_type::<ResultType::Error>();
let handler = StubRouteHandler::new_with_name(&operation_id);
ApiEndpoint {
operation_id,
Expand Down Expand Up @@ -346,9 +341,17 @@ pub struct ApiEndpointResponse {
}

/// Metadata for an API endpoint's error response type.
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct ApiEndpointErrorResponse {
pub(crate) schema: Option<ApiSchemaGenerator>,
schema: ApiSchemaGenerator,
type_name: &'static str,
}

impl ApiEndpointErrorResponse {
fn for_type<T: HttpResponseError>() -> Option<Self> {
let schema = T::content_metadata()?;
Some(Self { schema, type_name: std::any::type_name::<T>() })
}
}

/// Wrapper for both dynamically generated and pre-generated schemas.
Expand Down Expand Up @@ -685,6 +688,28 @@ impl<Context: ServerContext> ApiDescription<Context> {
let mut definitions =
indexmap::IndexMap::<String, schemars::schema::Schema>::new();

// A response object generated for an endpoint's error response. These
// are emitted in the top-level `components.responses` map in the
// OpenAPI document, so that multiple endpoints that return the same
// Rust error type can share error response schemas.
struct ErrorResponse {
response: openapiv3::Response,
name: String,
reference: openapiv3::ReferenceOr<openapiv3::Response>,
}
let mut error_responses =
indexmap::IndexMap::<&str, ErrorResponse>::new();
// In the event that there are multiple Rust error types with the same
// name (e.g., both named 'Error'), we must disambiguate the name of the
// response by appending the number of occurances of that name. This is
// the mechanism as how `schemars` will disambiguate colliding schema
// names. However, we must implement our own version of this for
// response schemas, as we cannot simply use the response body's schema
// name: the body's schema may be a static, non-referenceable schema, so
// there isn't guaranteed to be a schema name we can reuse for the
// response object.
let mut error_response_names = HashMap::<&str, usize>::new();

for (path, method, endpoint) in self.router.endpoints(Some(version)) {
if !endpoint.visible {
continue;
Expand Down Expand Up @@ -939,43 +964,81 @@ impl<Context: ServerContext> ApiDescription<Context> {

// If the endpoint defines an error type, emit that for
// the 4xx and 5xx responses.
if let Some(ref schema) = endpoint.error.schema {
let error_schema = match schema {
ApiSchemaGenerator::Gen { ref name, ref schema } => {
j2oas_schema(Some(&name()), &schema(&mut generator))
}
ApiSchemaGenerator::Static {
ref schema,
ref dependencies,
} => {
definitions.extend(dependencies.clone());
j2oas_schema(None, &schema)
}
};
let mut content = indexmap::IndexMap::new();
content.insert(
CONTENT_TYPE_JSON.to_string(),
openapiv3::MediaType {
schema: Some(error_schema),
..Default::default()
},
);
operation.responses.responses.insert(
openapiv3::StatusCode::Range(4),
openapiv3::ReferenceOr::Item(openapiv3::Response {
description: "client error".to_string(),
content: content.clone(),
..Default::default()
}),
);
operation.responses.responses.insert(
openapiv3::StatusCode::Range(5),
openapiv3::ReferenceOr::Item(openapiv3::Response {
description: "server error".to_string(),
content: content.clone(),
..Default::default()
}),
);
if let Some(ApiEndpointErrorResponse { ref schema, type_name }) =
endpoint.error
{
let ErrorResponse { reference, .. } =
// If a response object for this error type has already been
// generated, use that; otherwise, we'll generate it now.
error_responses.entry(type_name).or_insert_with(|| {
let error_schema = match schema {
ApiSchemaGenerator::Gen {
ref name,
ref schema,
} => j2oas_schema(
Some(&name()),
&schema(&mut generator),
),
ApiSchemaGenerator::Static {
ref schema,
ref dependencies,
} => {
definitions.extend(dependencies.clone());
j2oas_schema(None, &schema)
}
};

// If multiple distinct Rust error types with the same
// type name occur in the API, disambigate the response
// schemas by appending a number to each distinct error
// type.
//
// This is a bit ugly, but fortunately, it won't happen
// *too* often, and at least it's consistent with
// schemars' name disambiguation.
let type_name = type_name
.split("::")
.last()
.expect("type name must not be an empty string");
let name = {
let num = error_response_names
.entry(type_name)
.and_modify(|num| *num += 1)
.or_insert(1);
if *num <= 1 {
type_name.to_string()
} else {
format!("{type_name}{num}")
}
};

let mut content = indexmap::IndexMap::new();
content.insert(
CONTENT_TYPE_JSON.to_string(),
openapiv3::MediaType {
schema: Some(error_schema),
..Default::default()
},
);
let response = openapiv3::Response {
description: type_name.to_string(),
content: content.clone(),
..Default::default()
};
let reference = openapiv3::ReferenceOr::Reference {
reference: format!("#/components/responses/{name}"),
};

ErrorResponse { name, reference, response }
});
operation
.responses
.responses
.insert(openapiv3::StatusCode::Range(4), reference.clone());
operation
.responses
.responses
.insert(openapiv3::StatusCode::Range(5), reference.clone());
}

if let Some(code) = &endpoint.response.success {
Expand Down Expand Up @@ -1014,6 +1077,14 @@ impl<Context: ServerContext> ApiDescription<Context> {
}
});

// Generate error responses.
let responses = &mut components.responses;
for (_, ErrorResponse { name, response, .. }) in error_responses {
let prev =
responses.insert(name, openapiv3::ReferenceOr::Item(response));
assert_eq!(prev, None, "error response names must be unique")
}

openapi
}

Expand Down
3 changes: 1 addition & 2 deletions dropshot/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,6 @@ mod test {
use super::HttpRouter;
use super::PathSegment;
use crate::api_description::ApiEndpointBodyContentType;
use crate::api_description::ApiEndpointErrorResponse;
use crate::api_description::ApiEndpointVersions;
use crate::from_map::from_map;
use crate::router::VariableValue;
Expand Down Expand Up @@ -879,7 +878,7 @@ mod test {
parameters: vec![],
body_content_type: ApiEndpointBodyContentType::default(),
response: ApiEndpointResponse::default(),
error: ApiEndpointErrorResponse::default(),
error: None,
summary: None,
description: None,
tags: vec![],
Expand Down
Loading

0 comments on commit de6933c

Please sign in to comment.