From 3acc0637a23e23dbed3f7193f38fb85b1c0e45cc Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 09:42:12 -0300 Subject: [PATCH] Remove common function to setup servers Signed-off-by: Mateus Devino --- src/config.rs | 22 +++++++++++++++++++ tests/common/mod.rs | 43 -------------------------------------- tests/detection_content.rs | 43 +++++++++++++++++++++++--------------- 3 files changed, 48 insertions(+), 60 deletions(-) diff --git a/src/config.rs b/src/config.rs index 828e5219..66fd9b76 100644 --- a/src/config.rs +++ b/src/config.rs @@ -355,6 +355,28 @@ impl OrchestratorConfig { .get(detector_id) .map(|detector_config| detector_config.chunker_id.clone()) } + + pub fn set_generation_port(&mut self, port: u16) { + self.generation + .as_mut() + .map(|generation| generation.service.port = Some(port)); + } + + pub fn set_chat_generation_port(&mut self, port: u16) { + self.chat_generation + .as_mut() + .map(|chat_generation| chat_generation.service.port = Some(port)); + } + + pub fn set_chunker_port(&mut self, name: &str, port: u16) { + self.chunkers + .as_mut() + .map(|chunkers| chunkers.get_mut(name).unwrap().service.port = Some(port)); + } + + pub fn set_detector_port(&mut self, name: &str, port: u16) { + self.detectors.get_mut(name).unwrap().service.port = Some(port); + } } /// Applies named TLS config to a service. diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 59754e50..1b176839 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -15,14 +15,8 @@ */ -use std::sync::Arc; - -use fms_guardrails_orchestr8::config::OrchestratorConfig; -use fms_guardrails_orchestr8::orchestrator::Orchestrator; -use fms_guardrails_orchestr8::server::ServerState; use mocktail::generate_grpc_server; use mocktail::mock::MockSet; -use mocktail::server::HttpMockServer; use rustls::crypto::ring; generate_grpc_server!( @@ -40,40 +34,3 @@ pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } - -/// Starts mock servers and adds them to orchestrator configuration. -pub async fn create_orchestrator_shared_state( - detectors: Vec, - chunkers: Vec<(&str, MockChunkersServiceServer)>, -) -> Result, mocktail::Error> { - let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); - - for detector_mock_server in detectors { - let _ = detector_mock_server.start().await?; - - // assign mock server port to detector config - config - .detectors - .get_mut(detector_mock_server.name()) - .unwrap() - .service - .port = Some(detector_mock_server.addr().port()); - } - - for (chunker_name, chunker_mock_server) in chunkers { - let _ = chunker_mock_server.start().await?; - - // assign mock server port to chunker config - config - .chunkers - .as_mut() - .unwrap() - .get_mut(chunker_name) - .unwrap() - .service - .port = Some(chunker_mock_server.addr().port()); - } - - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Ok(Arc::new(ServerState::new(orchestrator))) -} diff --git a/tests/detection_content.rs b/tests/detection_content.rs index bd1e6e9f..895ce569 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -15,24 +15,25 @@ */ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use axum_test::TestServer; use common::{ - create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer, - CHUNKER_UNARY_ENDPOINT, + ensure_global_rustls_state, MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT, CONFIG_FILE_PATH, }; use fms_guardrails_orchestr8::{ clients::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, detector::{ContentAnalysisRequest, ContentAnalysisResponse}, }, + config::OrchestratorConfig, models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, + orchestrator::Orchestrator, pb::{ caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, caikit_data_model::nlp::{Token, TokenizationResults}, }, - server::get_app, + server::{get_app, ServerState}, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -81,11 +82,15 @@ async fn test_single_detection_whole_doc() { ), ); - // Setup orchestrator and detector servers + // Start orchestrator and detector servers. let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]) - .await - .unwrap(); + let _ = mock_detector_server.start().await; + + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + config.set_detector_port(detector_name, mock_detector_server.addr().port()); + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + let shared_state = Arc::new(ServerState::new(orchestrator)); + let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call @@ -156,8 +161,8 @@ async fn test_single_detection_sentence_chunker() { // Add detector mock let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; - let mut mocks = MockSet::new(); - mocks.insert( + let mut detector_mocks = MockSet::new(); + detector_mocks.insert( MockPath::new(Method::POST, ENDPOINT_DETECTOR), Mock::new( MockRequest::json(ContentAnalysisRequest { @@ -185,13 +190,17 @@ async fn test_single_detection_sentence_chunker() { // Start orchestrator, chunker and detector servers. let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks).unwrap(); - let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let shared_state = create_orchestrator_shared_state( - vec![mock_detector_server], - vec![(chunker_id, mock_chunker_server)], - ) - .await - .unwrap(); + let _ = mock_chunker_server.start().await; + + let mock_detector_server = HttpMockServer::new(detector_name, detector_mocks).unwrap(); + let _ = mock_detector_server.start().await; + + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + config.set_chunker_port(chunker_id, mock_chunker_server.addr().port()); + config.set_detector_port(detector_name, mock_detector_server.addr().port()); + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + let shared_state = Arc::new(ServerState::new(orchestrator)); + let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call