Skip to content

Commit

Permalink
Remove common function to setup servers
Browse files Browse the repository at this point in the history
Signed-off-by: Mateus Devino <mdevino@ibm.com>
  • Loading branch information
mdevino committed Feb 10, 2025
1 parent e0bb62c commit 3acc063
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 60 deletions.
22 changes: 22 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,28 @@ impl OrchestratorConfig {
.get(detector_id)
.map(|detector_config| detector_config.chunker_id.clone())
}

pub fn set_generation_port(&mut self, port: u16) {
self.generation
.as_mut()
.map(|generation| generation.service.port = Some(port));
}

pub fn set_chat_generation_port(&mut self, port: u16) {
self.chat_generation
.as_mut()
.map(|chat_generation| chat_generation.service.port = Some(port));
}

pub fn set_chunker_port(&mut self, name: &str, port: u16) {
self.chunkers
.as_mut()
.map(|chunkers| chunkers.get_mut(name).unwrap().service.port = Some(port));
}

pub fn set_detector_port(&mut self, name: &str, port: u16) {
self.detectors.get_mut(name).unwrap().service.port = Some(port);
}
}

/// Applies named TLS config to a service.
Expand Down
43 changes: 0 additions & 43 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,8 @@
*/

use std::sync::Arc;

use fms_guardrails_orchestr8::config::OrchestratorConfig;
use fms_guardrails_orchestr8::orchestrator::Orchestrator;
use fms_guardrails_orchestr8::server::ServerState;
use mocktail::generate_grpc_server;
use mocktail::mock::MockSet;
use mocktail::server::HttpMockServer;
use rustls::crypto::ring;

generate_grpc_server!(
Expand All @@ -40,40 +34,3 @@ pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml";
pub fn ensure_global_rustls_state() {
let _ = ring::default_provider().install_default();
}

/// Starts mock servers and adds them to orchestrator configuration.
pub async fn create_orchestrator_shared_state(
detectors: Vec<HttpMockServer>,
chunkers: Vec<(&str, MockChunkersServiceServer)>,
) -> Result<Arc<ServerState>, mocktail::Error> {
let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap();

for detector_mock_server in detectors {
let _ = detector_mock_server.start().await?;

// assign mock server port to detector config
config
.detectors
.get_mut(detector_mock_server.name())
.unwrap()
.service
.port = Some(detector_mock_server.addr().port());
}

for (chunker_name, chunker_mock_server) in chunkers {
let _ = chunker_mock_server.start().await?;

// assign mock server port to chunker config
config
.chunkers
.as_mut()
.unwrap()
.get_mut(chunker_name)
.unwrap()
.service
.port = Some(chunker_mock_server.addr().port());
}

let orchestrator = Orchestrator::new(config, false).await.unwrap();
Ok(Arc::new(ServerState::new(orchestrator)))
}
43 changes: 26 additions & 17 deletions tests/detection_content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@
*/

use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

use axum_test::TestServer;
use common::{
create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer,
CHUNKER_UNARY_ENDPOINT,
ensure_global_rustls_state, MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT, CONFIG_FILE_PATH,
};
use fms_guardrails_orchestr8::{
clients::{
chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME,
detector::{ContentAnalysisRequest, ContentAnalysisResponse},
},
config::OrchestratorConfig,
models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult},
orchestrator::Orchestrator,
pb::{
caikit::runtime::chunkers::ChunkerTokenizationTaskRequest,
caikit_data_model::nlp::{Token, TokenizationResults},
},
server::get_app,
server::{get_app, ServerState},
};
use hyper::StatusCode;
use mocktail::prelude::*;
Expand Down Expand Up @@ -81,11 +82,15 @@ async fn test_single_detection_whole_doc() {
),
);

// Setup orchestrator and detector servers
// Start orchestrator and detector servers.
let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap();
let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![])
.await
.unwrap();
let _ = mock_detector_server.start().await;

let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap();
config.set_detector_port(detector_name, mock_detector_server.addr().port());
let orchestrator = Orchestrator::new(config, false).await.unwrap();
let shared_state = Arc::new(ServerState::new(orchestrator));

let server = TestServer::new(get_app(shared_state)).unwrap();

// Make orchestrator call
Expand Down Expand Up @@ -156,8 +161,8 @@ async fn test_single_detection_sentence_chunker() {

// Add detector mock
let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE;
let mut mocks = MockSet::new();
mocks.insert(
let mut detector_mocks = MockSet::new();
detector_mocks.insert(
MockPath::new(Method::POST, ENDPOINT_DETECTOR),
Mock::new(
MockRequest::json(ContentAnalysisRequest {
Expand Down Expand Up @@ -185,13 +190,17 @@ async fn test_single_detection_sentence_chunker() {

// Start orchestrator, chunker and detector servers.
let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks).unwrap();
let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap();
let shared_state = create_orchestrator_shared_state(
vec![mock_detector_server],
vec![(chunker_id, mock_chunker_server)],
)
.await
.unwrap();
let _ = mock_chunker_server.start().await;

let mock_detector_server = HttpMockServer::new(detector_name, detector_mocks).unwrap();
let _ = mock_detector_server.start().await;

let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap();
config.set_chunker_port(chunker_id, mock_chunker_server.addr().port());
config.set_detector_port(detector_name, mock_detector_server.addr().port());
let orchestrator = Orchestrator::new(config, false).await.unwrap();
let shared_state = Arc::new(ServerState::new(orchestrator));

let server = TestServer::new(get_app(shared_state)).unwrap();

// Make orchestrator call
Expand Down

0 comments on commit 3acc063

Please sign in to comment.