From 8b4df027b15ca70a8f445855a7aa7c252bb154f9 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 3 Jan 2025 11:27:34 -0300 Subject: [PATCH 001/117] Refactor common integration test code Signed-off-by: Mateus Devino --- Cargo.lock | 736 ++++++++++++++++++++++++------------------- tests/canary_test.rs | 27 +- tests/common/mod.rs | 23 ++ 3 files changed, 434 insertions(+), 352 deletions(-) create mode 100644 tests/common/mod.rs diff --git a/Cargo.lock b/Cargo.lock index f247ccd1..c3973b1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 4 [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -28,9 +28,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -43,43 +43,44 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" [[package]] name = "assert-json-diff" @@ -93,9 +94,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" dependencies = [ "async-stream-impl", "futures-core", @@ -104,9 +105,9 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", @@ -115,9 +116,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", @@ -144,9 +145,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-lc-rs" -version = "1.12.1" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ea835662a0af02443aa1396d39be523bbf8f11ee6fad20329607c480bea48c3" +checksum = "4cd755adf9707cf671e31d944a189be3deaaeee11c8bc1d669bb8022ac90fbd0" dependencies = [ "aws-lc-sys", "paste", @@ -155,9 +156,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b2ddd3ada61a305e1d8bb6c005d1eaa7d14d903681edfc400406d523a9b491" +checksum = "0f9dd2e03ee80ca2822dd6ea431163d2ef259f2066a4d6ccaca6d9dcb386aa43" dependencies = [ "bindgen", "cc", @@ -169,15 +170,15 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core 0.4.5", "bytes", "futures-util", - "http 1.2.0", + "http", "http-body", "http-body-util", "itoa", @@ -204,7 +205,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.2.0", + "http", "http-body", "http-body-util", "hyper", @@ -237,7 +238,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.2.0", + "http", "http-body", "http-body-util", "mime", @@ -256,7 +257,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.2.0", + "http", "http-body", "http-body-util", "mime", @@ -278,7 +279,7 @@ dependencies = [ "axum-core 0.5.0", "bytes", "futures-util", - "http 1.2.0", + "http", "http-body", "http-body-util", "mime", @@ -291,9 +292,9 @@ dependencies = [ [[package]] name = "axum-test" -version = "17.1.0" +version = "17.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "375ec4f6db373ce6d696839249203c57049aefe1213cfa77bb2e12e10ed43d08" +checksum = "317c1f4ecc1e68e0ad5decb78478421055c963ce215e736ed97463fa609cd196" dependencies = [ "anyhow", "assert-json-diff", @@ -302,7 +303,7 @@ dependencies = [ "bytes", "bytesize", "cookie", - "http 1.2.0", + "http", "http-body-util", "hyper", "hyper-util", @@ -365,15 +366,15 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "byteorder" @@ -383,21 +384,21 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" [[package]] name = "bytesize" -version = "1.3.0" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e368af43e418a04d52505cf3dbc23dda4e3407ae2fa99fd0e4f308ce546acc" +checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e" [[package]] name = "cc" -version = "1.1.23" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bbb537bb4a30b90362caddba8f360c0a56bc13d3a5570028e7197204cb54a17" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "jobserver", "libc", @@ -438,9 +439,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" dependencies = [ "clap_builder", "clap_derive", @@ -448,9 +449,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" dependencies = [ "anstream", "anstyle", @@ -460,9 +461,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.24" +version = "4.5.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" +checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" dependencies = [ "heck", "proc-macro2", @@ -478,18 +479,18 @@ checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "cmake" -version = "0.1.51" +version = "0.1.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" dependencies = [ "cc", ] [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "cookie" @@ -511,6 +512,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -554,9 +565,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" [[package]] name = "deranged" @@ -592,9 +603,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "either" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" [[package]] name = "encoding_rs" @@ -619,18 +630,18 @@ dependencies = [ [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -646,9 +657,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "faux" @@ -675,9 +686,9 @@ dependencies = [ [[package]] name = "fixedbitset" -version = "0.4.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "fms-guardrails-orchestr8" @@ -694,7 +705,7 @@ dependencies = [ "futures", "futures-util", "ginepro", - "http 1.2.0", + "http", "http-body", "http-body-util", "http-serde", @@ -861,17 +872,31 @@ name = "getrandom" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", ] [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "ginepro" @@ -882,8 +907,8 @@ dependencies = [ "anyhow", "async-trait", "hickory-resolver", - "http 1.2.0", - "thiserror 1.0.64", + "http", + "thiserror 1.0.69", "tokio", "tonic", "tower 0.4.13", @@ -892,23 +917,23 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.2.0", - "indexmap 2.5.0", + "http", + "indexmap 2.7.1", "slab", "tokio", "tokio-util", @@ -923,9 +948,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -933,17 +958,11 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - [[package]] name = "hickory-proto" -version = "0.24.3" +version = "0.24.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ad3d6d98c648ed628df039541a5577bee1a7c83e9e16fe3dbedeea4cdfeb971" +checksum = "92652067c9ce6f66ce53cc38d1169daa36e6e7eb7dd3b63b5103bd9d97117248" dependencies = [ "async-trait", "cfg-if", @@ -955,8 +974,8 @@ dependencies = [ "idna", "ipnet", "once_cell", - "rand", - "thiserror 1.0.64", + "rand 0.8.5", + "thiserror 1.0.69", "tinyvec", "tokio", "tracing", @@ -965,9 +984,9 @@ dependencies = [ [[package]] name = "hickory-resolver" -version = "0.24.2" +version = "0.24.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a2e2aba9c389ce5267d31cf1e4dace82390ae276b0b364ea55630b1fa1b44b4" +checksum = "cbb117a1ca520e111743ab2f6688eddee69db4e0ea242545a604dce8a66fd22e" dependencies = [ "cfg-if", "futures-util", @@ -976,21 +995,21 @@ dependencies = [ "lru-cache", "once_cell", "parking_lot", - "rand", + "rand 0.8.5", "resolv-conf", "smallvec", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tracing", ] [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1004,17 +1023,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.2.0" @@ -1033,7 +1041,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http", ] [[package]] @@ -1044,7 +1052,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.2.0", + "http", "http-body", "pin-project-lite", ] @@ -1055,15 +1063,15 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" dependencies = [ - "http 1.2.0", + "http", "serde", ] [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -1073,15 +1081,15 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", "futures-util", "h2", - "http 1.2.0", + "http", "http-body", "httparse", "httpdate", @@ -1099,7 +1107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.2.0", + "http", "hyper", "hyper-util", "log", @@ -1150,7 +1158,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", + "http", "http-body", "hyper", "pin-project-lite", @@ -1317,12 +1325,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -1339,9 +1347,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -1360,18 +1368,18 @@ dependencies = [ [[package]] name = "itertools" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jobserver" @@ -1384,10 +1392,11 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -1405,18 +1414,18 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -1437,9 +1446,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -1459,9 +1468,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "lru-cache" @@ -1529,22 +1538,21 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -1556,9 +1564,9 @@ checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" [[package]] name = "native-tls" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" dependencies = [ "libc", "log", @@ -1566,7 +1574,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -1599,27 +1607,24 @@ checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "object" -version = "0.36.4" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.20.1" +version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" -dependencies = [ - "portable-atomic", -] +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "openssl" -version = "0.10.70" +version = "0.10.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" +checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" dependencies = [ "bitflags", "cfg-if", @@ -1643,15 +1648,15 @@ dependencies = [ [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.105" +version = "0.9.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" dependencies = [ "cc", "libc", @@ -1669,7 +1674,7 @@ dependencies = [ "futures-sink", "js-sys", "pin-project-lite", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -1681,7 +1686,7 @@ checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" dependencies = [ "async-trait", "bytes", - "http 1.2.0", + "http", "opentelemetry", "reqwest", ] @@ -1694,13 +1699,13 @@ checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" dependencies = [ "async-trait", "futures-core", - "http 1.2.0", + "http", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", "prost", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tonic", "tracing", @@ -1731,9 +1736,9 @@ dependencies = [ "glob", "opentelemetry", "percent-encoding", - "rand", + "rand 0.8.5", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tokio-stream", "tracing", @@ -1782,28 +1787,28 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.5.0", + "indexmap 2.7.1", ] [[package]] name = "pin-project" -version = "1.1.5" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.5" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" dependencies = [ "proc-macro2", "quote", @@ -1828,12 +1833,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" -[[package]] -name = "portable-atomic" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" - [[package]] name = "powerfmt" version = "0.2.0" @@ -1846,7 +1845,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -1861,9 +1860,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.22" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", "syn", @@ -1871,18 +1870,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] [[package]] name = "prost" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", "prost-derive", @@ -1890,13 +1889,12 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.3" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ - "bytes", "heck", - "itertools 0.13.0", + "itertools 0.14.0", "log", "multimap", "once_cell", @@ -1911,12 +1909,12 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.4" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools 0.14.0", "proc-macro2", "quote", "syn", @@ -1924,9 +1922,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.3" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ "prost", ] @@ -1939,44 +1937,47 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", + "rustc-hash 2.1.1", "rustls", "socket2", - "thiserror 1.0.64", + "thiserror 2.0.11", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", - "rand", + "getrandom 0.2.15", + "rand 0.8.5", "ring", - "rustc-hash 2.0.0", + "rustc-hash 2.1.1", "rustls", + "rustls-pki-types", "slab", - "thiserror 1.0.64", + "thiserror 2.0.11", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.7" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" +checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" dependencies = [ "cfg_aliases", "libc", @@ -1988,9 +1989,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -2002,8 +2003,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.2", + "zerocopy 0.8.20", ] [[package]] @@ -2013,7 +2025,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.2", ] [[package]] @@ -2022,27 +2044,37 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a509b1a2ffbe92afab0e55c8fd99dea1c280e8171bd2d88682bb20bc41cbc2c" +dependencies = [ + "getrandom 0.3.1", + "zerocopy 0.8.20", ] [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f" dependencies = [ "bitflags", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -2057,9 +2089,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -2091,7 +2123,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 1.2.0", + "http", "http-body", "http-body-util", "hyper", @@ -2130,12 +2162,12 @@ dependencies = [ [[package]] name = "reserve-port" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9838134a2bfaa8e1f40738fcc972ac799de6e0e06b5157acb95fc2b05a0ea283" +checksum = "359fc315ed556eb0e42ce74e76f4b1cd807b50fa6307f3de4e51f92dbe86e2d5" dependencies = [ "lazy_static", - "thiserror 1.0.64", + "thiserror 2.0.11", ] [[package]] @@ -2150,33 +2182,32 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.8" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", - "spin", "untrusted", "windows-sys 0.52.0", ] [[package]] name = "rust-multipart-rfc7578_2" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03b748410c0afdef2ebbe3685a6a862e2ee937127cdaae623336a459451c8d57" +checksum = "bc4bb9e7c9abe5fa5f30c2d8f8fefb9e0080a2c1e3c2e567318d2907054b35d3" dependencies = [ "bytes", "futures-core", "futures-util", - "http 0.2.12", + "http", "mime", "mime_guess", - "rand", - "thiserror 1.0.64", + "rand 0.9.0", + "thiserror 2.0.11", ] [[package]] @@ -2193,28 +2224,28 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "aws-lc-rs", "log", @@ -2228,15 +2259,14 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.2.0", ] [[package]] @@ -2250,9 +2280,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -2268,21 +2301,21 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -2300,7 +2333,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -2308,9 +2354,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -2318,18 +2364,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -2338,9 +2384,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" dependencies = [ "itoa", "memchr", @@ -2376,7 +2422,7 @@ version = "0.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59e2dd588bf1597a252c3b920e0143eb99b0f76e4e082f4c92ce34fbc9e71ddd" dependencies = [ - "indexmap 2.5.0", + "indexmap 2.7.1", "itoa", "libyml", "memchr", @@ -2420,26 +2466,20 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2460,9 +2500,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.87" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -2471,9 +2511,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -2496,7 +2536,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -2512,12 +2552,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.13.0" +version = "3.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" dependencies = [ "cfg-if", "fastrand", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -2525,11 +2566,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl 1.0.64", + "thiserror-impl 1.0.69", ] [[package]] @@ -2543,9 +2584,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", @@ -2575,9 +2616,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "itoa", @@ -2596,9 +2637,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -2616,9 +2657,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -2692,9 +2733,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -2711,11 +2752,11 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum 0.7.7", + "axum 0.7.9", "base64", "bytes", "h2", - "http 1.2.0", + "http", "http-body", "http-body-util", "hyper", @@ -2762,7 +2803,7 @@ dependencies = [ "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand", + "rand 0.8.5", "slab", "tokio", "tokio-util", @@ -2795,7 +2836,7 @@ checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" dependencies = [ "bitflags", "bytes", - "http 1.2.0", + "http", "http-body", "pin-project-lite", "tower-layer", @@ -2937,18 +2978,15 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "unicase" -version = "2.7.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" -dependencies = [ - "version_check", -] +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" [[package]] name = "untrusted" @@ -2987,18 +3025,18 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" +checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1" dependencies = [ - "getrandom", + "getrandom 0.3.1", ] [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "vcpkg" @@ -3027,26 +3065,35 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn", @@ -3055,21 +3102,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3077,9 +3125,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", @@ -3090,15 +3138,18 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -3116,9 +3167,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -3351,6 +3402,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "write16" version = "1.0.0" @@ -3400,7 +3460,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dde3bb8c68a8f3f1ed4ac9221aad6b10cece3e60a8e2ea54a6a2dec806d0084c" +dependencies = [ + "zerocopy-derive 0.8.20", ] [[package]] @@ -3414,6 +3483,17 @@ dependencies = [ "syn", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eea57037071898bf96a6da35fd626f4f27e9cee3ead2a6c703cf09d472b2e700" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.5" diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 653c0e84..19c8a129 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -1,33 +1,12 @@ -use std::sync::Arc; - use axum_test::TestServer; -use fms_guardrails_orchestr8::{ - config::OrchestratorConfig, - orchestrator::Orchestrator, - server::{get_health_app, ServerState}, -}; +use common::{ensure_global_rustls_state, shared_state, ONCE}; +use fms_guardrails_orchestr8::server::get_health_app; use hyper::StatusCode; -use rustls::crypto::ring; use serde_json::Value; -use tokio::sync::OnceCell; use tracing::debug; use tracing_test::traced_test; -/// Async lazy initialization of shared state using tokio::sync::OnceCell -static ONCE: OnceCell> = OnceCell::const_new(); - -/// The actual async function that initializes the shared state if not already initialized -async fn shared_state() -> Arc { - let config = OrchestratorConfig::load("tests/test.config.yaml") - .await - .unwrap(); - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Arc::new(ServerState::new(orchestrator)) -} - -fn ensure_global_rustls_state() { - let _ = ring::default_provider().install_default(); -} +mod common; /// Checks if the health endpoint is working /// NOTE: We do not currently mock client services yet, so this test is diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 00000000..971cf599 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,23 @@ +use std::sync::Arc; + +use fms_guardrails_orchestr8::{ + config::OrchestratorConfig, orchestrator::Orchestrator, server::ServerState, +}; +use rustls::crypto::ring; +use tokio::sync::OnceCell; + +/// Async lazy initialization of shared state using tokio::sync::OnceCell +pub static ONCE: OnceCell> = OnceCell::const_new(); + +/// The actual async function that initializes the shared state if not already initialized +pub async fn shared_state() -> Arc { + let config = OrchestratorConfig::load("tests/test.config.yaml") + .await + .unwrap(); + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + Arc::new(ServerState::new(orchestrator)) +} + +pub fn ensure_global_rustls_state() { + let _ = ring::default_provider().install_default(); +} From 798b730ad87427c72a0748404fc5282585cbc6bd Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 6 Jan 2025 14:21:17 -0300 Subject: [PATCH 002/117] /detection/content base test case Signed-off-by: Mateus Devino --- Cargo.lock | 39 ++++++++++++++ Cargo.toml | 1 + src/server.rs | 106 ++++++++++++++++++------------------- tests/detection_content.rs | 97 +++++++++++++++++++++++++++++++++ tests/test.config.yaml | 7 +++ 5 files changed, 197 insertions(+), 53 deletions(-) create mode 100644 tests/detection_content.rs diff --git a/Cargo.lock b/Cargo.lock index c3973b1f..3c23bb71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,6 +713,7 @@ dependencies = [ "hyper-rustls", "hyper-timeout", "hyper-util", + "mocktail", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", @@ -1556,6 +1557,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mocktail" +version = "0.1.0-alpha" +dependencies = [ + "bytes", + "futures", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "prost", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.11", + "tokio", + "tonic", + "tracing", + "url", +] + [[package]] name = "multimap" version = "0.10.0" @@ -2150,11 +2174,13 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower 0.5.2", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", @@ -3145,6 +3171,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" diff --git a/Cargo.toml b/Cargo.toml index 69d10a05..7b8df9f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ tonic-build = "0.12.3" [dev-dependencies] axum-test = "17.1.0" faux = "0.1.12" +mocktail = { path = "../mocktail/mocktail" } tracing-test = "0.2.5" [profile.release] diff --git a/src/server.rs b/src/server.rs index 28d9ad93..d9bbe93e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -166,59 +166,7 @@ pub async fn run( } // (2b) Add main guardrails server routes - let mut router = Router::new() - .route( - &format!("{}/classification-with-text-generation", API_PREFIX), - post(classification_with_gen), - ) - .route( - &format!("{}/detection/stream-content", TEXT_API_PREFIX), - post(stream_content_detection), - ) - .route( - &format!( - "{}/server-streaming-classification-with-text-generation", - API_PREFIX - ), - post(stream_classification_with_gen), - ) - .route( - &format!("{}/generation-detection", TEXT_API_PREFIX), - post(generation_with_detection), - ) - .route( - &format!("{}/detection/content", TEXT_API_PREFIX), - post(detection_content), - ) - .route( - &format!("{}/detection/chat", TEXT_API_PREFIX), - post(detect_chat), - ) - .route( - &format!("{}/detection/context", TEXT_API_PREFIX), - post(detect_context_documents), - ) - .route( - &format!("{}/detection/generated", TEXT_API_PREFIX), - post(detect_generated), - ); - - // If chat generation is configured, enable the chat completions detection endpoint. - if shared_state.orchestrator.config().chat_generation.is_some() { - info!("Enabling chat completions detection endpoint"); - router = router.route( - "/api/v2/chat/completions-detection", - post(chat_completions_detection), - ); - } - - let app = router.with_state(shared_state).layer( - TraceLayer::new_for_http() - .make_span_with(utils::trace::incoming_request_span) - .on_request(utils::trace::on_incoming_request) - .on_response(utils::trace::on_outgoing_response) - .on_eos(utils::trace::on_outgoing_eos), - ); + let app = get_app(shared_state); // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) @@ -323,6 +271,58 @@ pub fn get_health_app(state: Arc) -> Router { .with_state(state) } +pub fn get_app(state: Arc) -> Router { + let mut router = Router::new() + .route( + &format!("{}/classification-with-text-generation", API_PREFIX), + post(classification_with_gen), + ) + .route( + &format!( + "{}/server-streaming-classification-with-text-generation", + API_PREFIX + ), + post(stream_classification_with_gen), + ) + .route( + &format!("{}/generation-detection", TEXT_API_PREFIX), + post(generation_with_detection), + ) + .route( + &format!("{}/detection/content", TEXT_API_PREFIX), + post(detection_content), + ) + .route( + &format!("{}/detection/chat", TEXT_API_PREFIX), + post(detect_chat), + ) + .route( + &format!("{}/detection/context", TEXT_API_PREFIX), + post(detect_context_documents), + ) + .route( + &format!("{}/detection/generated", TEXT_API_PREFIX), + post(detect_generated), + ); + + // If chat generation is configured, enable the chat completions detection endpoint. + if state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + router.with_state(state).layer( + TraceLayer::new_for_http() + .make_span_with(utils::trace::incoming_request_span) + .on_request(utils::trace::on_incoming_request) + .on_response(utils::trace::on_outgoing_response) + .on_eos(utils::trace::on_outgoing_eos), + ) +} + async fn health() -> Result { // NOTE: we are only adding the package information in the `health` endpoint to have this endpoint // provide a non empty 200 response. If we need to add more information regarding dependencies version diff --git a/tests/detection_content.rs b/tests/detection_content.rs new file mode 100644 index 00000000..d060cb05 --- /dev/null +++ b/tests/detection_content.rs @@ -0,0 +1,97 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use axum_test::TestServer; +use common::{ensure_global_rustls_state, shared_state, ONCE}; +use fms_guardrails_orchestr8::{ + clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + models::{DetectorParams, TextContentDetectionResult}, + server::get_app, +}; +use hyper::StatusCode; +use mocktail::mock::MockSet; +use mocktail::prelude::*; +use serde_json::json; +use tracing::debug; +use tracing_test::traced_test; + +mod common; + +/// Asserts a scenario with a single detection works as expected. +/// +/// This test mocks a detector that detects the word "word" in a given input. +#[traced_test] +#[tokio::test] +async fn test_single_detection() { + ensure_global_rustls_state(); + let shared_state = ONCE.get_or_init(shared_state).await.clone(); + let server = TestServer::new(get_app(shared_state)).unwrap(); + let detector_name = "content_detector"; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, "/api/v1/text/contents"), + Mock::new( + MockRequest::json(json!({ + "contents": ["This sentence has a detection on the last word."], + "detector_params": {}, + })), + MockResponse::json(vec![vec![ContentAnalysisResponse { + start: 42, + end: 46, + text: "word".to_string(), + detection: "word".to_string(), + detection_type: "word_detection".to_string(), + score: 1.0, + evidence: None, + }]]), + ), + ); + + let mock_detector_server = + HttpMockServer::new_with_port("content_detector", mocks, 8001).unwrap(); + let _ = mock_detector_server.start().await; + + let response = server + .post("/api/v2/text/detection/content") + .json(&json!({ + "content": "This sentence has a detection on the last word.", + "detectors": { + detector_name: {} + } + })) + .await; + + debug!("{:#?}", response); + + response.assert_status(StatusCode::OK); + response.assert_json(&json!( + { + "detections": [ + { + "start": 42, + "end": 46, + "text": "word", + "detection": "word", + "detection_type": "word_detection", + "score": 1.0, + } + ] + } + )); +} diff --git a/tests/test.config.yaml b/tests/test.config.yaml index 91749356..13072e51 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -17,3 +17,10 @@ detectors: port: 8000 chunker_id: test_chunker default_threshold: 0.5 + content_detector: + type: text_contents + service: + hostname: 0.0.0.0 + port: 8001 + chunker_id: whole_doc_chunker + default_threshold: 0.5 From c02c65377459f0bbf6c2290f8707c9a645047a52 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 8 Jan 2025 11:11:01 -0300 Subject: [PATCH 003/117] Replace json macros with strong types Signed-off-by: Mateus Devino --- tests/detection_content.rs | 54 ++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index d060cb05..1f911ddb 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -15,24 +15,25 @@ */ +use std::collections::HashMap; + use axum_test::TestServer; use common::{ensure_global_rustls_state, shared_state, ONCE}; use fms_guardrails_orchestr8::{ clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, - models::{DetectorParams, TextContentDetectionResult}, + models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, server::get_app, }; use hyper::StatusCode; use mocktail::mock::MockSet; use mocktail::prelude::*; -use serde_json::json; use tracing::debug; use tracing_test::traced_test; mod common; /// Asserts a scenario with a single detection works as expected. -/// +/// /// This test mocks a detector that detects the word "word" in a given input. #[traced_test] #[tokio::test] @@ -40,17 +41,17 @@ async fn test_single_detection() { ensure_global_rustls_state(); let shared_state = ONCE.get_or_init(shared_state).await.clone(); let server = TestServer::new(get_app(shared_state)).unwrap(); - let detector_name = "content_detector"; + let detector_name = "content_detector".to_string(); // Add detector mock let mut mocks = MockSet::new(); mocks.insert( MockPath::new(Method::POST, "/api/v1/text/contents"), Mock::new( - MockRequest::json(json!({ - "contents": ["This sentence has a detection on the last word."], - "detector_params": {}, - })), + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This sentence has a detection on the last word.".to_string()], + detector_params: DetectorParams::new(), + }), MockResponse::json(vec![vec![ContentAnalysisResponse { start: 42, end: 46, @@ -69,29 +70,24 @@ async fn test_single_detection() { let response = server .post("/api/v2/text/detection/content") - .json(&json!({ - "content": "This sentence has a detection on the last word.", - "detectors": { - detector_name: {} - } - })) + .json(&TextContentDetectionHttpRequest { + content: "This sentence has a detection on the last word.".to_string(), + detectors: HashMap::from([(detector_name, DetectorParams::new())]), + }) .await; - debug!("{:#?}", response); + debug!(?response); response.assert_status(StatusCode::OK); - response.assert_json(&json!( - { - "detections": [ - { - "start": 42, - "end": 46, - "text": "word", - "detection": "word", - "detection_type": "word_detection", - "score": 1.0, - } - ] - } - )); + response.assert_json(&TextContentDetectionResult { + detections: vec![ContentAnalysisResponse { + start: 42, + end: 46, + text: "word".to_string(), + detection: "word".to_string(), + detection_type: "word_detection".to_string(), + score: 1.0, + evidence: None, + }], + }); } From 2c80ddc492a889a16ed23b5e14874b2357944a6b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 9 Jan 2025 10:45:09 -0300 Subject: [PATCH 004/117] Update test detector name to mention whole_doc_chunker Signed-off-by: Mateus Devino --- tests/detection_content.rs | 4 ++-- tests/test.config.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 1f911ddb..b6b66bbf 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -32,7 +32,7 @@ use tracing_test::traced_test; mod common; -/// Asserts a scenario with a single detection works as expected. +/// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects the word "word" in a given input. #[traced_test] @@ -41,7 +41,7 @@ async fn test_single_detection() { ensure_global_rustls_state(); let shared_state = ONCE.get_or_init(shared_state).await.clone(); let server = TestServer::new(get_app(shared_state)).unwrap(); - let detector_name = "content_detector".to_string(); + let detector_name = "content_detector_whole_doc".to_string(); // Add detector mock let mut mocks = MockSet::new(); diff --git a/tests/test.config.yaml b/tests/test.config.yaml index 13072e51..e5033a81 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -17,10 +17,10 @@ detectors: port: 8000 chunker_id: test_chunker default_threshold: 0.5 - content_detector: + content_detector_whole_doc: type: text_contents service: - hostname: 0.0.0.0 + hostname: localhost port: 8001 chunker_id: whole_doc_chunker default_threshold: 0.5 From 20375338ad67e6b1c757a15a6584ea5234ef7203 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 9 Jan 2025 10:48:45 -0300 Subject: [PATCH 005/117] Testing mock for chunker gRPC call Signed-off-by: Mateus Devino --- tests/chunker.rs | 107 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/chunker.rs diff --git a/tests/chunker.rs b/tests/chunker.rs new file mode 100644 index 00000000..6758ffae --- /dev/null +++ b/tests/chunker.rs @@ -0,0 +1,107 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use fms_guardrails_orchestr8::{ + clients::chunker::{ChunkerClient, MODEL_ID_HEADER_NAME}, + config::ServiceConfig, + pb::{ + caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, + caikit_data_model::nlp::{Token, TokenizationResults}, + }, +}; +use mocktail::mock::MockSet; +use mocktail::prelude::*; +use tracing_test::traced_test; + +mod common; + +generate_grpc_server!( + "caikit.runtime.Chunkers.ChunkersService", + MockChunkersServiceServer +); + +#[traced_test] +#[tokio::test] +async fn test_isolated_chunker_call() -> Result<(), anyhow::Error> { + // Add detector mock + let chunker_id = "sentence_chunker"; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); + + let expected_response = TokenizationResults { + results: vec![ + Token { + start: 0, + end: 9, + text: "Hi there!".to_string(), + }, + Token { + start: 0, + end: 9, + text: "how are you?".to_string(), + }, + Token { + start: 0, + end: 9, + text: "I am great!".to_string(), + }, + ], + token_count: 0, + }; + + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new( + Method::POST, + "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict", + ), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "Hi there! how are you? I am great!".to_string(), + }) + .with_headers(chunker_headers), + MockResponse::pb(expected_response.clone()), + ), + ); + + let mock_chunker_server = MockChunkersServiceServer::new(mocks).unwrap(); + let _ = mock_chunker_server.start().await; + + let client = ChunkerClient::new(&ServiceConfig { + hostname: "localhost".to_string(), + port: Some(mock_chunker_server.addr().port()), + request_timeout: None, + tls: None, + }) + .await; + + let response = client + .tokenization_task_predict( + chunker_id, + ChunkerTokenizationTaskRequest { + text: "Hi there! how are you? I am great!".to_string(), + }, + ) + .await; + + dbg!(&response); + + assert!(response.is_ok()); + assert!(response.unwrap() == expected_response); + + Ok(()) +} From 64ef86af37511010eb54121a08765227e76ac3c9 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 5 Feb 2025 12:42:35 -0300 Subject: [PATCH 006/117] Update detection_content.rs::test_single_detection() to include detector_id Signed-off-by: Mateus Devino --- tests/detection_content.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index b6b66bbf..7f1708cd 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -58,6 +58,7 @@ async fn test_single_detection() { text: "word".to_string(), detection: "word".to_string(), detection_type: "word_detection".to_string(), + detector_id: Some(detector_name.to_string()), score: 1.0, evidence: None, }]]), @@ -72,7 +73,7 @@ async fn test_single_detection() { .post("/api/v2/text/detection/content") .json(&TextContentDetectionHttpRequest { content: "This sentence has a detection on the last word.".to_string(), - detectors: HashMap::from([(detector_name, DetectorParams::new())]), + detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), }) .await; @@ -86,6 +87,7 @@ async fn test_single_detection() { text: "word".to_string(), detection: "word".to_string(), detection_type: "word_detection".to_string(), + detector_id: Some(detector_name.to_string()), score: 1.0, evidence: None, }], From a208f8058b70c20d2dae0dc3dbdce2263ba3f9ec Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 5 Feb 2025 14:25:12 -0300 Subject: [PATCH 007/117] Update detection_content.rs::test_single_detection_whole_doc() mocks to use random ports Signed-off-by: Mateus Devino --- tests/common/mod.rs | 2 ++ tests/detection_content.rs | 34 +++++++++++++++++++++++++--------- tests/test.config.yaml | 1 - 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 971cf599..2483870b 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -9,6 +9,8 @@ use tokio::sync::OnceCell; /// Async lazy initialization of shared state using tokio::sync::OnceCell pub static ONCE: OnceCell> = OnceCell::const_new(); +pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; + /// The actual async function that initializes the shared state if not already initialized pub async fn shared_state() -> Arc { let config = OrchestratorConfig::load("tests/test.config.yaml") diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 7f1708cd..0a4c4ad4 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -15,14 +15,16 @@ */ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use axum_test::TestServer; -use common::{ensure_global_rustls_state, shared_state, ONCE}; +use common::{ensure_global_rustls_state, CONFIG_FILE_PATH}; use fms_guardrails_orchestr8::{ clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + config::OrchestratorConfig, models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, - server::get_app, + orchestrator::Orchestrator, + server::{get_app, ServerState}, }; use hyper::StatusCode; use mocktail::mock::MockSet; @@ -37,11 +39,9 @@ mod common; /// This test mocks a detector that detects the word "word" in a given input. #[traced_test] #[tokio::test] -async fn test_single_detection() { +async fn test_single_detection_whole_doc() { ensure_global_rustls_state(); - let shared_state = ONCE.get_or_init(shared_state).await.clone(); - let server = TestServer::new(get_app(shared_state)).unwrap(); - let detector_name = "content_detector_whole_doc".to_string(); + let detector_name = "content_detector_whole_doc"; // Add detector mock let mut mocks = MockSet::new(); @@ -65,10 +65,25 @@ async fn test_single_detection() { ), ); - let mock_detector_server = - HttpMockServer::new_with_port("content_detector", mocks, 8001).unwrap(); + // Start mock server + let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); let _ = mock_detector_server.start().await; + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + + // assign mock server port to detector config + config + .detectors + .get_mut(detector_name) + .unwrap() + .service + .port = Some(mock_detector_server.addr().port()); + + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + let shared_state = Arc::new(ServerState::new(orchestrator)); + let server = TestServer::new(get_app(shared_state)).unwrap(); + + // Make orchestrator call let response = server .post("/api/v2/text/detection/content") .json(&TextContentDetectionHttpRequest { @@ -79,6 +94,7 @@ async fn test_single_detection() { debug!(?response); + // assertions response.assert_status(StatusCode::OK); response.assert_json(&TextContentDetectionResult { detections: vec![ContentAnalysisResponse { diff --git a/tests/test.config.yaml b/tests/test.config.yaml index e5033a81..cb176e64 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -21,6 +21,5 @@ detectors: type: text_contents service: hostname: localhost - port: 8001 chunker_id: whole_doc_chunker default_threshold: 0.5 From e30a2df13ea4c84251869b389a0a7727ee3ba89d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 5 Feb 2025 14:33:36 -0300 Subject: [PATCH 008/117] Make test_single_detection_whole_doc() more meaningful Signed-off-by: Mateus Devino --- tests/detection_content.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 0a4c4ad4..58ed1ec9 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -36,12 +36,12 @@ mod common; /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// -/// This test mocks a detector that detects the word "word" in a given input. +/// This test mocks a detector that detects text between . #[traced_test] #[tokio::test] async fn test_single_detection_whole_doc() { ensure_global_rustls_state(); - let detector_name = "content_detector_whole_doc"; + let detector_name = "angle_brackets_detector_whole_doc"; // Add detector mock let mut mocks = MockSet::new(); @@ -49,15 +49,15 @@ async fn test_single_detection_whole_doc() { MockPath::new(Method::POST, "/api/v1/text/contents"), Mock::new( MockRequest::json(ContentAnalysisRequest { - contents: vec!["This sentence has a detection on the last word.".to_string()], + contents: vec!["This sentence has .".to_string()], detector_params: DetectorParams::new(), }), MockResponse::json(vec![vec![ContentAnalysisResponse { - start: 42, - end: 46, - text: "word".to_string(), - detection: "word".to_string(), - detection_type: "word_detection".to_string(), + start: 18, + end: 35, + text: "a detection here".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), detector_id: Some(detector_name.to_string()), score: 1.0, evidence: None, @@ -87,7 +87,7 @@ async fn test_single_detection_whole_doc() { let response = server .post("/api/v2/text/detection/content") .json(&TextContentDetectionHttpRequest { - content: "This sentence has a detection on the last word.".to_string(), + content: "This sentence has .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), }) .await; @@ -98,11 +98,11 @@ async fn test_single_detection_whole_doc() { response.assert_status(StatusCode::OK); response.assert_json(&TextContentDetectionResult { detections: vec![ContentAnalysisResponse { - start: 42, - end: 46, - text: "word".to_string(), - detection: "word".to_string(), - detection_type: "word_detection".to_string(), + start: 18, + end: 35, + text: "a detection here".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), detector_id: Some(detector_name.to_string()), score: 1.0, evidence: None, From 7b4f0ac31e84077347582986d6e1a6fad9a82c9a Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 13:35:39 -0300 Subject: [PATCH 009/117] Add test case: detection_content.rs::test_single_detection_sentence_chunker Signed-off-by: Mateus Devino --- tests/detection_content.rs | 148 ++++++++++++++++++++++++++++++++++++- tests/test.config.yaml | 12 ++- 2 files changed, 157 insertions(+), 3 deletions(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 58ed1ec9..243ca74c 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -20,10 +20,17 @@ use std::{collections::HashMap, sync::Arc}; use axum_test::TestServer; use common::{ensure_global_rustls_state, CONFIG_FILE_PATH}; use fms_guardrails_orchestr8::{ - clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + clients::{ + chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, + detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + }, config::OrchestratorConfig, models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, orchestrator::Orchestrator, + pb::{ + caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, + caikit_data_model::nlp::{Token, TokenizationResults}, + }, server::{get_app, ServerState}, }; use hyper::StatusCode; @@ -34,6 +41,13 @@ use tracing_test::traced_test; mod common; +generate_grpc_server!( + "caikit.runtime.Chunkers.ChunkersService", + MockChunkersServiceServer +); + +const ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT: &str = "/api/v2/text/detection/content"; + /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects text between . @@ -85,7 +99,7 @@ async fn test_single_detection_whole_doc() { // Make orchestrator call let response = server - .post("/api/v2/text/detection/content") + .post(ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT) .json(&TextContentDetectionHttpRequest { content: "This sentence has .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), @@ -109,3 +123,133 @@ async fn test_single_detection_whole_doc() { }], }); } + +/// Asserts a scenario with a single detection works as expected (with sentence chunker). +/// +/// This test mocks a detector that detects text between . +#[traced_test] +#[tokio::test] +async fn test_single_detection_sentence_chunker() { + ensure_global_rustls_state(); + + // Add chunker mock + let chunker_id = "sentence_chunker"; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); + + let expected_chunker_response = TokenizationResults { + results: vec![ + Token { + start: 0, + end: 40, + text: "This sentence does not have a detection.".to_string(), + }, + Token { + start: 41, + end: 61, + text: "But .".to_string(), + }, + ], + token_count: 0, + }; + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new( + Method::POST, + "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict", + ), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "This sentence does not have a detection. But .".to_string(), + }) + .with_headers(chunker_headers), + MockResponse::pb(expected_chunker_response.clone()), + ), + ); + + let mock_chunker_server = MockChunkersServiceServer::new(mocks).unwrap(); + let _ = mock_chunker_server.start().await; + + // Add detector mock + let detector_name = "angle_brackets_detector_sentence"; + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, "/api/v1/text/contents"), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec![ + "This sentence does not have a detection.".to_string(), + "But .".to_string(), + ], + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![ + vec![], + vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "this one does".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), + detector_id: Some(detector_name.to_string()), + score: 1.0, + evidence: None, + }], + ]), + ), + ); + + // Start detector mock server + let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); + let _ = mock_detector_server.start().await; + + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + + // assign mock server port to detector config + config + .detectors + .get_mut(detector_name) + .unwrap() + .service + .port = Some(mock_detector_server.addr().port()); + + // assign mock server port to chunker config + config + .chunkers + .as_mut() + .unwrap() + .get_mut(chunker_id) + .unwrap() + .service + .port = Some(mock_chunker_server.addr().port()); + + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + let shared_state = Arc::new(ServerState::new(orchestrator)); + let server = TestServer::new(get_app(shared_state)).unwrap(); + + // Make orchestrator call + let response = server + .post(ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This sentence does not have a detection. But .".to_string(), + detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), + }) + .await; + + debug!(?response); + + // assertions + response.assert_status(StatusCode::OK); + response.assert_json(&TextContentDetectionResult { + detections: vec![ContentAnalysisResponse { + start: 45, + end: 59, + text: "this one does".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), + detector_id: Some(detector_name.to_string()), + score: 1.0, + evidence: None, + }], + }); +} diff --git a/tests/test.config.yaml b/tests/test.config.yaml index cb176e64..e6149d9f 100644 --- a/tests/test.config.yaml +++ b/tests/test.config.yaml @@ -9,6 +9,10 @@ chunkers: service: hostname: localhost port: 8085 + sentence_chunker: + service: + hostname: localhost + type: sentence detectors: test_detector: type: text_contents @@ -17,7 +21,13 @@ detectors: port: 8000 chunker_id: test_chunker default_threshold: 0.5 - content_detector_whole_doc: + angle_brackets_detector_sentence: + type: text_contents + service: + hostname: localhost + chunker_id: sentence_chunker + default_threshold: 0.5 + angle_brackets_detector_whole_doc: type: text_contents service: hostname: localhost From 01aa131277424bdac36ad8211bc4af4af19d323e Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 14:09:08 -0300 Subject: [PATCH 010/117] Move specific code back to canary_test.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 18 ++++++++++++++++-- tests/common/mod.rs | 18 ------------------ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 19c8a129..ff34023f 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -1,13 +1,27 @@ use axum_test::TestServer; -use common::{ensure_global_rustls_state, shared_state, ONCE}; -use fms_guardrails_orchestr8::server::get_health_app; +use common::{ensure_global_rustls_state, CONFIG_FILE_PATH}; +use fms_guardrails_orchestr8::server::{get_health_app, ServerState}; use hyper::StatusCode; use serde_json::Value; +use std::sync::Arc; use tracing::debug; use tracing_test::traced_test; +use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; +use tokio::sync::OnceCell; + mod common; +/// Async lazy initialization of shared state using tokio::sync::OnceCell +pub static ONCE: OnceCell> = OnceCell::const_new(); + +/// The actual async function that initializes the shared state if not already initialized +pub async fn shared_state() -> Arc { + let config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + Arc::new(ServerState::new(orchestrator)) +} + /// Checks if the health endpoint is working /// NOTE: We do not currently mock client services yet, so this test is /// superficially testing the client health endpoints on the orchestrator is accessible diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 2483870b..f4947db5 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,25 +1,7 @@ -use std::sync::Arc; - -use fms_guardrails_orchestr8::{ - config::OrchestratorConfig, orchestrator::Orchestrator, server::ServerState, -}; use rustls::crypto::ring; -use tokio::sync::OnceCell; - -/// Async lazy initialization of shared state using tokio::sync::OnceCell -pub static ONCE: OnceCell> = OnceCell::const_new(); pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; -/// The actual async function that initializes the shared state if not already initialized -pub async fn shared_state() -> Arc { - let config = OrchestratorConfig::load("tests/test.config.yaml") - .await - .unwrap(); - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Arc::new(ServerState::new(orchestrator)) -} - pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } From 6af1ac4fd912fa02e744aef609b228007dfb91ee Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 14:22:05 -0300 Subject: [PATCH 011/117] refactor: extract constants Signed-off-by: Mateus Devino --- tests/detection_content.rs | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 243ca74c..e79b84cc 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -46,7 +46,14 @@ generate_grpc_server!( MockChunkersServiceServer ); -const ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT: &str = "/api/v2/text/detection/content"; +// Constants +const ENDPOINT_ORCHESTRATOR: &str = "/api/v2/text/detection/content"; +const ENDPOINT_DETECTOR: &str = "/api/v1/text/contents"; + +const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; +const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; + +const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// @@ -55,12 +62,12 @@ const ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT: &str = "/api/v2/text/detection/co #[tokio::test] async fn test_single_detection_whole_doc() { ensure_global_rustls_state(); - let detector_name = "angle_brackets_detector_whole_doc"; + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, "/api/v1/text/contents"), + MockPath::new(Method::POST, ENDPOINT_DETECTOR), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This sentence has .".to_string()], @@ -99,7 +106,7 @@ async fn test_single_detection_whole_doc() { // Make orchestrator call let response = server - .post(ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT) + .post(ENDPOINT_ORCHESTRATOR) .json(&TextContentDetectionHttpRequest { content: "This sentence has .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), @@ -133,7 +140,7 @@ async fn test_single_detection_sentence_chunker() { ensure_global_rustls_state(); // Add chunker mock - let chunker_id = "sentence_chunker"; + let chunker_id = CHUNKER_NAME_SENTENCE; let mut chunker_headers = HeaderMap::new(); chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); @@ -171,10 +178,10 @@ async fn test_single_detection_sentence_chunker() { let _ = mock_chunker_server.start().await; // Add detector mock - let detector_name = "angle_brackets_detector_sentence"; + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, "/api/v1/text/contents"), + MockPath::new(Method::POST, ENDPOINT_DETECTOR), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -229,7 +236,7 @@ async fn test_single_detection_sentence_chunker() { // Make orchestrator call let response = server - .post(ORCHESTRATOR_DETECTION_CONTENT_ENDPOINT) + .post(ENDPOINT_ORCHESTRATOR) .json(&TextContentDetectionHttpRequest { content: "This sentence does not have a detection. But .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), From 1cd39d0d89074c83d9f25e9a411500c63016b81e Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 14:29:24 -0300 Subject: [PATCH 012/117] refactor: move grpc server macros to tests common module Signed-off-by: Mateus Devino --- tests/chunker.rs | 9 ++------- tests/common/mod.rs | 7 +++++++ tests/detection_content.rs | 8 +------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index 6758ffae..5f679800 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -15,6 +15,7 @@ */ +use common::MockChunkersServiceServer; use fms_guardrails_orchestr8::{ clients::chunker::{ChunkerClient, MODEL_ID_HEADER_NAME}, config::ServiceConfig, @@ -23,20 +24,14 @@ use fms_guardrails_orchestr8::{ caikit_data_model::nlp::{Token, TokenizationResults}, }, }; -use mocktail::mock::MockSet; use mocktail::prelude::*; use tracing_test::traced_test; mod common; -generate_grpc_server!( - "caikit.runtime.Chunkers.ChunkersService", - MockChunkersServiceServer -); - #[traced_test] #[tokio::test] -async fn test_isolated_chunker_call() -> Result<(), anyhow::Error> { +async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { // Add detector mock let chunker_id = "sentence_chunker"; let mut chunker_headers = HeaderMap::new(); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index f4947db5..492691dd 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,5 +1,12 @@ +use mocktail::generate_grpc_server; +use mocktail::mock::MockSet; use rustls::crypto::ring; +generate_grpc_server!( + "caikit.runtime.Chunkers.ChunkersService", + MockChunkersServiceServer +); + pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; pub fn ensure_global_rustls_state() { diff --git a/tests/detection_content.rs b/tests/detection_content.rs index e79b84cc..3418f2f9 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use axum_test::TestServer; -use common::{ensure_global_rustls_state, CONFIG_FILE_PATH}; +use common::{ensure_global_rustls_state, MockChunkersServiceServer, CONFIG_FILE_PATH}; use fms_guardrails_orchestr8::{ clients::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, @@ -34,18 +34,12 @@ use fms_guardrails_orchestr8::{ server::{get_app, ServerState}, }; use hyper::StatusCode; -use mocktail::mock::MockSet; use mocktail::prelude::*; use tracing::debug; use tracing_test::traced_test; mod common; -generate_grpc_server!( - "caikit.runtime.Chunkers.ChunkersService", - MockChunkersServiceServer -); - // Constants const ENDPOINT_ORCHESTRATOR: &str = "/api/v2/text/detection/content"; const ENDPOINT_DETECTOR: &str = "/api/v1/text/contents"; From 33839b1786d941e42c8b5895b50f80525c3bfc6d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 14:33:41 -0300 Subject: [PATCH 013/117] Add copyright notice to common test module Signed-off-by: Mateus Devino --- tests/common/mod.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 492691dd..b83c0a83 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + use mocktail::generate_grpc_server; use mocktail::mock::MockSet; use rustls::crypto::ring; From 6eaf9feeccfe01f92463c6ec1a8392e352bc7a2b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 14:54:12 -0300 Subject: [PATCH 014/117] refactor: create function for orchestrator config Signed-off-by: Mateus Devino --- tests/common/mod.rs | 28 ++++++++++++++++++++++++++++ tests/detection_content.rs | 20 +++++--------------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b83c0a83..9eb4749e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -15,8 +15,14 @@ */ +use std::sync::Arc; + +use fms_guardrails_orchestr8::config::OrchestratorConfig; +use fms_guardrails_orchestr8::orchestrator::Orchestrator; +use fms_guardrails_orchestr8::server::ServerState; use mocktail::generate_grpc_server; use mocktail::mock::MockSet; +use mocktail::server::HttpMockServer; use rustls::crypto::ring; generate_grpc_server!( @@ -24,8 +30,30 @@ generate_grpc_server!( MockChunkersServiceServer ); +/// Default orchestrator configuration file for integration tests. pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; +/// pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } + +/// Creates an orchestrator shared state based off from the default test configuration file and given detector mocks. +pub async fn create_orchestrator_shared_state(detectors: Vec) -> Arc { + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + + for detector_mock_server in detectors { + let _ = detector_mock_server.start().await; + + // assign mock server port to detector config + config + .detectors + .get_mut(detector_mock_server.name()) + .unwrap() + .service + .port = Some(detector_mock_server.addr().port()); + } + + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + Arc::new(ServerState::new(orchestrator)) +} diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 3418f2f9..dad212f2 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -18,7 +18,10 @@ use std::{collections::HashMap, sync::Arc}; use axum_test::TestServer; -use common::{ensure_global_rustls_state, MockChunkersServiceServer, CONFIG_FILE_PATH}; +use common::{ + create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer, + CONFIG_FILE_PATH, +}; use fms_guardrails_orchestr8::{ clients::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, @@ -80,22 +83,9 @@ async fn test_single_detection_whole_doc() { ), ); - // Start mock server let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let _ = mock_detector_server.start().await; - - let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); - - // assign mock server port to detector config - config - .detectors - .get_mut(detector_name) - .unwrap() - .service - .port = Some(mock_detector_server.addr().port()); - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - let shared_state = Arc::new(ServerState::new(orchestrator)); + let shared_state = create_orchestrator_shared_state(vec![mock_detector_server]).await; let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call From 17435b0bb5d1efe979ca74b7be2151546894df70 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 6 Feb 2025 15:29:03 -0300 Subject: [PATCH 015/117] Add chunkers logic to create_orchestrator_shared_state() Signed-off-by: Mateus Devino --- tests/common/mod.rs | 21 +++++++++- tests/detection_content.rs | 81 +++++++++++++------------------------- 2 files changed, 47 insertions(+), 55 deletions(-) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 9eb4749e..5dccb313 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -38,8 +38,11 @@ pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } -/// Creates an orchestrator shared state based off from the default test configuration file and given detector mocks. -pub async fn create_orchestrator_shared_state(detectors: Vec) -> Arc { +/// Starts mock servers and adds them to orchestrator configuration. +pub async fn create_orchestrator_shared_state( + detectors: Vec, + chunkers: Vec<(&str, MockChunkersServiceServer)>, +) -> Arc { let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); for detector_mock_server in detectors { @@ -54,6 +57,20 @@ pub async fn create_orchestrator_shared_state(detectors: Vec) -> .port = Some(detector_mock_server.addr().port()); } + for (chunker_name, chunker_mock_server) in chunkers { + let _ = chunker_mock_server.start().await; + + // assign mock server port to chunker config + config + .chunkers + .as_mut() + .unwrap() + .get_mut(chunker_name) + .unwrap() + .service + .port = Some(chunker_mock_server.addr().port()); + } + let orchestrator = Orchestrator::new(config, false).await.unwrap(); Arc::new(ServerState::new(orchestrator)) } diff --git a/tests/detection_content.rs b/tests/detection_content.rs index dad212f2..3acd6003 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -15,26 +15,23 @@ */ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use axum_test::TestServer; use common::{ create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer, - CONFIG_FILE_PATH, }; use fms_guardrails_orchestr8::{ clients::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, detector::{ContentAnalysisRequest, ContentAnalysisResponse}, }, - config::OrchestratorConfig, models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, - orchestrator::Orchestrator, pb::{ caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, caikit_data_model::nlp::{Token, TokenizationResults}, }, - server::{get_app, ServerState}, + server::get_app, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -83,9 +80,9 @@ async fn test_single_detection_whole_doc() { ), ); + // Setup orchestrator and detector servers let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - - let shared_state = create_orchestrator_shared_state(vec![mock_detector_server]).await; + let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]).await; let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call @@ -128,23 +125,8 @@ async fn test_single_detection_sentence_chunker() { let mut chunker_headers = HeaderMap::new(); chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); - let expected_chunker_response = TokenizationResults { - results: vec![ - Token { - start: 0, - end: 40, - text: "This sentence does not have a detection.".to_string(), - }, - Token { - start: 41, - end: 61, - text: "But .".to_string(), - }, - ], - token_count: 0, - }; - let mut mocks = MockSet::new(); - mocks.insert( + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( MockPath::new( Method::POST, "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict", @@ -154,13 +136,24 @@ async fn test_single_detection_sentence_chunker() { text: "This sentence does not have a detection. But .".to_string(), }) .with_headers(chunker_headers), - MockResponse::pb(expected_chunker_response.clone()), + MockResponse::pb(TokenizationResults { + results: vec![ + Token { + start: 0, + end: 40, + text: "This sentence does not have a detection.".to_string(), + }, + Token { + start: 41, + end: 61, + text: "But .".to_string(), + }, + ], + token_count: 0, + }), ), ); - let mock_chunker_server = MockChunkersServiceServer::new(mocks).unwrap(); - let _ = mock_chunker_server.start().await; - // Add detector mock let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; let mut mocks = MockSet::new(); @@ -190,32 +183,14 @@ async fn test_single_detection_sentence_chunker() { ), ); - // Start detector mock server + // Start orchestrator, chunker and detector servers. + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks).unwrap(); let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let _ = mock_detector_server.start().await; - - let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); - - // assign mock server port to detector config - config - .detectors - .get_mut(detector_name) - .unwrap() - .service - .port = Some(mock_detector_server.addr().port()); - - // assign mock server port to chunker config - config - .chunkers - .as_mut() - .unwrap() - .get_mut(chunker_id) - .unwrap() - .service - .port = Some(mock_chunker_server.addr().port()); - - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - let shared_state = Arc::new(ServerState::new(orchestrator)); + let shared_state = create_orchestrator_shared_state( + vec![mock_detector_server], + vec![(chunker_id, mock_chunker_server)], + ) + .await; let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call From b3706ac3b60eead9239987dcbb0a81ac2f6f6b9a Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 7 Feb 2025 10:35:24 -0300 Subject: [PATCH 016/117] Change mocktail source to IBM git repo Signed-off-by: Mateus Devino --- Cargo.lock | 1 + Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 3c23bb71..f0f2428a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1560,6 +1560,7 @@ dependencies = [ [[package]] name = "mocktail" version = "0.1.0-alpha" +source = "git+https://github.com/IBM/mocktail#f635426acdbfa42e58067319d76e489bec64f825" dependencies = [ "bytes", "futures", diff --git a/Cargo.toml b/Cargo.toml index 7b8df9f6..c57a8899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ tonic-build = "0.12.3" [dev-dependencies] axum-test = "17.1.0" faux = "0.1.12" -mocktail = { path = "../mocktail/mocktail" } +mocktail = { git = "https://github.com/IBM/mocktail", version = "0.1.0-alpha" } tracing-test = "0.2.5" [profile.release] From f6e540879e60d2105a64833910c8f86179ad146b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 7 Feb 2025 18:45:56 -0300 Subject: [PATCH 017/117] Refactor chunker unary endpoint as constant Signed-off-by: Mateus Devino --- tests/chunker.rs | 7 ++----- tests/common/mod.rs | 3 +++ tests/detection_content.rs | 6 ++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index 5f679800..9ec9d38f 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -15,7 +15,7 @@ */ -use common::MockChunkersServiceServer; +use common::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}; use fms_guardrails_orchestr8::{ clients::chunker::{ChunkerClient, MODEL_ID_HEADER_NAME}, config::ServiceConfig, @@ -60,10 +60,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { let mut mocks = MockSet::new(); mocks.insert( - MockPath::new( - Method::POST, - "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict", - ), + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "Hi there! how are you? I am great!".to_string(), diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 5dccb313..40989869 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -30,6 +30,9 @@ generate_grpc_server!( MockChunkersServiceServer ); +pub const CHUNKER_UNARY_ENDPOINT: &str = + "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; + /// Default orchestrator configuration file for integration tests. pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 3acd6003..63345756 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use axum_test::TestServer; use common::{ create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer, + CHUNKER_UNARY_ENDPOINT, }; use fms_guardrails_orchestr8::{ clients::{ @@ -127,10 +128,7 @@ async fn test_single_detection_sentence_chunker() { let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new( - Method::POST, - "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict", - ), + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "This sentence does not have a detection. But .".to_string(), From 14325e81e4e200902cb781c48c26f2cd16d64590 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 09:57:37 -0300 Subject: [PATCH 018/117] Update create_orchestrator_shared_state() to return Result Signed-off-by: Mateus Devino --- tests/common/mod.rs | 8 ++++---- tests/detection_content.rs | 7 +++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 40989869..59754e50 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -45,11 +45,11 @@ pub fn ensure_global_rustls_state() { pub async fn create_orchestrator_shared_state( detectors: Vec, chunkers: Vec<(&str, MockChunkersServiceServer)>, -) -> Arc { +) -> Result, mocktail::Error> { let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); for detector_mock_server in detectors { - let _ = detector_mock_server.start().await; + let _ = detector_mock_server.start().await?; // assign mock server port to detector config config @@ -61,7 +61,7 @@ pub async fn create_orchestrator_shared_state( } for (chunker_name, chunker_mock_server) in chunkers { - let _ = chunker_mock_server.start().await; + let _ = chunker_mock_server.start().await?; // assign mock server port to chunker config config @@ -75,5 +75,5 @@ pub async fn create_orchestrator_shared_state( } let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Arc::new(ServerState::new(orchestrator)) + Ok(Arc::new(ServerState::new(orchestrator))) } diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 63345756..bd1e6e9f 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -83,7 +83,9 @@ async fn test_single_detection_whole_doc() { // Setup orchestrator and detector servers let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]).await; + let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]) + .await + .unwrap(); let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call @@ -188,7 +190,8 @@ async fn test_single_detection_sentence_chunker() { vec![mock_detector_server], vec![(chunker_id, mock_chunker_server)], ) - .await; + .await + .unwrap(); let server = TestServer::new(get_app(shared_state)).unwrap(); // Make orchestrator call From 6ab815b873b2d39b1cee02e316e4dc14ee223def Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 17:12:57 -0300 Subject: [PATCH 019/117] Remove unwrap() calls Signed-off-by: Mateus Devino --- tests/chunker.rs | 4 ++-- tests/detection_content.rs | 27 ++++++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index 9ec9d38f..bca60d98 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -70,7 +70,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { ), ); - let mock_chunker_server = MockChunkersServiceServer::new(mocks).unwrap(); + let mock_chunker_server = MockChunkersServiceServer::new(mocks)?; let _ = mock_chunker_server.start().await; let client = ChunkerClient::new(&ServiceConfig { @@ -93,7 +93,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { dbg!(&response); assert!(response.is_ok()); - assert!(response.unwrap() == expected_response); + assert!(response? == expected_response); Ok(()) } diff --git a/tests/detection_content.rs b/tests/detection_content.rs index bd1e6e9f..4afd9ecf 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -55,7 +55,7 @@ const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; /// This test mocks a detector that detects text between . #[traced_test] #[tokio::test] -async fn test_single_detection_whole_doc() { +async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -82,11 +82,9 @@ async fn test_single_detection_whole_doc() { ); // Setup orchestrator and detector servers - let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); - let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]) - .await - .unwrap(); - let server = TestServer::new(get_app(shared_state)).unwrap(); + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]).await?; + let server = TestServer::new(get_app(shared_state))?; // Make orchestrator call let response = server @@ -113,6 +111,8 @@ async fn test_single_detection_whole_doc() { evidence: None, }], }); + + Ok(()) } /// Asserts a scenario with a single detection works as expected (with sentence chunker). @@ -120,13 +120,13 @@ async fn test_single_detection_whole_doc() { /// This test mocks a detector that detects text between . #[traced_test] #[tokio::test] -async fn test_single_detection_sentence_chunker() { +async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); // Add chunker mock let chunker_id = CHUNKER_NAME_SENTENCE; let mut chunker_headers = HeaderMap::new(); - chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( @@ -184,15 +184,14 @@ async fn test_single_detection_sentence_chunker() { ); // Start orchestrator, chunker and detector servers. - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks).unwrap(); - let mock_detector_server = HttpMockServer::new(detector_name, mocks).unwrap(); + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let shared_state = create_orchestrator_shared_state( vec![mock_detector_server], vec![(chunker_id, mock_chunker_server)], ) - .await - .unwrap(); - let server = TestServer::new(get_app(shared_state)).unwrap(); + .await?; + let server = TestServer::new(get_app(shared_state))?; // Make orchestrator call let response = server @@ -219,4 +218,6 @@ async fn test_single_detection_sentence_chunker() { evidence: None, }], }); + + Ok(()) } From b51eb12df30a74b86d01cac0943e997f93b9a12a Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 17:37:49 -0300 Subject: [PATCH 020/117] Allow dead code on test common file Signed-off-by: Mateus Devino --- tests/common/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 59754e50..1b571957 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -15,6 +15,11 @@ */ +// This is needed because integration test files are compiled as separate crates. +// If any of the code in this file is not used by any of the test files, a warning about unused code is generated. +// For more: https://github.com/rust-lang/rust/issues/46379 +#![allow(dead_code)] + use std::sync::Arc; use fms_guardrails_orchestr8::config::OrchestratorConfig; From 4cd10d6a2a29896fd55ad6eec90fb8cc6dcc68de Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 17:45:34 -0300 Subject: [PATCH 021/117] Remove unneeded changes in canary_test.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index ff34023f..224a759d 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -1,23 +1,28 @@ +use std::sync::Arc; + use axum_test::TestServer; -use common::{ensure_global_rustls_state, CONFIG_FILE_PATH}; -use fms_guardrails_orchestr8::server::{get_health_app, ServerState}; +use common::ensure_global_rustls_state; +use fms_guardrails_orchestr8::{ + config::OrchestratorConfig, + orchestrator::Orchestrator, + server::{get_health_app, ServerState}, +}; use hyper::StatusCode; use serde_json::Value; -use std::sync::Arc; +use tokio::sync::OnceCell; use tracing::debug; use tracing_test::traced_test; -use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; -use tokio::sync::OnceCell; - mod common; /// Async lazy initialization of shared state using tokio::sync::OnceCell -pub static ONCE: OnceCell> = OnceCell::const_new(); +static ONCE: OnceCell> = OnceCell::const_new(); /// The actual async function that initializes the shared state if not already initialized -pub async fn shared_state() -> Arc { - let config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); +async fn shared_state() -> Arc { + let config = OrchestratorConfig::load("tests/test.config.yaml") + .await + .unwrap(); let orchestrator = Orchestrator::new(config, false).await.unwrap(); Arc::new(ServerState::new(orchestrator)) } From 61b17575d4e342b0ebc4fffcc58905624f7c75d7 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 10 Feb 2025 17:51:45 -0300 Subject: [PATCH 022/117] Add copyright text to canary_test.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 21 +++++++++++++++++++++ tests/common/mod.rs | 1 + 2 files changed, 22 insertions(+) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 224a759d..72e8deda 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -1,3 +1,24 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +// This is needed because integration test files are compiled as separate crates. +// If any of the code in this file is not used by any of the test files, a warning about unused code is generated. +// For more: https://github.com/rust-lang/rust/issues/46379 + use std::sync::Arc; use axum_test::TestServer; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 1b571957..6eec0a2d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -18,6 +18,7 @@ // This is needed because integration test files are compiled as separate crates. // If any of the code in this file is not used by any of the test files, a warning about unused code is generated. // For more: https://github.com/rust-lang/rust/issues/46379 + #![allow(dead_code)] use std::sync::Arc; From 4c21b76d098ca5a50c5babbea260a1ba1aea6500 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 11 Feb 2025 10:56:25 -0300 Subject: [PATCH 023/117] Remove allow macro from tests/common/mod.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 2 +- tests/chunker.rs | 2 +- tests/common/mod.rs | 2 -- tests/detection_content.rs | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 72e8deda..ebd23933 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -34,7 +34,7 @@ use tokio::sync::OnceCell; use tracing::debug; use tracing_test::traced_test; -mod common; +pub mod common; /// Async lazy initialization of shared state using tokio::sync::OnceCell static ONCE: OnceCell> = OnceCell::const_new(); diff --git a/tests/chunker.rs b/tests/chunker.rs index bca60d98..fd9fb085 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -27,7 +27,7 @@ use fms_guardrails_orchestr8::{ use mocktail::prelude::*; use tracing_test::traced_test; -mod common; +pub mod common; #[traced_test] #[tokio::test] diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 6eec0a2d..166eab60 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -19,8 +19,6 @@ // If any of the code in this file is not used by any of the test files, a warning about unused code is generated. // For more: https://github.com/rust-lang/rust/issues/46379 -#![allow(dead_code)] - use std::sync::Arc; use fms_guardrails_orchestr8::config::OrchestratorConfig; diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 4afd9ecf..c3ca5421 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -39,7 +39,7 @@ use mocktail::prelude::*; use tracing::debug; use tracing_test::traced_test; -mod common; +pub mod common; // Constants const ENDPOINT_ORCHESTRATOR: &str = "/api/v2/text/detection/content"; From d6c917295b0813cd0e09a321585503179c762522 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 11 Feb 2025 13:28:03 -0300 Subject: [PATCH 024/117] wip: test isolated NLP streaming call Signed-off-by: Mateus Devino --- src/clients/nlp.rs | 2 +- tests/common/mod.rs | 5 ++ tests/generation_nlp.rs | 101 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 tests/generation_nlp.rs diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index e2651851..6bb60096 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -44,7 +44,7 @@ use crate::{ }; const DEFAULT_PORT: u16 = 8085; -const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; +pub const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; #[cfg_attr(test, faux::create)] #[derive(Clone)] diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 166eab60..f48e925f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -34,9 +34,14 @@ generate_grpc_server!( MockChunkersServiceServer ); +generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); + pub const CHUNKER_UNARY_ENDPOINT: &str = "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; +pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = + "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; + /// Default orchestrator configuration file for integration tests. pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs new file mode 100644 index 00000000..217c2744 --- /dev/null +++ b/tests/generation_nlp.rs @@ -0,0 +1,101 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use common::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}; +use fms_guardrails_orchestr8::{ + clients::{nlp::MODEL_ID_HEADER_NAME, NlpClient}, + config::ServiceConfig, + pb::{ + caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, + caikit_data_model::nlp::GeneratedTextStreamResult, + }, +}; +use futures::StreamExt; +use mocktail::prelude::*; +use tracing_test::traced_test; + +pub mod common; + +#[traced_test] +#[tokio::test] +async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { + // Add detector mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let expected_response = vec![ + GeneratedTextStreamResult { + generated_text: "I".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " am".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " great!".to_string(), + ..Default::default() + }, + ]; + + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! how are you?".to_string(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb_stream(expected_response.clone()), + ), + ); + + let generation_nlp_server = MockNlpServiceServer::new(mocks)?; + let _ = generation_nlp_server.start().await; + + let client = NlpClient::new(&ServiceConfig { + hostname: "localhost".to_string(), + port: Some(generation_nlp_server.addr().port()), + request_timeout: None, + tls: None, + }) + .await; + + let mut response = client + .server_streaming_text_generation_task_predict( + model_id, + ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".to_string(), + ..Default::default() + }, + headers, + ) + .await?; + + // assert!(response.is_ok()); + // assert!(response == expected_response); + + // let stream = response.into_stream().into_inner(); + + while let Some(Ok(message)) = response.next().await { + tracing::debug!("recv: {message:?}"); + } + + Ok(()) +} From 769df321ceecaca963973ce5ea8bb0fd75ba471b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 12 Feb 2025 11:45:04 -0300 Subject: [PATCH 025/117] Break common test module into multiple files Signed-off-by: Mateus Devino --- tests/canary_test.rs | 2 +- tests/chunker.rs | 2 +- tests/common/chunker.rs | 26 +++++++++++++ tests/common/generation.rs | 23 ++++++++++++ tests/common/mod.rs | 75 ++------------------------------------ tests/common/util.rs | 70 +++++++++++++++++++++++++++++++++++ tests/detection_content.rs | 4 +- 7 files changed, 126 insertions(+), 76 deletions(-) create mode 100644 tests/common/chunker.rs create mode 100644 tests/common/generation.rs create mode 100644 tests/common/util.rs diff --git a/tests/canary_test.rs b/tests/canary_test.rs index ebd23933..ed290192 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use axum_test::TestServer; -use common::ensure_global_rustls_state; +use common::util::ensure_global_rustls_state; use fms_guardrails_orchestr8::{ config::OrchestratorConfig, orchestrator::Orchestrator, diff --git a/tests/chunker.rs b/tests/chunker.rs index fd9fb085..c0bc4a5d 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -15,7 +15,7 @@ */ -use common::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}; +use common::chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}; use fms_guardrails_orchestr8::{ clients::chunker::{ChunkerClient, MODEL_ID_HEADER_NAME}, config::ServiceConfig, diff --git a/tests/common/chunker.rs b/tests/common/chunker.rs new file mode 100644 index 00000000..4b0c9e8b --- /dev/null +++ b/tests/common/chunker.rs @@ -0,0 +1,26 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use mocktail::generate_grpc_server; +use mocktail::mock::MockSet; + +generate_grpc_server!( + "caikit.runtime.Chunkers.ChunkersService", + MockChunkersServiceServer +); + +pub const CHUNKER_UNARY_ENDPOINT: &str = + "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; diff --git a/tests/common/generation.rs b/tests/common/generation.rs new file mode 100644 index 00000000..2bbceecb --- /dev/null +++ b/tests/common/generation.rs @@ -0,0 +1,23 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use mocktail::generate_grpc_server; +use mocktail::mock::MockSet; + +generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); + +pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = + "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index f48e925f..e7a6f973 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,75 +14,6 @@ limitations under the License. */ - -// This is needed because integration test files are compiled as separate crates. -// If any of the code in this file is not used by any of the test files, a warning about unused code is generated. -// For more: https://github.com/rust-lang/rust/issues/46379 - -use std::sync::Arc; - -use fms_guardrails_orchestr8::config::OrchestratorConfig; -use fms_guardrails_orchestr8::orchestrator::Orchestrator; -use fms_guardrails_orchestr8::server::ServerState; -use mocktail::generate_grpc_server; -use mocktail::mock::MockSet; -use mocktail::server::HttpMockServer; -use rustls::crypto::ring; - -generate_grpc_server!( - "caikit.runtime.Chunkers.ChunkersService", - MockChunkersServiceServer -); - -generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); - -pub const CHUNKER_UNARY_ENDPOINT: &str = - "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; - -pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = - "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; - -/// Default orchestrator configuration file for integration tests. -pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; - -/// -pub fn ensure_global_rustls_state() { - let _ = ring::default_provider().install_default(); -} - -/// Starts mock servers and adds them to orchestrator configuration. -pub async fn create_orchestrator_shared_state( - detectors: Vec, - chunkers: Vec<(&str, MockChunkersServiceServer)>, -) -> Result, mocktail::Error> { - let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); - - for detector_mock_server in detectors { - let _ = detector_mock_server.start().await?; - - // assign mock server port to detector config - config - .detectors - .get_mut(detector_mock_server.name()) - .unwrap() - .service - .port = Some(detector_mock_server.addr().port()); - } - - for (chunker_name, chunker_mock_server) in chunkers { - let _ = chunker_mock_server.start().await?; - - // assign mock server port to chunker config - config - .chunkers - .as_mut() - .unwrap() - .get_mut(chunker_name) - .unwrap() - .service - .port = Some(chunker_mock_server.addr().port()); - } - - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Ok(Arc::new(ServerState::new(orchestrator))) -} +pub mod chunker; +pub mod generation; +pub mod util; diff --git a/tests/common/util.rs b/tests/common/util.rs new file mode 100644 index 00000000..fd998bde --- /dev/null +++ b/tests/common/util.rs @@ -0,0 +1,70 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::sync::Arc; + +use fms_guardrails_orchestr8::{ + config::OrchestratorConfig, orchestrator::Orchestrator, server::ServerState, +}; +use mocktail::server::HttpMockServer; +use rustls::crypto::ring; + +use super::chunker::MockChunkersServiceServer; + +/// Default orchestrator configuration file for integration tests. +pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; + +pub fn ensure_global_rustls_state() { + let _ = ring::default_provider().install_default(); +} + +/// Starts mock servers and adds them to orchestrator configuration. +pub async fn create_orchestrator_shared_state( + detectors: Vec, + chunkers: Vec<(&str, MockChunkersServiceServer)>, +) -> Result, mocktail::Error> { + let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + + for detector_mock_server in detectors { + let _ = detector_mock_server.start().await?; + + // assign mock server port to detector config + config + .detectors + .get_mut(detector_mock_server.name()) + .unwrap() + .service + .port = Some(detector_mock_server.addr().port()); + } + + for (chunker_name, chunker_mock_server) in chunkers { + let _ = chunker_mock_server.start().await?; + + // assign mock server port to chunker config + config + .chunkers + .as_mut() + .unwrap() + .get_mut(chunker_name) + .unwrap() + .service + .port = Some(chunker_mock_server.addr().port()); + } + + let orchestrator = Orchestrator::new(config, false).await.unwrap(); + Ok(Arc::new(ServerState::new(orchestrator))) +} diff --git a/tests/detection_content.rs b/tests/detection_content.rs index c3ca5421..86be7335 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -19,8 +19,8 @@ use std::collections::HashMap; use axum_test::TestServer; use common::{ - create_orchestrator_shared_state, ensure_global_rustls_state, MockChunkersServiceServer, - CHUNKER_UNARY_ENDPOINT, + chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}, + util::{create_orchestrator_shared_state, ensure_global_rustls_state}, }; use fms_guardrails_orchestr8::{ clients::{ From 99e4544b21c0c036f944e460598958c9af3ceaf4 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 12 Feb 2025 14:46:02 -0300 Subject: [PATCH 026/117] Test NLP generation stremaing call Signed-off-by: Mateus Devino --- tests/generation_nlp.rs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 217c2744..2030f765 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -15,7 +15,7 @@ */ -use common::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}; +use common::generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}; use fms_guardrails_orchestr8::{ clients::{nlp::MODEL_ID_HEADER_NAME, NlpClient}, config::ServiceConfig, @@ -58,7 +58,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { - text: "Hi there! how are you?".to_string(), + text: "Hi there! How are you?".to_string(), ..Default::default() }) .with_headers(headers.clone()), @@ -67,7 +67,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { ); let generation_nlp_server = MockNlpServiceServer::new(mocks)?; - let _ = generation_nlp_server.start().await; + generation_nlp_server.start().await?; let client = NlpClient::new(&ServiceConfig { hostname: "localhost".to_string(), @@ -77,7 +77,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { }) .await; - let mut response = client + let response = client .server_streaming_text_generation_task_predict( model_id, ServerStreamingTextGenerationTaskRequest { @@ -88,14 +88,10 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { ) .await?; - // assert!(response.is_ok()); - // assert!(response == expected_response); - - // let stream = response.into_stream().into_inner(); + // Collect stream results as array for assertion + let result = response.map(Result::unwrap).collect::>().await; - while let Some(Ok(message)) = response.next().await { - tracing::debug!("recv: {message:?}"); - } + assert!(result == expected_response); Ok(()) } From f55765b3152e8e8f5c4bc880536dd5314a1df2d3 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 12 Feb 2025 17:50:43 -0300 Subject: [PATCH 027/117] test case: streaming no detectors Signed-off-by: Mateus Devino --- tests/common/util.rs | 12 +++- tests/detection_content.rs | 4 +- tests/streaming.rs | 116 +++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 tests/streaming.rs diff --git a/tests/common/util.rs b/tests/common/util.rs index fd998bde..8bbe4493 100644 --- a/tests/common/util.rs +++ b/tests/common/util.rs @@ -23,7 +23,7 @@ use fms_guardrails_orchestr8::{ use mocktail::server::HttpMockServer; use rustls::crypto::ring; -use super::chunker::MockChunkersServiceServer; +use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; /// Default orchestrator configuration file for integration tests. pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; @@ -34,13 +34,19 @@ pub fn ensure_global_rustls_state() { /// Starts mock servers and adds them to orchestrator configuration. pub async fn create_orchestrator_shared_state( + generation_server: Option<&MockNlpServiceServer>, detectors: Vec, chunkers: Vec<(&str, MockChunkersServiceServer)>, ) -> Result, mocktail::Error> { let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); + if let Some(generation_server) = generation_server { + generation_server.start().await?; + config.generation.as_mut().unwrap().service.port = Some(generation_server.addr().port()); + } + for detector_mock_server in detectors { - let _ = detector_mock_server.start().await?; + detector_mock_server.start().await?; // assign mock server port to detector config config @@ -52,7 +58,7 @@ pub async fn create_orchestrator_shared_state( } for (chunker_name, chunker_mock_server) in chunkers { - let _ = chunker_mock_server.start().await?; + chunker_mock_server.start().await?; // assign mock server port to chunker config config diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 86be7335..0842dac4 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -83,7 +83,8 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Setup orchestrator and detector servers let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; - let shared_state = create_orchestrator_shared_state(vec![mock_detector_server], vec![]).await?; + let shared_state = + create_orchestrator_shared_state(None, vec![mock_detector_server], vec![]).await?; let server = TestServer::new(get_app(shared_state))?; // Make orchestrator call @@ -187,6 +188,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let shared_state = create_orchestrator_shared_state( + None, vec![mock_detector_server], vec![(chunker_id, mock_chunker_server)], ) diff --git a/tests/streaming.rs b/tests/streaming.rs new file mode 100644 index 00000000..5bc5b167 --- /dev/null +++ b/tests/streaming.rs @@ -0,0 +1,116 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use axum_test::TestServer; +use common::{ + generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, + util::{create_orchestrator_shared_state, ensure_global_rustls_state}, +}; +use fms_guardrails_orchestr8::{ + clients::nlp::MODEL_ID_HEADER_NAME, + models::{ClassifiedGeneratedTextStreamResult, GuardrailsHttpRequest}, + pb::{ + caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, + caikit_data_model::nlp::GeneratedTextStreamResult, + }, + server::get_app, +}; +use mocktail::prelude::*; +use tracing::debug; +use tracing_test::traced_test; + +pub mod common; + +const ENDPOINT_ORCHESTRATOR: &str = + "/api/v1/task/server-streaming-classification-with-text-generation"; + +#[traced_test] +#[tokio::test] +async fn test_no_detectors() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + + // Add generation mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let expected_response = vec![ + GeneratedTextStreamResult { + generated_text: "I".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " am".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " great!".to_string(), + ..Default::default() + }, + ]; + + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".to_string(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb_stream(expected_response.clone()), + ), + ); + + // Setup servers + let generation_nlp_server = MockNlpServiceServer::new(mocks)?; + let shared_state = + create_orchestrator_shared_state(Some(&generation_nlp_server), vec![], vec![]).await?; + let server = TestServer::new(get_app(shared_state))?; + + // Make orchestrator call + let response = server + .post(ENDPOINT_ORCHESTRATOR) + .json(&GuardrailsHttpRequest { + model_id: model_id.to_string(), + inputs: "Hi there! How are you?".to_string(), + guardrail_config: None, + text_gen_parameters: None, + }) + .await; + + // convert SSE events back into Rust structs + let text = response.text(); + let results: Vec<_> = text + .split("\n\n") + .filter(|line| !line.is_empty()) + .map(|line| { + serde_json::from_str::(&line.replace("data: ", "")) + .unwrap() + }) + .collect(); + + // assertions + assert!(results.len() == 3); + assert!(results[0].generated_text == Some("I".into())); + assert!(results[1].generated_text == Some(" am".into())); + assert!(results[2].generated_text == Some(" great!".into())); + + debug!("{:#?}", results); + + Ok(()) +} From a8cafe536b804ec527fb6f29da9291be07646b59 Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Wed, 12 Feb 2025 16:14:08 -0800 Subject: [PATCH 028/117] Add MockOrchestratorServer to replace usage of TestServer, add EventSource example, add custom SseStream wrapper (wip) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- Cargo.lock | 1 + Cargo.toml | 43 ++++++++--- tests/common/util.rs | 166 ++++++++++++++++++++++++++++++++++++++++++- tests/streaming.rs | 88 ++++++++++++++--------- 4 files changed, 257 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f0f2428a..a5c3944c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -699,6 +699,7 @@ dependencies = [ "axum 0.8.1", "axum-extra", "axum-test", + "bytes", "clap", "eventsource-stream", "faux", diff --git a/Cargo.toml b/Cargo.toml index c57a8899..6c5737fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ anyhow = "1.0.95" async-trait = "0.1.85" axum = { version = "0.8.1", features = ["json"] } axum-extra = "0.10.0" +bytes = "1.10.0" clap = { version = "4.5.26", features = ["derive", "env"] } eventsource-stream = "0.2.3" futures = "0.3.31" @@ -28,27 +29,53 @@ http-body = "1.0" http-body-util = "0.1.2" http-serde = "2.1.1" hyper = { version = "1.5.2", features = ["http1", "http2", "server"] } -hyper-rustls = { version = "0.27.5", features = ["ring"]} +hyper-rustls = { version = "0.27.5", features = ["ring"] } hyper-timeout = "0.5.2" -hyper-util = { version = "0.1.10", features = ["server-auto", "server-graceful", "tokio"] } +hyper-util = { version = "0.1.10", features = [ + "server-auto", + "server-graceful", + "tokio", +] } opentelemetry = { version = "0.27.1", features = ["metrics", "trace"] } opentelemetry-http = { version = "0.27.0", features = ["reqwest"] } -opentelemetry-otlp = { version = "0.27.0", features = ["grpc-tonic", "http-proto"] } +opentelemetry-otlp = { version = "0.27.0", features = [ + "grpc-tonic", + "http-proto", +] } opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio", "metrics"] } pin-project-lite = "0.2.16" prost = "0.13.4" -reqwest = { version = "0.12.12", features = ["blocking", "rustls-tls", "json"] } -rustls = {version = "0.23.21", default-features = false, features = ["ring", "std"]} +reqwest = { version = "0.12.12", features = [ + "blocking", + "rustls-tls", + "json", + "stream", +] } +rustls = { version = "0.23.21", default-features = false, features = [ + "ring", + "std", +] } rustls-pemfile = "2.2.0" rustls-webpki = "0.102.8" serde = { version = "1.0.217", features = ["derive"] } serde_json = "1.0.135" serde_yml = "0.0.12" thiserror = "2.0.11" -tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] } -tokio-rustls = { version = "0.26.1", features = ["ring"]} +tokio = { version = "1.43.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", + "fs", +] } +tokio-rustls = { version = "0.26.1", features = ["ring"] } tokio-stream = { version = "0.1.17", features = ["sync"] } -tonic = { version = "0.12.3", features = ["tls", "tls-roots", "tls-webpki-roots"] } +tonic = { version = "0.12.3", features = [ + "tls", + "tls-roots", + "tls-webpki-roots", +] } tower = { version = "0.5.2", features = ["timeout"] } tower-http = { version = "0.6.2", features = ["trace"] } tracing = "0.1.41" diff --git a/tests/common/util.rs b/tests/common/util.rs index 8bbe4493..70914f3f 100644 --- a/tests/common/util.rs +++ b/tests/common/util.rs @@ -15,13 +15,27 @@ */ -use std::sync::Arc; +use std::{ + marker::PhantomData, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::Path, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; +use bytes::Bytes; +use eventsource_stream::{EventStream, Eventsource}; use fms_guardrails_orchestr8::{ config::OrchestratorConfig, orchestrator::Orchestrator, server::ServerState, }; +use futures::{stream::BoxStream, Stream, StreamExt}; use mocktail::server::HttpMockServer; use rustls::crypto::ring; +use serde::de::DeserializeOwned; +use tokio::task::JoinHandle; +use url::Url; use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; @@ -74,3 +88,153 @@ pub async fn create_orchestrator_shared_state( let orchestrator = Orchestrator::new(config, false).await.unwrap(); Ok(Arc::new(ServerState::new(orchestrator))) } + +pub struct TestOrchestratorServer { + base_url: Url, + health_url: Url, + client: reqwest::Client, + _handle: JoinHandle>, +} + +impl TestOrchestratorServer { + /// Configures and runs an orchestrator server. + pub async fn run( + config_path: impl AsRef, + port: u16, + health_port: u16, + generation_server: Option, + chat_generation_server: Option, + detector_servers: Option>, + chunker_servers: Option>, + ) -> Result { + // Load orchestrator config + let mut config = OrchestratorConfig::load(config_path).await?; + + // Start & configure mock servers + // Generation server + if let Some(generation_server) = generation_server { + generation_server.start().await?; + config.generation.as_mut().unwrap().service.port = + Some(generation_server.addr().port()); + } + // Chat generation server + if let Some(chat_generation_server) = chat_generation_server { + chat_generation_server.start().await?; + config.chat_generation.as_mut().unwrap().service.port = + Some(chat_generation_server.addr().port()); + } + // Detector servers + if let Some(detector_servers) = detector_servers { + for detector_server in detector_servers { + detector_server.start().await?; + config + .detectors + .get_mut(detector_server.name()) + .unwrap() + .service + .port = Some(detector_server.addr().port()); + } + } + // Chunker servers + if let Some(chunker_servers) = chunker_servers { + for (name, chunker_server) in chunker_servers { + chunker_server.start().await?; + config + .chunkers + .as_mut() + .unwrap() + .get_mut(&name) + .unwrap() + .service + .port = Some(chunker_server.addr().port()); + } + } + + // Run orchestrator server + let orchestrator = Orchestrator::new(config, false).await?; + let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let health_http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); + + let _handle = tokio::spawn(async move { + fms_guardrails_orchestr8::server::run( + http_addr, + health_http_addr, + None, + None, + None, + orchestrator, + ) + .await?; + Ok::<(), anyhow::Error>(()) + }); + // Allow the server time to become ready + tokio::time::sleep(Duration::from_millis(10)).await; + + let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); + let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); + let client = reqwest::Client::builder().build().unwrap(); + Ok(Self { + base_url, + health_url, + client, + _handle, + }) + } + + pub fn server_url(&self, path: &str) -> Url { + self.base_url.join(path).unwrap() + } + + pub fn health_url(&self) -> Url { + self.health_url.clone() + } + + pub fn get(&self, path: &str) -> reqwest::RequestBuilder { + let url = self.server_url(path); + self.client.get(url) + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = self.server_url(path); + self.client.post(url) + } +} + +pub struct SseStream<'a, T> { + stream: EventStream>>, + phantom: PhantomData<&'a T>, +} + +impl SseStream<'_, T> { + pub fn new(stream: impl Stream> + Send + 'static) -> Self { + let stream = stream.boxed().eventsource(); + Self { + stream, + phantom: PhantomData, + } + } +} + +impl Stream for SseStream<'_, T> +where + T: DeserializeOwned, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match Pin::new(&mut self.get_mut().stream).poll_next(cx) { + Poll::Ready(Some(Ok(event))) => { + if event.data == "[DONE]" { + return Poll::Ready(None); + } + match serde_json::from_str::(&event.data) { + Ok(msg) => Poll::Ready(Some(Ok(msg))), + Err(error) => Poll::Ready(Some(Err(error.into()))), + } + } + Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/streaming.rs b/tests/streaming.rs index 5bc5b167..a80559af 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -15,10 +15,9 @@ */ -use axum_test::TestServer; use common::{ generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, - util::{create_orchestrator_shared_state, ensure_global_rustls_state}, + util::{ensure_global_rustls_state, SseStream, TestOrchestratorServer}, }; use fms_guardrails_orchestr8::{ clients::nlp::MODEL_ID_HEADER_NAME, @@ -27,15 +26,14 @@ use fms_guardrails_orchestr8::{ caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, caikit_data_model::nlp::GeneratedTextStreamResult, }, - server::get_app, }; +use futures::StreamExt; use mocktail::prelude::*; -use tracing::debug; use tracing_test::traced_test; pub mod common; -const ENDPOINT_ORCHESTRATOR: &str = +const STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT: &str = "/api/v1/task/server-streaming-classification-with-text-generation"; #[traced_test] @@ -76,41 +74,67 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { ), ); - // Setup servers - let generation_nlp_server = MockNlpServiceServer::new(mocks)?; - let shared_state = - create_orchestrator_shared_state(Some(&generation_nlp_server), vec![], vec![]).await?; - let server = TestServer::new(get_app(shared_state))?; - - // Make orchestrator call - let response = server - .post(ENDPOINT_ORCHESTRATOR) + // Configure mock servers + let generation_server = MockNlpServiceServer::new(mocks)?; + + // Run test orchestrator server + let orchestrator_server = TestOrchestratorServer::run( + "tests/test.config.yaml", + 8080, + 8081, + Some(generation_server), + None, + None, + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.to_string(), inputs: "Hi there! How are you?".to_string(), guardrail_config: None, text_gen_parameters: None, }) - .await; - - // convert SSE events back into Rust structs - let text = response.text(); - let results: Vec<_> = text - .split("\n\n") - .filter(|line| !line.is_empty()) - .map(|line| { - serde_json::from_str::(&line.replace("data: ", "")) - .unwrap() - }) - .collect(); + .send() + .await?; + + // Example showing how to create an event stream from a bytes stream. + // let mut events = Vec::new(); + // let mut event_stream = response.bytes_stream().eventsource(); + // while let Some(event) = event_stream.next().await { + // match event { + // Ok(event) => { + // if event.data == "[DONE]" { + // break; + // } + // println!("recv: {event:?}"); + // events.push(event.data); + // } + // Err(_) => { + // panic!("received error from event stream"); + // } + // } + // } + // println!("{events:?}"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + println!("{messages:?}"); // assertions - assert!(results.len() == 3); - assert!(results[0].generated_text == Some("I".into())); - assert!(results[1].generated_text == Some(" am".into())); - assert!(results[2].generated_text == Some(" great!".into())); - - debug!("{:#?}", results); + assert!(messages.len() == 3); + assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[2].generated_text == Some(" great!".into())); Ok(()) } From fc657009a88cd3736e0b95975fddd69e6d0c1cb2 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 10:52:12 -0300 Subject: [PATCH 029/117] Refactor test util.rs module into orchestrator.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 2 +- tests/common/mod.rs | 2 +- tests/common/{util.rs => orchestrator.rs} | 161 ++++++++++++++-------- tests/detection_content.rs | 2 +- tests/streaming.rs | 2 +- 5 files changed, 105 insertions(+), 64 deletions(-) rename tests/common/{util.rs => orchestrator.rs} (64%) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index ed290192..a2fc0d28 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use axum_test::TestServer; -use common::util::ensure_global_rustls_state; +use common::orchestrator::ensure_global_rustls_state; use fms_guardrails_orchestr8::{ config::OrchestratorConfig, orchestrator::Orchestrator, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e7a6f973..e11ecedf 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -16,4 +16,4 @@ */ pub mod chunker; pub mod generation; -pub mod util; +pub mod orchestrator; diff --git a/tests/common/util.rs b/tests/common/orchestrator.rs similarity index 64% rename from tests/common/util.rs rename to tests/common/orchestrator.rs index 70914f3f..f7930044 100644 --- a/tests/common/util.rs +++ b/tests/common/orchestrator.rs @@ -111,68 +111,15 @@ impl TestOrchestratorServer { let mut config = OrchestratorConfig::load(config_path).await?; // Start & configure mock servers - // Generation server - if let Some(generation_server) = generation_server { - generation_server.start().await?; - config.generation.as_mut().unwrap().service.port = - Some(generation_server.addr().port()); - } - // Chat generation server - if let Some(chat_generation_server) = chat_generation_server { - chat_generation_server.start().await?; - config.chat_generation.as_mut().unwrap().service.port = - Some(chat_generation_server.addr().port()); - } - // Detector servers - if let Some(detector_servers) = detector_servers { - for detector_server in detector_servers { - detector_server.start().await?; - config - .detectors - .get_mut(detector_server.name()) - .unwrap() - .service - .port = Some(detector_server.addr().port()); - } - } - // Chunker servers - if let Some(chunker_servers) = chunker_servers { - for (name, chunker_server) in chunker_servers { - chunker_server.start().await?; - config - .chunkers - .as_mut() - .unwrap() - .get_mut(&name) - .unwrap() - .service - .port = Some(chunker_server.addr().port()); - } - } + initialize_generation_server(generation_server, &mut config).await?; + initialize_chat_generation_server(chat_generation_server, &mut config).await?; + initialize_detectors(detector_servers, &mut config).await?; + initialize_chunkers(chunker_servers, &mut config).await?; // Run orchestrator server - let orchestrator = Orchestrator::new(config, false).await?; - let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); - let health_http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); - - let _handle = tokio::spawn(async move { - fms_guardrails_orchestr8::server::run( - http_addr, - health_http_addr, - None, - None, - None, - orchestrator, - ) - .await?; - Ok::<(), anyhow::Error>(()) - }); - // Allow the server time to become ready - tokio::time::sleep(Duration::from_millis(10)).await; - - let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); - let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); - let client = reqwest::Client::builder().build().unwrap(); + let (_handle, base_url, health_url, client) = + initialize_orchestrator_server(port, health_port, config).await?; + Ok(Self { base_url, health_url, @@ -200,6 +147,100 @@ impl TestOrchestratorServer { } } +async fn initialize_orchestrator_server( + port: u16, + health_port: u16, + config: OrchestratorConfig, +) -> Result< + ( + JoinHandle>, + Url, + Url, + reqwest::Client, + ), + anyhow::Error, +> { + let orchestrator = Orchestrator::new(config, false).await?; + let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let health_http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); + let _handle = tokio::spawn(async move { + fms_guardrails_orchestr8::server::run( + http_addr, + health_http_addr, + None, + None, + None, + orchestrator, + ) + .await?; + Ok::<(), anyhow::Error>(()) + }); + tokio::time::sleep(Duration::from_millis(10)).await; + let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); + let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); + let client = reqwest::Client::builder().build().unwrap(); + Ok((_handle, base_url, health_url, client)) +} + +async fn initialize_generation_server( + generation_server: Option, + config: &mut OrchestratorConfig, +) -> Result<(), anyhow::Error> { + Ok(if let Some(generation_server) = generation_server { + generation_server.start().await?; + config.generation.as_mut().unwrap().service.port = Some(generation_server.addr().port()); + }) +} + +async fn initialize_chat_generation_server( + chat_generation_server: Option, + config: &mut OrchestratorConfig, +) -> Result<(), anyhow::Error> { + Ok( + if let Some(chat_generation_server) = chat_generation_server { + chat_generation_server.start().await?; + config.chat_generation.as_mut().unwrap().service.port = + Some(chat_generation_server.addr().port()); + }, + ) +} + +async fn initialize_detectors( + detector_servers: Option>, + config: &mut OrchestratorConfig, +) -> Result<(), anyhow::Error> { + Ok(if let Some(detector_servers) = detector_servers { + for detector_server in detector_servers { + detector_server.start().await?; + config + .detectors + .get_mut(detector_server.name()) + .unwrap() + .service + .port = Some(detector_server.addr().port()); + } + }) +} + +async fn initialize_chunkers( + chunker_servers: Option>, + config: &mut OrchestratorConfig, +) -> Result<(), anyhow::Error> { + Ok(if let Some(chunker_servers) = chunker_servers { + for (name, chunker_server) in chunker_servers { + chunker_server.start().await?; + config + .chunkers + .as_mut() + .unwrap() + .get_mut(&name) + .unwrap() + .service + .port = Some(chunker_server.addr().port()); + } + }) +} + pub struct SseStream<'a, T> { stream: EventStream>>, phantom: PhantomData<&'a T>, diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 0842dac4..77ec28ef 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use axum_test::TestServer; use common::{ chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}, - util::{create_orchestrator_shared_state, ensure_global_rustls_state}, + orchestrator::{create_orchestrator_shared_state, ensure_global_rustls_state}, }; use fms_guardrails_orchestr8::{ clients::{ diff --git a/tests/streaming.rs b/tests/streaming.rs index a80559af..a718c174 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -17,7 +17,7 @@ use common::{ generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, - util::{ensure_global_rustls_state, SseStream, TestOrchestratorServer}, + orchestrator::{ensure_global_rustls_state, SseStream, TestOrchestratorServer}, }; use fms_guardrails_orchestr8::{ clients::nlp::MODEL_ID_HEADER_NAME, From 2b3fed6657b04e918ef0adf13a94f6df883b15e1 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 11:15:05 -0300 Subject: [PATCH 030/117] Remove old create_orchestrator_shared_state() Signed-off-by: Mateus Devino --- tests/common/orchestrator.rs | 48 +---------------- tests/detection_content.rs | 100 ++++++++++++++++++++--------------- 2 files changed, 59 insertions(+), 89 deletions(-) diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index f7930044..0873f37b 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -20,16 +20,13 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, pin::Pin, - sync::Arc, task::{Context, Poll}, time::Duration, }; use bytes::Bytes; use eventsource_stream::{EventStream, Eventsource}; -use fms_guardrails_orchestr8::{ - config::OrchestratorConfig, orchestrator::Orchestrator, server::ServerState, -}; +use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; use futures::{stream::BoxStream, Stream, StreamExt}; use mocktail::server::HttpMockServer; use rustls::crypto::ring; @@ -46,49 +43,6 @@ pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } -/// Starts mock servers and adds them to orchestrator configuration. -pub async fn create_orchestrator_shared_state( - generation_server: Option<&MockNlpServiceServer>, - detectors: Vec, - chunkers: Vec<(&str, MockChunkersServiceServer)>, -) -> Result, mocktail::Error> { - let mut config = OrchestratorConfig::load(CONFIG_FILE_PATH).await.unwrap(); - - if let Some(generation_server) = generation_server { - generation_server.start().await?; - config.generation.as_mut().unwrap().service.port = Some(generation_server.addr().port()); - } - - for detector_mock_server in detectors { - detector_mock_server.start().await?; - - // assign mock server port to detector config - config - .detectors - .get_mut(detector_mock_server.name()) - .unwrap() - .service - .port = Some(detector_mock_server.addr().port()); - } - - for (chunker_name, chunker_mock_server) in chunkers { - chunker_mock_server.start().await?; - - // assign mock server port to chunker config - config - .chunkers - .as_mut() - .unwrap() - .get_mut(chunker_name) - .unwrap() - .service - .port = Some(chunker_mock_server.addr().port()); - } - - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Ok(Arc::new(ServerState::new(orchestrator))) -} - pub struct TestOrchestratorServer { base_url: Url, health_url: Url, diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 77ec28ef..32e723b8 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -17,10 +17,9 @@ use std::collections::HashMap; -use axum_test::TestServer; use common::{ chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}, - orchestrator::{create_orchestrator_shared_state, ensure_global_rustls_state}, + orchestrator::{ensure_global_rustls_state, TestOrchestratorServer, CONFIG_FILE_PATH}, }; use fms_guardrails_orchestr8::{ clients::{ @@ -32,7 +31,6 @@ use fms_guardrails_orchestr8::{ caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, caikit_data_model::nlp::{Token, TokenizationResults}, }, - server::get_app, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -81,37 +79,48 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { ), ); - // Setup orchestrator and detector servers + // Start orchestrator server and its dependencies let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; - let shared_state = - create_orchestrator_shared_state(None, vec![mock_detector_server], vec![]).await?; - let server = TestServer::new(get_app(shared_state))?; + let orchestrator_server = TestOrchestratorServer::run( + CONFIG_FILE_PATH, + 8080, + 8081, + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; // Make orchestrator call - let response = server + let response = orchestrator_server .post(ENDPOINT_ORCHESTRATOR) .json(&TextContentDetectionHttpRequest { content: "This sentence has .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), }) - .await; + .send() + .await?; debug!(?response); // assertions - response.assert_status(StatusCode::OK); - response.assert_json(&TextContentDetectionResult { - detections: vec![ContentAnalysisResponse { - start: 18, - end: 35, - text: "a detection here".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), - score: 1.0, - evidence: None, - }], - }); + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == TextContentDetectionResult { + detections: vec![ContentAnalysisResponse { + start: 18, + end: 35, + text: "a detection here".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), + detector_id: Some(detector_name.to_string()), + score: 1.0, + evidence: None, + }], + } + ); Ok(()) } @@ -184,42 +193,49 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { ), ); - // Start orchestrator, chunker and detector servers. + // Start orchestrator server and its dependencies let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; - let shared_state = create_orchestrator_shared_state( + let orchestrator_server = TestOrchestratorServer::run( + CONFIG_FILE_PATH, + 8080, + 8081, + None, None, - vec![mock_detector_server], - vec![(chunker_id, mock_chunker_server)], + Some(vec![mock_detector_server]), + Some(vec![(chunker_id.into(), mock_chunker_server)]), ) .await?; - let server = TestServer::new(get_app(shared_state))?; // Make orchestrator call - let response = server + let response = orchestrator_server .post(ENDPOINT_ORCHESTRATOR) .json(&TextContentDetectionHttpRequest { content: "This sentence does not have a detection. But .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), }) - .await; + .send() + .await?; debug!(?response); // assertions - response.assert_status(StatusCode::OK); - response.assert_json(&TextContentDetectionResult { - detections: vec![ContentAnalysisResponse { - start: 45, - end: 59, - text: "this one does".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), - score: 1.0, - evidence: None, - }], - }); + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == TextContentDetectionResult { + detections: vec![ContentAnalysisResponse { + start: 45, + end: 59, + text: "this one does".to_string(), + detection: "has_angle_brackets".to_string(), + detection_type: "angle_brackets".to_string(), + detector_id: Some(detector_name.to_string()), + score: 1.0, + evidence: None, + }], + } + ); Ok(()) } From 45ec27389855f5b7ce0d9e39bde50256a4b580cf Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 11:34:09 -0300 Subject: [PATCH 031/117] Move orchestrator.rs functions into TestOrchestratorServer impl Signed-off-by: Mateus Devino --- tests/common/orchestrator.rs | 192 +++++++++++++++++------------------ 1 file changed, 91 insertions(+), 101 deletions(-) diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 0873f37b..9869fa3c 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -65,15 +65,99 @@ impl TestOrchestratorServer { let mut config = OrchestratorConfig::load(config_path).await?; // Start & configure mock servers - initialize_generation_server(generation_server, &mut config).await?; - initialize_chat_generation_server(chat_generation_server, &mut config).await?; - initialize_detectors(detector_servers, &mut config).await?; - initialize_chunkers(chunker_servers, &mut config).await?; + Self::initialize_generation_server(generation_server, &mut config).await?; + Self::initialize_chat_generation_server(chat_generation_server, &mut config).await?; + Self::initialize_detectors(detector_servers, &mut config).await?; + Self::initialize_chunkers(chunker_servers, &mut config).await?; - // Run orchestrator server - let (_handle, base_url, health_url, client) = - initialize_orchestrator_server(port, health_port, config).await?; + // Run orchestrator server and returns it + Self::initialize_orchestrator_server(port, health_port, config).await + } + + async fn initialize_generation_server( + generation_server: Option, + config: &mut OrchestratorConfig, + ) -> Result<(), anyhow::Error> { + Ok(if let Some(generation_server) = generation_server { + generation_server.start().await?; + config.generation.as_mut().unwrap().service.port = + Some(generation_server.addr().port()); + }) + } + + async fn initialize_chat_generation_server( + chat_generation_server: Option, + config: &mut OrchestratorConfig, + ) -> Result<(), anyhow::Error> { + Ok( + if let Some(chat_generation_server) = chat_generation_server { + chat_generation_server.start().await?; + config.chat_generation.as_mut().unwrap().service.port = + Some(chat_generation_server.addr().port()); + }, + ) + } + async fn initialize_detectors( + detector_servers: Option>, + config: &mut OrchestratorConfig, + ) -> Result<(), anyhow::Error> { + Ok(if let Some(detector_servers) = detector_servers { + for detector_server in detector_servers { + detector_server.start().await?; + config + .detectors + .get_mut(detector_server.name()) + .unwrap() + .service + .port = Some(detector_server.addr().port()); + } + }) + } + + async fn initialize_chunkers( + chunker_servers: Option>, + config: &mut OrchestratorConfig, + ) -> Result<(), anyhow::Error> { + Ok(if let Some(chunker_servers) = chunker_servers { + for (name, chunker_server) in chunker_servers { + chunker_server.start().await?; + config + .chunkers + .as_mut() + .unwrap() + .get_mut(&name) + .unwrap() + .service + .port = Some(chunker_server.addr().port()); + } + }) + } + + async fn initialize_orchestrator_server( + port: u16, + health_port: u16, + config: OrchestratorConfig, + ) -> Result { + let orchestrator = Orchestrator::new(config, false).await?; + let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let health_http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); + let _handle = tokio::spawn(async move { + fms_guardrails_orchestr8::server::run( + http_addr, + health_http_addr, + None, + None, + None, + orchestrator, + ) + .await?; + Ok::<(), anyhow::Error>(()) + }); + tokio::time::sleep(Duration::from_millis(10)).await; + let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); + let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); + let client = reqwest::Client::builder().build().unwrap(); Ok(Self { base_url, health_url, @@ -101,100 +185,6 @@ impl TestOrchestratorServer { } } -async fn initialize_orchestrator_server( - port: u16, - health_port: u16, - config: OrchestratorConfig, -) -> Result< - ( - JoinHandle>, - Url, - Url, - reqwest::Client, - ), - anyhow::Error, -> { - let orchestrator = Orchestrator::new(config, false).await?; - let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); - let health_http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); - let _handle = tokio::spawn(async move { - fms_guardrails_orchestr8::server::run( - http_addr, - health_http_addr, - None, - None, - None, - orchestrator, - ) - .await?; - Ok::<(), anyhow::Error>(()) - }); - tokio::time::sleep(Duration::from_millis(10)).await; - let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); - let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); - let client = reqwest::Client::builder().build().unwrap(); - Ok((_handle, base_url, health_url, client)) -} - -async fn initialize_generation_server( - generation_server: Option, - config: &mut OrchestratorConfig, -) -> Result<(), anyhow::Error> { - Ok(if let Some(generation_server) = generation_server { - generation_server.start().await?; - config.generation.as_mut().unwrap().service.port = Some(generation_server.addr().port()); - }) -} - -async fn initialize_chat_generation_server( - chat_generation_server: Option, - config: &mut OrchestratorConfig, -) -> Result<(), anyhow::Error> { - Ok( - if let Some(chat_generation_server) = chat_generation_server { - chat_generation_server.start().await?; - config.chat_generation.as_mut().unwrap().service.port = - Some(chat_generation_server.addr().port()); - }, - ) -} - -async fn initialize_detectors( - detector_servers: Option>, - config: &mut OrchestratorConfig, -) -> Result<(), anyhow::Error> { - Ok(if let Some(detector_servers) = detector_servers { - for detector_server in detector_servers { - detector_server.start().await?; - config - .detectors - .get_mut(detector_server.name()) - .unwrap() - .service - .port = Some(detector_server.addr().port()); - } - }) -} - -async fn initialize_chunkers( - chunker_servers: Option>, - config: &mut OrchestratorConfig, -) -> Result<(), anyhow::Error> { - Ok(if let Some(chunker_servers) = chunker_servers { - for (name, chunker_server) in chunker_servers { - chunker_server.start().await?; - config - .chunkers - .as_mut() - .unwrap() - .get_mut(&name) - .unwrap() - .service - .port = Some(chunker_server.addr().port()); - } - }) -} - pub struct SseStream<'a, T> { stream: EventStream>>, phantom: PhantomData<&'a T>, From ba43415650c5c4d716ad55e06fefbdb38c4df3dc Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 11:55:31 -0300 Subject: [PATCH 032/117] Refactor test constants Signed-off-by: Mateus Devino --- tests/common/detectors.rs | 20 ++++++++++++++++++++ tests/common/mod.rs | 1 + tests/common/orchestrator.rs | 2 ++ tests/detection_content.rs | 23 ++++++++++++----------- tests/streaming.rs | 8 ++++++-- 5 files changed, 41 insertions(+), 13 deletions(-) create mode 100644 tests/common/detectors.rs diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs new file mode 100644 index 00000000..75c1247d --- /dev/null +++ b/tests/common/detectors.rs @@ -0,0 +1,20 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; + +pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; +pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e11ecedf..b50f826f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -15,5 +15,6 @@ */ pub mod chunker; +pub mod detectors; pub mod generation; pub mod orchestrator; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 9869fa3c..bbefd02f 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -39,6 +39,8 @@ use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer /// Default orchestrator configuration file for integration tests. pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; +pub const CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT: &str = "/api/v2/text/detection/content"; + pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 32e723b8..f706bd9a 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -19,7 +19,14 @@ use std::collections::HashMap; use common::{ chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}, - orchestrator::{ensure_global_rustls_state, TestOrchestratorServer, CONFIG_FILE_PATH}, + detectors::{ + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + TEXT_CONTENTS_DETECTOR_ENDPOINT, + }, + orchestrator::{ + ensure_global_rustls_state, TestOrchestratorServer, CONFIG_FILE_PATH, + CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, + }, }; use fms_guardrails_orchestr8::{ clients::{ @@ -40,12 +47,6 @@ use tracing_test::traced_test; pub mod common; // Constants -const ENDPOINT_ORCHESTRATOR: &str = "/api/v2/text/detection/content"; -const ENDPOINT_DETECTOR: &str = "/api/v1/text/contents"; - -const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; -const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; - const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). @@ -60,7 +61,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, ENDPOINT_DETECTOR), + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This sentence has .".to_string()], @@ -94,7 +95,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Make orchestrator call let response = orchestrator_server - .post(ENDPOINT_ORCHESTRATOR) + .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) .json(&TextContentDetectionHttpRequest { content: "This sentence has .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), @@ -168,7 +169,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, ENDPOINT_DETECTOR), + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -209,7 +210,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { // Make orchestrator call let response = orchestrator_server - .post(ENDPOINT_ORCHESTRATOR) + .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) .json(&TextContentDetectionHttpRequest { content: "This sentence does not have a detection. But .".to_string(), detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), diff --git a/tests/streaming.rs b/tests/streaming.rs index a718c174..5d50996e 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -16,12 +16,16 @@ */ use common::{ + detectors::{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, orchestrator::{ensure_global_rustls_state, SseStream, TestOrchestratorServer}, }; use fms_guardrails_orchestr8::{ - clients::nlp::MODEL_ID_HEADER_NAME, - models::{ClassifiedGeneratedTextStreamResult, GuardrailsHttpRequest}, + clients::{ + detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + nlp::MODEL_ID_HEADER_NAME, + }, + models::{ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsHttpRequest}, pb::{ caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, caikit_data_model::nlp::GeneratedTextStreamResult, From 6cf3c17658903f871311fd6e9e4467db541506ae Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 13:12:00 -0300 Subject: [PATCH 033/117] Test case: streaming.rs::test_input_detector_whole_doc_no_detections() Signed-off-by: Mateus Devino --- tests/streaming.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 1 deletion(-) diff --git a/tests/streaming.rs b/tests/streaming.rs index 5d50996e..c6f427b0 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -15,6 +15,8 @@ */ +use std::collections::HashMap; + use common::{ detectors::{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, @@ -25,7 +27,10 @@ use fms_guardrails_orchestr8::{ detector::{ContentAnalysisRequest, ContentAnalysisResponse}, nlp::MODEL_ID_HEADER_NAME, }, - models::{ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsHttpRequest}, + models::{ + ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsConfig, + GuardrailsConfigInput, GuardrailsHttpRequest, + }, pb::{ caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, caikit_data_model::nlp::GeneratedTextStreamResult, @@ -142,3 +147,106 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { Ok(()) } + +#[traced_test] +#[tokio::test] +async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["Hi there! How are you?".to_string()], + detector_params: DetectorParams::new(), + }), + MockResponse::json([Vec::::new()]), + ), + ); + + // Add generation mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let expected_response = vec![ + GeneratedTextStreamResult { + generated_text: "I".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " am".to_string(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " great!".to_string(), + ..Default::default() + }, + ]; + + let mut generation_mocks = MockSet::new(); + generation_mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".to_string(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb_stream(expected_response.clone()), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + "tests/test.config.yaml", + 8080, + 8081, + Some(generation_server), + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.to_string(), + inputs: "Hi there! How are you?".to_string(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + // Test custom SseStream wrapper + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + println!("{messages:?}"); + + // assertions + assert!(messages.len() == 3); + assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[2].generated_text == Some(" great!".into())); + + Ok(()) +} From b3b215fae31ff294853b9b75a14307826f4eae7b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 13 Feb 2025 15:48:40 -0300 Subject: [PATCH 034/117] Refactoring Signed-off-by: Mateus Devino --- tests/chunker.rs | 12 +++--- tests/common/chunker.rs | 2 + tests/common/orchestrator.rs | 2 +- tests/detection_content.rs | 73 ++++++++++++++++++------------------ tests/generation_nlp.rs | 12 +++--- tests/streaming.rs | 51 +++++++++++++------------ 6 files changed, 78 insertions(+), 74 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index c0bc4a5d..4c791efd 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -42,17 +42,17 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { Token { start: 0, end: 9, - text: "Hi there!".to_string(), + text: "Hi there!".into(), }, Token { start: 0, end: 9, - text: "how are you?".to_string(), + text: "how are you?".into(), }, Token { start: 0, end: 9, - text: "I am great!".to_string(), + text: "I am great!".into(), }, ], token_count: 0, @@ -63,7 +63,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { - text: "Hi there! how are you? I am great!".to_string(), + text: "Hi there! how are you? I am great!".into(), }) .with_headers(chunker_headers), MockResponse::pb(expected_response.clone()), @@ -74,7 +74,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { let _ = mock_chunker_server.start().await; let client = ChunkerClient::new(&ServiceConfig { - hostname: "localhost".to_string(), + hostname: "localhost".into(), port: Some(mock_chunker_server.addr().port()), request_timeout: None, tls: None, @@ -85,7 +85,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { .tokenization_task_predict( chunker_id, ChunkerTokenizationTaskRequest { - text: "Hi there! how are you? I am great!".to_string(), + text: "Hi there! how are you? I am great!".into(), }, ) .await; diff --git a/tests/common/chunker.rs b/tests/common/chunker.rs index 4b0c9e8b..778f7309 100644 --- a/tests/common/chunker.rs +++ b/tests/common/chunker.rs @@ -22,5 +22,7 @@ generate_grpc_server!( MockChunkersServiceServer ); +pub const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; + pub const CHUNKER_UNARY_ENDPOINT: &str = "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index bbefd02f..74cd8080 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -37,7 +37,7 @@ use url::Url; use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; /// Default orchestrator configuration file for integration tests. -pub const CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; +pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; pub const CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT: &str = "/api/v2/text/detection/content"; diff --git a/tests/detection_content.rs b/tests/detection_content.rs index f706bd9a..03f79606 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -18,14 +18,14 @@ use std::collections::HashMap; use common::{ - chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}, + chunker::{MockChunkersServiceServer, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, detectors::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, orchestrator::{ - ensure_global_rustls_state, TestOrchestratorServer, CONFIG_FILE_PATH, - CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, + ensure_global_rustls_state, TestOrchestratorServer, + CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, }, }; use fms_guardrails_orchestr8::{ @@ -40,14 +40,13 @@ use fms_guardrails_orchestr8::{ }, }; use hyper::StatusCode; -use mocktail::prelude::*; +use mocktail::{prelude::*, utils::find_available_port}; use tracing::debug; use tracing_test::traced_test; pub mod common; // Constants -const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// @@ -64,16 +63,16 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { - contents: vec!["This sentence has .".to_string()], + contents: vec!["This sentence has .".into()], detector_params: DetectorParams::new(), }), MockResponse::json(vec![vec![ContentAnalysisResponse { start: 18, end: 35, - text: "a detection here".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), + text: "a detection here".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), score: 1.0, evidence: None, }]]), @@ -83,9 +82,9 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Start orchestrator server and its dependencies let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let orchestrator_server = TestOrchestratorServer::run( - CONFIG_FILE_PATH, - 8080, - 8081, + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), None, None, Some(vec![mock_detector_server]), @@ -97,8 +96,8 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { let response = orchestrator_server .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) .json(&TextContentDetectionHttpRequest { - content: "This sentence has .".to_string(), - detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), + content: "This sentence has .".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), }) .send() .await?; @@ -113,10 +112,10 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { detections: vec![ContentAnalysisResponse { start: 18, end: 35, - text: "a detection here".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), + text: "a detection here".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), score: 1.0, evidence: None, }], @@ -144,7 +143,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { - text: "This sentence does not have a detection. But .".to_string(), + text: "This sentence does not have a detection. But .".into(), }) .with_headers(chunker_headers), MockResponse::pb(TokenizationResults { @@ -152,12 +151,12 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { Token { start: 0, end: 40, - text: "This sentence does not have a detection.".to_string(), + text: "This sentence does not have a detection.".into(), }, Token { start: 41, end: 61, - text: "But .".to_string(), + text: "But .".into(), }, ], token_count: 0, @@ -173,8 +172,8 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ - "This sentence does not have a detection.".to_string(), - "But .".to_string(), + "This sentence does not have a detection.".into(), + "But .".into(), ], detector_params: DetectorParams::new(), }), @@ -183,10 +182,10 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { vec![ContentAnalysisResponse { start: 4, end: 18, - text: "this one does".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), + text: "this one does".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), score: 1.0, evidence: None, }], @@ -198,9 +197,9 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let orchestrator_server = TestOrchestratorServer::run( - CONFIG_FILE_PATH, - 8080, - 8081, + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), None, None, Some(vec![mock_detector_server]), @@ -212,8 +211,8 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let response = orchestrator_server .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) .json(&TextContentDetectionHttpRequest { - content: "This sentence does not have a detection. But .".to_string(), - detectors: HashMap::from([(detector_name.to_string(), DetectorParams::new())]), + content: "This sentence does not have a detection. But .".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), }) .send() .await?; @@ -228,10 +227,10 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { detections: vec![ContentAnalysisResponse { start: 45, end: 59, - text: "this one does".to_string(), - detection: "has_angle_brackets".to_string(), - detection_type: "angle_brackets".to_string(), - detector_id: Some(detector_name.to_string()), + text: "this one does".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), score: 1.0, evidence: None, }], diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 2030f765..b334a9b4 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -40,15 +40,15 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { let expected_response = vec![ GeneratedTextStreamResult { - generated_text: "I".to_string(), + generated_text: "I".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " am".to_string(), + generated_text: " am".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " great!".to_string(), + generated_text: " great!".into(), ..Default::default() }, ]; @@ -58,7 +58,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { - text: "Hi there! How are you?".to_string(), + text: "Hi there! How are you?".into(), ..Default::default() }) .with_headers(headers.clone()), @@ -70,7 +70,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { generation_nlp_server.start().await?; let client = NlpClient::new(&ServiceConfig { - hostname: "localhost".to_string(), + hostname: "localhost".into(), port: Some(generation_nlp_server.addr().port()), request_timeout: None, tls: None, @@ -81,7 +81,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { .server_streaming_text_generation_task_predict( model_id, ServerStreamingTextGenerationTaskRequest { - text: "Hi there! How are you?".to_string(), + text: "Hi there! How are you?".into(), ..Default::default() }, headers, diff --git a/tests/streaming.rs b/tests/streaming.rs index c6f427b0..35656133 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -20,12 +20,15 @@ use std::collections::HashMap; use common::{ detectors::{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, - orchestrator::{ensure_global_rustls_state, SseStream, TestOrchestratorServer}, + orchestrator::{ + ensure_global_rustls_state, SseStream, TestOrchestratorServer, + ORCHESTRATOR_CONFIG_FILE_PATH, + }, }; use fms_guardrails_orchestr8::{ clients::{ detector::{ContentAnalysisRequest, ContentAnalysisResponse}, - nlp::MODEL_ID_HEADER_NAME, + nlp::MODEL_ID_HEADER_NAME as NLP_MODEL_ID_HEADER_NAME, }, models::{ ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsConfig, @@ -37,7 +40,7 @@ use fms_guardrails_orchestr8::{ }, }; use futures::StreamExt; -use mocktail::prelude::*; +use mocktail::{prelude::*, utils::find_available_port}; use tracing_test::traced_test; pub mod common; @@ -53,19 +56,19 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); let expected_response = vec![ GeneratedTextStreamResult { - generated_text: "I".to_string(), + generated_text: "I".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " am".to_string(), + generated_text: " am".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " great!".to_string(), + generated_text: " great!".into(), ..Default::default() }, ]; @@ -75,7 +78,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { - text: "Hi there! How are you?".to_string(), + text: "Hi there! How are you?".into(), ..Default::default() }) .with_headers(headers.clone()), @@ -88,9 +91,9 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::run( - "tests/test.config.yaml", - 8080, - 8081, + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), Some(generation_server), None, None, @@ -102,8 +105,8 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { let response = orchestrator_server .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) .json(&GuardrailsHttpRequest { - model_id: model_id.to_string(), - inputs: "Hi there! How are you?".to_string(), + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), guardrail_config: None, text_gen_parameters: None, }) @@ -160,7 +163,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { - contents: vec!["Hi there! How are you?".to_string()], + contents: vec!["Hi there! How are you?".into()], detector_params: DetectorParams::new(), }), MockResponse::json([Vec::::new()]), @@ -170,19 +173,19 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); let expected_response = vec![ GeneratedTextStreamResult { - generated_text: "I".to_string(), + generated_text: "I".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " am".to_string(), + generated_text: " am".into(), ..Default::default() }, GeneratedTextStreamResult { - generated_text: " great!".to_string(), + generated_text: " great!".into(), ..Default::default() }, ]; @@ -192,7 +195,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { - text: "Hi there! How are you?".to_string(), + text: "Hi there! How are you?".into(), ..Default::default() }) .with_headers(headers.clone()), @@ -204,9 +207,9 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; let generation_server = MockNlpServiceServer::new(generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( - "tests/test.config.yaml", - 8080, - 8081, + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), Some(generation_server), None, Some(vec![mock_detector_server]), @@ -218,8 +221,8 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err let response = orchestrator_server .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) .json(&GuardrailsHttpRequest { - model_id: model_id.to_string(), - inputs: "Hi there! How are you?".to_string(), + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { models: HashMap::from([(detector_name.into(), DetectorParams::new())]), From 60284c32d55d1eb6675946480e0a8ccc578a91eb Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 11:41:32 -0300 Subject: [PATCH 035/117] Replace tracing-test crate with test-log Signed-off-by: Mateus Devino --- Cargo.lock | 66 +++++++++++++++++++++++++------------- Cargo.toml | 2 +- tests/canary_test.rs | 5 ++- tests/chunker.rs | 5 ++- tests/detection_content.rs | 8 ++--- tests/generation_nlp.rs | 5 ++- tests/streaming.rs | 13 ++++---- 7 files changed, 60 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a5c3944c..8aa27eab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -628,6 +628,27 @@ dependencies = [ "syn", ] +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", +] + +[[package]] +name = "env_logger" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -728,6 +749,7 @@ dependencies = [ "serde", "serde_json", "serde_yml", + "test-log", "thiserror 2.0.11", "tokio", "tokio-rustls", @@ -739,7 +761,6 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", - "tracing-test", "url", "uuid", ] @@ -2592,6 +2613,28 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-log" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2977,27 +3020,6 @@ dependencies = [ "tracing-serde", ] -[[package]] -name = "tracing-test" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" -dependencies = [ - "tracing-core", - "tracing-subscriber", - "tracing-test-macro", -] - -[[package]] -name = "tracing-test-macro" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" -dependencies = [ - "quote", - "syn", -] - [[package]] name = "try-lock" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 6c5737fe..71ce9223 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,7 +91,7 @@ tonic-build = "0.12.3" axum-test = "17.1.0" faux = "0.1.12" mocktail = { git = "https://github.com/IBM/mocktail", version = "0.1.0-alpha" } -tracing-test = "0.2.5" +test-log = "0.2.17" [profile.release] debug = false diff --git a/tests/canary_test.rs b/tests/canary_test.rs index a2fc0d28..f7412c45 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -20,6 +20,7 @@ // For more: https://github.com/rust-lang/rust/issues/46379 use std::sync::Arc; +use test_log::test; use axum_test::TestServer; use common::orchestrator::ensure_global_rustls_state; @@ -32,7 +33,6 @@ use hyper::StatusCode; use serde_json::Value; use tokio::sync::OnceCell; use tracing::debug; -use tracing_test::traced_test; pub mod common; @@ -53,8 +53,7 @@ async fn shared_state() -> Arc { /// superficially testing the client health endpoints on the orchestrator is accessible /// and when the orchestrator is running (healthy) all the health endpoints return 200 OK. /// This will happen even if the client services or their health endpoints are not found. -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_health() { ensure_global_rustls_state(); let shared_state = ONCE.get_or_init(shared_state).await.clone(); diff --git a/tests/chunker.rs b/tests/chunker.rs index 4c791efd..7fb2e812 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -25,12 +25,11 @@ use fms_guardrails_orchestr8::{ }, }; use mocktail::prelude::*; -use tracing_test::traced_test; +use test_log::test; pub mod common; -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { // Add detector mock let chunker_id = "sentence_chunker"; diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 03f79606..330ca427 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -16,6 +16,7 @@ */ use std::collections::HashMap; +use test_log::test; use common::{ chunker::{MockChunkersServiceServer, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, @@ -42,7 +43,6 @@ use fms_guardrails_orchestr8::{ use hyper::StatusCode; use mocktail::{prelude::*, utils::find_available_port}; use tracing::debug; -use tracing_test::traced_test; pub mod common; @@ -51,8 +51,7 @@ pub mod common; /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects text between . -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -128,8 +127,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { /// Asserts a scenario with a single detection works as expected (with sentence chunker). /// /// This test mocks a detector that detects text between . -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index b334a9b4..493f651b 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -26,12 +26,11 @@ use fms_guardrails_orchestr8::{ }; use futures::StreamExt; use mocktail::prelude::*; -use tracing_test::traced_test; +use test_log::test; pub mod common; -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { // Add detector mock let model_id = "my-super-model-8B"; diff --git a/tests/streaming.rs b/tests/streaming.rs index 35656133..7f8dd0f8 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -16,6 +16,7 @@ */ use std::collections::HashMap; +use test_log::test; use common::{ detectors::{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, @@ -41,15 +42,14 @@ use fms_guardrails_orchestr8::{ }; use futures::StreamExt; use mocktail::{prelude::*, utils::find_available_port}; -use tracing_test::traced_test; +use tracing::debug; pub mod common; const STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT: &str = "/api/v1/task/server-streaming-classification-with-text-generation"; -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_no_detectors() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); @@ -140,7 +140,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { .await .into_iter() .collect::, anyhow::Error>>()?; - println!("{messages:?}"); + debug!(?messages); // assertions assert!(messages.len() == 3); @@ -151,8 +151,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { Ok(()) } -#[traced_test] -#[tokio::test] +#[test(tokio::test)] async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -243,7 +242,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err .await .into_iter() .collect::, anyhow::Error>>()?; - println!("{messages:?}"); + println!("{messages:#?}"); // assertions assert!(messages.len() == 3); From 14e32d4facdb1bad45d79a5e79e870606cbe142c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 11:46:47 -0300 Subject: [PATCH 036/117] test case: test_input_detector_sentence_chunker_no_detections Signed-off-by: Mateus Devino --- tests/streaming.rs | 152 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 3 deletions(-) diff --git a/tests/streaming.rs b/tests/streaming.rs index 7f8dd0f8..14eb9b04 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -19,7 +19,11 @@ use std::collections::HashMap; use test_log::test; use common::{ - detectors::{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, + chunker::{MockChunkersServiceServer, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, + detectors::{ + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + TEXT_CONTENTS_DETECTOR_ENDPOINT, + }, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, orchestrator::{ ensure_global_rustls_state, SseStream, TestOrchestratorServer, @@ -28,6 +32,7 @@ use common::{ }; use fms_guardrails_orchestr8::{ clients::{ + chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, detector::{ContentAnalysisRequest, ContentAnalysisResponse}, nlp::MODEL_ID_HEADER_NAME as NLP_MODEL_ID_HEADER_NAME, }, @@ -36,8 +41,10 @@ use fms_guardrails_orchestr8::{ GuardrailsConfigInput, GuardrailsHttpRequest, }, pb::{ - caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, - caikit_data_model::nlp::GeneratedTextStreamResult, + caikit::runtime::{ + chunkers::ChunkerTokenizationTaskRequest, nlp::ServerStreamingTextGenerationTaskRequest, + }, + caikit_data_model::nlp::{GeneratedTextStreamResult, Token, TokenizationResults}, }, }; use futures::StreamExt; @@ -252,3 +259,142 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err Ok(()) } + +#[test(tokio::test)] +async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + + // Add input chunker mock + let chunker_id = CHUNKER_NAME_SENTENCE; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "Hi there! How are you?".into(), + }) + .with_headers(chunker_headers), + MockResponse::pb(TokenizationResults { + results: vec![ + Token { + start: 0, + end: 9, + text: "Hi there!".into(), + }, + Token { + start: 10, + end: 22, + text: " How are you?".into(), + }, + ], + token_count: 0, + }), + ), + ); + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["Hi there!".into(), " How are you?".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json([ + Vec::::new(), + Vec::::new(), + ]), + ), + ); + + // Add generation mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let expected_response = vec![ + GeneratedTextStreamResult { + generated_text: "I".into(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " am".into(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " great!".into(), + ..Default::default() + }, + ]; + + let mut generation_mocks = MockSet::new(); + generation_mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".into(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb_stream(expected_response.clone()), + ), + ); + + // Start orchestrator server and its dependencies + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + Some(generation_server), + None, + Some(vec![mock_detector_server]), + Some(vec![(chunker_id.into(), mock_chunker_server)]), + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + println!("{response:#?}"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + println!("{messages:#?}"); + + // assertions + assert!(messages.len() == 3); + assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[2].generated_text == Some(" great!".into())); + + Ok(()) +} From a702fbf707e692b76c40bd9166bd28b4bc66972c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 14:52:20 -0300 Subject: [PATCH 037/117] test case: test_input_detector_returns_404 Signed-off-by: Mateus Devino --- tests/common/errors.rs | 29 ++++++++++++++ tests/common/mod.rs | 1 + tests/streaming.rs | 86 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 tests/common/errors.rs diff --git a/tests/common/errors.rs b/tests/common/errors.rs new file mode 100644 index 00000000..97f0da9b --- /dev/null +++ b/tests/common/errors.rs @@ -0,0 +1,29 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct DetectorError { + pub code: u16, + pub message: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OrchestratorError { + pub code: u16, + pub details: String, +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b50f826f..de269335 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -16,5 +16,6 @@ */ pub mod chunker; pub mod detectors; +pub mod errors; pub mod generation; pub mod orchestrator; diff --git a/tests/streaming.rs b/tests/streaming.rs index 14eb9b04..3592a8b3 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -24,6 +24,7 @@ use common::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, + errors::{DetectorError, OrchestratorError}, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, orchestrator::{ ensure_global_rustls_state, SseStream, TestOrchestratorServer, @@ -249,7 +250,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err .await .into_iter() .collect::, anyhow::Error>>()?; - println!("{messages:#?}"); + debug!("{messages:#?}"); // assertions assert!(messages.len() == 3); @@ -378,7 +379,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh .send() .await?; - println!("{response:#?}"); + debug!("{response:#?}"); // Test custom SseStream wrapper let sse_stream: SseStream = @@ -388,7 +389,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh .await .into_iter() .collect::, anyhow::Error>>()?; - println!("{messages:#?}"); + debug!("{messages:#?}"); // assertions assert!(messages.len() == 3); @@ -398,3 +399,82 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh Ok(()) } + +#[test(tokio::test)] +async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let model_id = "my-super-model-8B"; + let expected_detector_error = DetectorError { + code: 404, + message: "Not found.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 404".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This should return a 404".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == 404); + assert!( + messages[0].details + == format!( + "detector request failed for `{}`: {}", + detector_name, expected_detector_error.message + ) + ); + + Ok(()) +} From eab68a1e1dcac3ee9e17a9e4a132b8af373d252b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 15:07:14 -0300 Subject: [PATCH 038/117] test case: test_input_detector_returns_503 Signed-off-by: Mateus Devino --- tests/streaming.rs | 79 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/streaming.rs b/tests/streaming.rs index 3592a8b3..4ed55c61 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -400,6 +400,85 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh Ok(()) } +#[test(tokio::test)] +async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let model_id = "my-super-model-8B"; + let expected_detector_error = DetectorError { + code: 503, + message: "The detector service is overloaded.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 503".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This should return a 503".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == 503); + assert!( + messages[0].details + == format!( + "detector request failed for `{}`: {}", + detector_name, expected_detector_error.message + ) + ); + + Ok(()) +} + #[test(tokio::test)] async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { ensure_global_rustls_state(); From 0a35b7cf1670578bdc9d40f0931feb3e2704c520 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 15:20:18 -0300 Subject: [PATCH 039/117] test case: test_input_detector_returns_non_compliant_message Signed-off-by: Mateus Devino --- tests/streaming.rs | 147 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/streaming.rs b/tests/streaming.rs index 4ed55c61..18da189e 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -557,3 +557,150 @@ async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let model_id = "my-super-model-8B"; + let expected_detector_error = DetectorError { + code: 500, + message: "Internal detector error.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 500".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This should return a 500".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == 500); + assert!(messages[0].details == "unexpected error occurred while processing request"); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { + ensure_global_rustls_state(); + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let model_id = "my-super-model-8B"; + let non_compliant_detector_response = serde_json::json!({ + "detections": true, + }); + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec![ + "The detector will return a message non compliant with the API".into(), + ], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&non_compliant_detector_response).with_code(StatusCode::OK), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "The detector will return a message non compliant with the API".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == 500); + assert!(messages[0].details == "unexpected error occurred while processing request"); + + Ok(()) +} From 76fe749daf10c217a72ff8be5b8598bd203c4eb6 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 15:34:36 -0300 Subject: [PATCH 040/117] refactor: move ensure_global_rustls_state() call into TestOrchestratorServer::run() Signed-off-by: Mateus Devino --- tests/common/orchestrator.rs | 3 +++ tests/detection_content.rs | 7 ++----- tests/streaming.rs | 14 ++------------ 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 74cd8080..d1bb75d4 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -63,6 +63,9 @@ impl TestOrchestratorServer { detector_servers: Option>, chunker_servers: Option>, ) -> Result { + // Set default crypto provider + ensure_global_rustls_state(); + // Load orchestrator config let mut config = OrchestratorConfig::load(config_path).await?; diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 330ca427..4c8272a1 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -25,8 +25,8 @@ use common::{ TEXT_CONTENTS_DETECTOR_ENDPOINT, }, orchestrator::{ - ensure_global_rustls_state, TestOrchestratorServer, - CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, + TestOrchestratorServer, CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, + ORCHESTRATOR_CONFIG_FILE_PATH, }, }; use fms_guardrails_orchestr8::{ @@ -53,7 +53,6 @@ pub mod common; /// This test mocks a detector that detects text between . #[test(tokio::test)] async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; // Add detector mock @@ -129,8 +128,6 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { /// This test mocks a detector that detects text between . #[test(tokio::test)] async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); - // Add chunker mock let chunker_id = CHUNKER_NAME_SENTENCE; let mut chunker_headers = HeaderMap::new(); diff --git a/tests/streaming.rs b/tests/streaming.rs index 18da189e..44ee9657 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -26,10 +26,7 @@ use common::{ }, errors::{DetectorError, OrchestratorError}, generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, - orchestrator::{ - ensure_global_rustls_state, SseStream, TestOrchestratorServer, - ORCHESTRATOR_CONFIG_FILE_PATH, - }, + orchestrator::{SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH}, }; use fms_guardrails_orchestr8::{ clients::{ @@ -59,8 +56,6 @@ const STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT: &str = #[test(tokio::test)] async fn test_no_detectors() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); - // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); @@ -161,7 +156,6 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { #[test(tokio::test)] async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; // Add input detection mock @@ -263,7 +257,6 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err #[test(tokio::test)] async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; // Add input chunker mock @@ -402,7 +395,6 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh #[test(tokio::test)] async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; let model_id = "my-super-model-8B"; let expected_detector_error = DetectorError { @@ -481,7 +473,6 @@ async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { #[test(tokio::test)] async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; let model_id = "my-super-model-8B"; let expected_detector_error = DetectorError { @@ -560,7 +551,6 @@ async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { #[test(tokio::test)] async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; let model_id = "my-super-model-8B"; let expected_detector_error = DetectorError { @@ -633,7 +623,7 @@ async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { #[test(tokio::test)] async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { - ensure_global_rustls_state(); + // ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; let model_id = "my-super-model-8B"; let non_compliant_detector_response = serde_json::json!({ From 0ab0d0175b585fdddaa19551d03127aa2444ba2f Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 18 Feb 2025 17:12:09 -0300 Subject: [PATCH 041/117] test case: test_input_detector_whole_doc_with_detections() Signed-off-by: Mateus Devino --- tests/common/generation.rs | 3 + tests/streaming.rs | 132 +++++++++++++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/tests/common/generation.rs b/tests/common/generation.rs index 2bbceecb..90f34863 100644 --- a/tests/common/generation.rs +++ b/tests/common/generation.rs @@ -21,3 +21,6 @@ generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; + +pub const GENERATION_NLP_TOKENIZATION_ENDPOINT: &str = + "/caikit.runtime.Nlp.NlpService/TokenizationTaskPredict"; diff --git a/tests/streaming.rs b/tests/streaming.rs index 44ee9657..af878d7d 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -25,7 +25,10 @@ use common::{ TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, - generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}, + generation::{ + MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT, + GENERATION_NLP_TOKENIZATION_ENDPOINT, + }, orchestrator::{SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH}, }; use fms_guardrails_orchestr8::{ @@ -35,12 +38,14 @@ use fms_guardrails_orchestr8::{ nlp::MODEL_ID_HEADER_NAME as NLP_MODEL_ID_HEADER_NAME, }, models::{ - ClassifiedGeneratedTextStreamResult, DetectorParams, GuardrailsConfig, - GuardrailsConfigInput, GuardrailsHttpRequest, + ClassifiedGeneratedTextStreamResult, DetectionWarning, DetectorParams, GuardrailsConfig, + GuardrailsConfigInput, GuardrailsHttpRequest, TextGenTokenClassificationResults, + TokenClassificationResult, }, pb::{ caikit::runtime::{ - chunkers::ChunkerTokenizationTaskRequest, nlp::ServerStreamingTextGenerationTaskRequest, + chunkers::ChunkerTokenizationTaskRequest, + nlp::{ServerStreamingTextGenerationTaskRequest, TokenizationTaskRequest}, }, caikit_data_model::nlp::{GeneratedTextStreamResult, Token, TokenizationResults}, }, @@ -116,7 +121,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { .send() .await?; - // Example showing how to create an event stream from a bytes stream. + // // Example showing how to create an event stream from a bytes stream. // let mut events = Vec::new(); // let mut event_stream = response.bytes_stream().eventsource(); // while let Some(event) = event_stream.next().await { @@ -255,6 +260,123 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err Ok(()) } +#[test(tokio::test)] +async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let mock_detection_response = ContentAnalysisResponse { + start: 46, + end: 59, + text: "this one does".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec![ + "This sentence does not have a detection. But .".into(), + ], + detector_params: DetectorParams::new(), + }), + MockResponse::json([vec![mock_detection_response.clone()]]), + ), + ); + + // Add generation mock for input token count + let model_id = "my-super-model-8B"; + let mock_tokenization_response = TokenizationResults { + results: Vec::new(), + token_count: 61, + }; + let mut headers = HeaderMap::new(); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + let mut generation_mocks = MockSet::new(); + generation_mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), + Mock::new( + MockRequest::pb(TokenizationTaskRequest { + text: "This sentence does not have a detection. But .".into(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb(mock_tokenization_response.clone()), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + Some(generation_server), + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This sentence does not have a detection. But .".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + // Test custom SseStream wrapper + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].generated_text == None); + assert!( + messages[0].token_classification_results + == TextGenTokenClassificationResults { + input: Some(vec![TokenClassificationResult { + start: mock_detection_response.start as u32, + end: mock_detection_response.end as u32, + word: mock_detection_response.text, + entity: mock_detection_response.detection, + entity_group: mock_detection_response.detection_type, + detector_id: mock_detection_response.detector_id, + score: mock_detection_response.score, + token_count: None + }]), + output: None + } + ); + assert!(messages[0].input_token_count == mock_tokenization_response.token_count as u32); + assert!(messages[0].warnings == Some(vec![DetectionWarning{ id: Some(fms_guardrails_orchestr8::models::DetectionWarningReason::UnsuitableInput), message: Some("Unsuitable input detected. Please check the detected entities on your input and try again with the unsuitable input removed.".into()) }])); + + Ok(()) +} + #[test(tokio::test)] async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; From 14abbfc20bf555a5183336bbf38deacaf3da5f57 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 09:47:27 -0300 Subject: [PATCH 042/117] test case: test_input_detector_sentence_chunker_no_detections() Signed-off-by: Mateus Devino --- tests/streaming.rs | 138 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/tests/streaming.rs b/tests/streaming.rs index af878d7d..d4db554f 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -260,6 +260,144 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err Ok(()) } +#[test(tokio::test)] +async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + + // Add input chunker mock + let chunker_id = CHUNKER_NAME_SENTENCE; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "Hi there! How are you?".into(), + }) + .with_headers(chunker_headers), + MockResponse::pb(TokenizationResults { + results: vec![ + Token { + start: 0, + end: 9, + text: "Hi there!".into(), + }, + Token { + start: 10, + end: 22, + text: " How are you?".into(), + }, + ], + token_count: 0, + }), + ), + ); + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["Hi there!".into(), " How are you?".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json([ + Vec::::new(), + Vec::::new(), + ]), + ), + ); + + // Add generation mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let expected_response = vec![ + GeneratedTextStreamResult { + generated_text: "I".into(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " am".into(), + ..Default::default() + }, + GeneratedTextStreamResult { + generated_text: " great!".into(), + ..Default::default() + }, + ]; + + let mut generation_mocks = MockSet::new(); + generation_mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".into(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::pb_stream(expected_response.clone()), + ), + ); + + // Start orchestrator server and its dependencies + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + Some(generation_server), + None, + Some(vec![mock_detector_server]), + Some(vec![(chunker_id.into(), mock_chunker_server)]), + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 3); + assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[2].generated_text == Some(" great!".into())); + + Ok(()) +} + #[test(tokio::test)] async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; From 00d586491ad0abd8fdbe28b152f49118873579a2 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 09:48:54 -0300 Subject: [PATCH 043/117] test case: test_input_detector_sentence_chunker_with_detections() Signed-off-by: Mateus Devino --- tests/streaming.rs | 97 ++++++++++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/tests/streaming.rs b/tests/streaming.rs index d4db554f..659aa05c 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -516,10 +516,8 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E } #[test(tokio::test)] -async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { - let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; - - // Add input chunker mock +async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), anyhow::Error> { + // Add chunker mock let chunker_id = CHUNKER_NAME_SENTENCE; let mut chunker_headers = HeaderMap::new(); chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); @@ -529,20 +527,20 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { - text: "Hi there! How are you?".into(), + text: "This sentence does not have a detection. But .".into(), }) .with_headers(chunker_headers), MockResponse::pb(TokenizationResults { results: vec![ Token { start: 0, - end: 9, - text: "Hi there!".into(), + end: 40, + text: "This sentence does not have a detection.".into(), }, Token { - start: 10, - end: 22, - text: " How are you?".into(), + start: 41, + end: 61, + text: "But .".into(), }, ], token_count: 0, @@ -551,51 +549,50 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh ); // Add input detection mock + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let mock_detection_response = ContentAnalysisResponse { + start: 5, + end: 18, + text: "this one does".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + }; let mut detection_mocks = MockSet::new(); detection_mocks.insert( MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { - contents: vec!["Hi there!".into(), " How are you?".into()], + contents: vec![ + "This sentence does not have a detection.".into(), + "But .".into(), + ], detector_params: DetectorParams::new(), }), - MockResponse::json([ - Vec::::new(), - Vec::::new(), - ]), + MockResponse::json(vec![vec![], vec![mock_detection_response.clone()]]), ), ); - // Add generation mock + // Add generation mock for input token count let model_id = "my-super-model-8B"; + let mock_tokenization_response = TokenizationResults { + results: Vec::new(), + token_count: 61, + }; let mut headers = HeaderMap::new(); headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); - - let expected_response = vec![ - GeneratedTextStreamResult { - generated_text: "I".into(), - ..Default::default() - }, - GeneratedTextStreamResult { - generated_text: " am".into(), - ..Default::default() - }, - GeneratedTextStreamResult { - generated_text: " great!".into(), - ..Default::default() - }, - ]; - let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), Mock::new( - MockRequest::pb(ServerStreamingTextGenerationTaskRequest { - text: "Hi there! How are you?".into(), + MockRequest::pb(TokenizationTaskRequest { + text: "This sentence does not have a detection. But .".into(), ..Default::default() }) .with_headers(headers.clone()), - MockResponse::pb_stream(expected_response.clone()), + MockResponse::pb(mock_tokenization_response.clone()), ), ); @@ -619,7 +616,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), - inputs: "Hi there! How are you?".into(), + inputs: "This sentence does not have a detection. But .".into(), guardrail_config: Some(GuardrailsConfig { input: Some(GuardrailsConfigInput { models: HashMap::from([(detector_name.into(), DetectorParams::new())]), @@ -632,8 +629,6 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh .send() .await?; - debug!("{response:#?}"); - // Test custom SseStream wrapper let sse_stream: SseStream = SseStream::new(response.bytes_stream()); @@ -645,10 +640,26 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh debug!("{messages:#?}"); // assertions - assert!(messages.len() == 3); - assert!(messages[0].generated_text == Some("I".into())); - assert!(messages[1].generated_text == Some(" am".into())); - assert!(messages[2].generated_text == Some(" great!".into())); + assert!(messages.len() == 1); + assert!(messages[0].generated_text == None); + assert!( + messages[0].token_classification_results + == TextGenTokenClassificationResults { + input: Some(vec![TokenClassificationResult { + start: 46 as u32, // index of first token of detected text, relative to the `inputs` string sent in the orchestrator request. + end: 59 as u32, // index of last token (+1) of detected text, relative to the `inputs` string sent in the orchestrator request. + word: mock_detection_response.text, + entity: mock_detection_response.detection, + entity_group: mock_detection_response.detection_type, + detector_id: mock_detection_response.detector_id, + score: mock_detection_response.score, + token_count: None + }]), + output: None + } + ); + assert!(messages[0].input_token_count == mock_tokenization_response.token_count as u32); + assert!(messages[0].warnings == Some(vec![DetectionWarning{ id: Some(fms_guardrails_orchestr8::models::DetectionWarningReason::UnsuitableInput), message: Some("Unsuitable input detected. Please check the detected entities on your input and try again with the unsuitable input removed.".into()) }])); Ok(()) } From d383098f6e32f61abf7503a63a54b598f6779a47 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 10:20:06 -0300 Subject: [PATCH 044/117] refactor: rename/move integration testing constants Signed-off-by: Mateus Devino --- tests/common/chunker.rs | 2 ++ tests/common/detectors.rs | 5 ++++- tests/common/errors.rs | 2 ++ tests/common/generation.rs | 2 +- tests/common/orchestrator.rs | 5 ++++- tests/detection_content.rs | 10 ++++------ tests/streaming.rs | 26 +++++++++++++------------- 7 files changed, 30 insertions(+), 22 deletions(-) diff --git a/tests/common/chunker.rs b/tests/common/chunker.rs index 778f7309..8678196d 100644 --- a/tests/common/chunker.rs +++ b/tests/common/chunker.rs @@ -22,7 +22,9 @@ generate_grpc_server!( MockChunkersServiceServer ); +// Chunker names pub const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; +// Chunker endpoints pub const CHUNKER_UNARY_ENDPOINT: &str = "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index 75c1247d..c882bdeb 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -14,7 +14,10 @@ limitations under the License. */ -pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; +// Detector names pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; + +// Detector endpoints +pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; diff --git a/tests/common/errors.rs b/tests/common/errors.rs index 97f0da9b..0cc5d31f 100644 --- a/tests/common/errors.rs +++ b/tests/common/errors.rs @@ -16,12 +16,14 @@ */ use serde::{Deserialize, Serialize}; +/// Errors returned by detector endpoints. #[derive(Serialize, Deserialize, Debug)] pub struct DetectorError { pub code: u16, pub message: String, } +/// Errors returned by orchestrator endpoints. #[derive(Serialize, Deserialize, Debug)] pub struct OrchestratorError { pub code: u16, diff --git a/tests/common/generation.rs b/tests/common/generation.rs index 90f34863..e6365fbd 100644 --- a/tests/common/generation.rs +++ b/tests/common/generation.rs @@ -19,8 +19,8 @@ use mocktail::mock::MockSet; generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); +// NLP generation server endpoints pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; - pub const GENERATION_NLP_TOKENIZATION_ENDPOINT: &str = "/caikit.runtime.Nlp.NlpService/TokenizationTaskPredict"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index d1bb75d4..c4bd9754 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -39,7 +39,10 @@ use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer /// Default orchestrator configuration file for integration tests. pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; -pub const CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT: &str = "/api/v2/text/detection/content"; +// Endpoints +pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = + "/api/v1/task/server-streaming-classification-with-text-generation"; +pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 4c8272a1..8da3980b 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -25,8 +25,8 @@ use common::{ TEXT_CONTENTS_DETECTOR_ENDPOINT, }, orchestrator::{ - TestOrchestratorServer, CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT, - ORCHESTRATOR_CONFIG_FILE_PATH, + TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, + ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, }, }; use fms_guardrails_orchestr8::{ @@ -46,8 +46,6 @@ use tracing::debug; pub mod common; -// Constants - /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects text between . @@ -92,7 +90,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Make orchestrator call let response = orchestrator_server - .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) .json(&TextContentDetectionHttpRequest { content: "This sentence has .".into(), detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), @@ -204,7 +202,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { // Make orchestrator call let response = orchestrator_server - .post(CONTENT_DETECTION_ORCHESTRATOR_ENDPOINT) + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) .json(&TextContentDetectionHttpRequest { content: "This sentence does not have a detection. But .".into(), detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), diff --git a/tests/streaming.rs b/tests/streaming.rs index 659aa05c..1be2b919 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -29,7 +29,10 @@ use common::{ MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT, GENERATION_NLP_TOKENIZATION_ENDPOINT, }, - orchestrator::{SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH}, + orchestrator::{ + SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, + ORCHESTRATOR_STREAMING_ENDPOINT, + }, }; use fms_guardrails_orchestr8::{ clients::{ @@ -56,9 +59,6 @@ use tracing::debug; pub mod common; -const STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT: &str = - "/api/v1/task/server-streaming-classification-with-text-generation"; - #[test(tokio::test)] async fn test_no_detectors() -> Result<(), anyhow::Error> { // Add generation mock @@ -111,7 +111,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "Hi there! How are you?".into(), @@ -225,7 +225,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "Hi there! How are you?".into(), @@ -361,7 +361,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "Hi there! How are you?".into(), @@ -464,7 +464,7 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "This sentence does not have a detection. But .".into(), @@ -613,7 +613,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "This sentence does not have a detection. But .".into(), @@ -701,7 +701,7 @@ async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "This should return a 503".into(), @@ -779,7 +779,7 @@ async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "This should return a 404".into(), @@ -857,7 +857,7 @@ async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "This should return a 500".into(), @@ -931,7 +931,7 @@ async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyho // Example orchestrator request with streaming response let response = orchestrator_server - .post(STREAMING_CLASSIFICATION_WITH_GEN_ENDPOINT) + .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { model_id: model_id.into(), inputs: "The detector will return a message non compliant with the API".into(), From 2e1d9275191b9c5ba31f347a5d20b4cae6a936e5 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 14:20:53 -0300 Subject: [PATCH 045/117] test case: test_input_detector_returns_an_error() Signed-off-by: Mateus Devino --- tests/common/orchestrator.rs | 6 ++- tests/streaming.rs | 78 ++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index c4bd9754..432c888b 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -36,7 +36,7 @@ use url::Url; use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; -/// Default orchestrator configuration file for integration tests. +// Default orchestrator configuration file for integration tests. pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; // Endpoints @@ -44,6 +44,10 @@ pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = "/api/v1/task/server-streaming-classification-with-text-generation"; pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; +// Messages +pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = + "unexpected error occurred while processing request"; + pub fn ensure_global_rustls_state() { let _ = ring::default_provider().install_default(); } diff --git a/tests/streaming.rs b/tests/streaming.rs index 1be2b919..f18895c2 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -31,7 +31,7 @@ use common::{ }, orchestrator::{ SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_STREAMING_ENDPOINT, + ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, ORCHESTRATOR_STREAMING_ENDPOINT, }, }; use fms_guardrails_orchestr8::{ @@ -887,7 +887,7 @@ async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { // assertions assert!(messages.len() == 1); assert!(messages[0].code == 500); - assert!(messages[0].details == "unexpected error occurred while processing request"); + assert!(messages[0].details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); Ok(()) } @@ -961,7 +961,79 @@ async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyho // assertions assert!(messages.len() == 1); assert!(messages[0].code == 500); - assert!(messages[0].details == "unexpected error occurred while processing request"); + assert!(messages[0].details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let model_id = "my-super-model-8B"; + + // Add input chunker mock + let chunker_id = CHUNKER_NAME_SENTENCE; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "Hi there! How are you?".into(), + }) + .with_headers(chunker_headers), + MockResponse::empty().with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + Some(vec![(chunker_id.into(), mock_chunker_server)]), + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == StatusCode::INTERNAL_SERVER_ERROR); + assert!(messages[0].details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); Ok(()) } From a7ced7ee015d733a36df3f312fef887faa79c8cb Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 14:31:56 -0300 Subject: [PATCH 046/117] test case: test_generation_server_returns_an_error() Signed-off-by: Mateus Devino --- tests/streaming.rs | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/streaming.rs b/tests/streaming.rs index f18895c2..3a465705 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -1037,3 +1037,87 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["Hi there! How are you?".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json([Vec::::new()]), + ), + ); + + // Add generation mock + let model_id = "my-super-model-8B"; + let mut headers = HeaderMap::new(); + headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + + let mut generation_mocks = MockSet::new(); + generation_mocks.insert( + MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + Mock::new( + MockRequest::pb(ServerStreamingTextGenerationTaskRequest { + text: "Hi there! How are you?".into(), + ..Default::default() + }) + .with_headers(headers.clone()), + MockResponse::empty().with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + Some(generation_server), + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(detector_name.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + // Test custom SseStream wrapper + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + // assertions + assert!(messages.len() == 1); + assert!(messages[0].code == StatusCode::INTERNAL_SERVER_ERROR); + assert!(messages[0].details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From f3a7c940bc0e9c50a0b48843509a19724a2283d1 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 19 Feb 2025 14:56:03 -0300 Subject: [PATCH 047/117] test case: test_orchestrator_receives_a_non_compliant_request() Signed-off-by: Mateus Devino --- tests/streaming.rs | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/streaming.rs b/tests/streaming.rs index 3a465705..125725f5 100644 --- a/tests/streaming.rs +++ b/tests/streaming.rs @@ -1121,3 +1121,48 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_non_compliant_request() -> Result<(), anyhow::Error> { + let model_id = "my-super-model-8B"; + + // Run test orchestrator server + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&serde_json::json!({ + "model_id": model_id, + "inputs": "This request does not comply with the orchestrator API", + "guardrail_config": { + "inputs": {}, + "outputs": {} + }, + "non_existing_field": "random value" + })) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + + let response_body = response.json::().await?; + assert!(response_body.code == StatusCode::UNPROCESSABLE_ENTITY); + assert!(response_body + .details + .starts_with("non_existing_field: unknown field `non_existing_field`")); + + Ok(()) +} From a3736e9b5b1b03cff437e9ebf35a63d189ee36ab Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 18:43:03 -0300 Subject: [PATCH 048/117] Add grpc_dns_probe_interval tests Signed-off-by: Mateus Devino --- tests/chunker.rs | 1 + tests/generation_nlp.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/chunker.rs b/tests/chunker.rs index 7fb2e812..75d7404f 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -77,6 +77,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { port: Some(mock_chunker_server.addr().port()), request_timeout: None, tls: None, + grpc_dns_probe_interval: None, }) .await; diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 493f651b..4f6145b7 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -73,6 +73,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { port: Some(generation_nlp_server.addr().port()), request_timeout: None, tls: None, + grpc_dns_probe_interval: None, }) .await; From 7e499d2415ddc6437169b59064db4c715750a0ef Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 18:47:18 -0300 Subject: [PATCH 049/117] Rename streaming tests file Signed-off-by: Mateus Devino --- tests/{streaming.rs => streaming_classification_with_gen_nlp.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{streaming.rs => streaming_classification_with_gen_nlp.rs} (100%) diff --git a/tests/streaming.rs b/tests/streaming_classification_with_gen_nlp.rs similarity index 100% rename from tests/streaming.rs rename to tests/streaming_classification_with_gen_nlp.rs From 824611685f2ff2a199c97d59e59b7b5cb7c7bd50 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 18:57:38 -0300 Subject: [PATCH 050/117] Create test-specific header name constants Signed-off-by: Mateus Devino --- src/clients/nlp.rs | 2 +- tests/common/chunker.rs | 2 + tests/common/generation.rs | 2 + tests/generation_nlp.rs | 11 +++-- .../streaming_classification_with_gen_nlp.rs | 45 +++++++++++++------ 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index 6bb60096..e2651851 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -44,7 +44,7 @@ use crate::{ }; const DEFAULT_PORT: u16 = 8085; -pub const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; +const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; #[cfg_attr(test, faux::create)] #[derive(Clone)] diff --git a/tests/common/chunker.rs b/tests/common/chunker.rs index 8678196d..05042055 100644 --- a/tests/common/chunker.rs +++ b/tests/common/chunker.rs @@ -28,3 +28,5 @@ pub const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; // Chunker endpoints pub const CHUNKER_UNARY_ENDPOINT: &str = "/caikit.runtime.Chunkers.ChunkersService/ChunkerTokenizationTaskPredict"; + +pub const CHUNKER_MODEL_ID_HEADER_NAME: &str = "mm-model-id"; diff --git a/tests/common/generation.rs b/tests/common/generation.rs index e6365fbd..a68a127c 100644 --- a/tests/common/generation.rs +++ b/tests/common/generation.rs @@ -24,3 +24,5 @@ pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = "/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict"; pub const GENERATION_NLP_TOKENIZATION_ENDPOINT: &str = "/caikit.runtime.Nlp.NlpService/TokenizationTaskPredict"; + +pub const GENERATION_NLP_MODEL_ID_HEADER_NAME: &str = "mm-model-id"; diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 4f6145b7..2e4234b5 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -15,9 +15,11 @@ */ -use common::generation::{MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT}; +use common::generation::{ + MockNlpServiceServer, GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_STREAMING_ENDPOINT, +}; use fms_guardrails_orchestr8::{ - clients::{nlp::MODEL_ID_HEADER_NAME, NlpClient}, + clients::NlpClient, config::ServiceConfig, pb::{ caikit::runtime::nlp::ServerStreamingTextGenerationTaskRequest, @@ -35,7 +37,10 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { // Add detector mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let expected_response = vec![ GeneratedTextStreamResult { diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index 125725f5..3f217c42 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -19,15 +19,18 @@ use std::collections::HashMap; use test_log::test; use common::{ - chunker::{MockChunkersServiceServer, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, + chunker::{ + MockChunkersServiceServer, CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE, + CHUNKER_UNARY_ENDPOINT, + }, detectors::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, generation::{ - MockNlpServiceServer, GENERATION_NLP_STREAMING_ENDPOINT, - GENERATION_NLP_TOKENIZATION_ENDPOINT, + MockNlpServiceServer, GENERATION_NLP_MODEL_ID_HEADER_NAME, + GENERATION_NLP_STREAMING_ENDPOINT, GENERATION_NLP_TOKENIZATION_ENDPOINT, }, orchestrator::{ SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, @@ -35,11 +38,7 @@ use common::{ }, }; use fms_guardrails_orchestr8::{ - clients::{ - chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, - detector::{ContentAnalysisRequest, ContentAnalysisResponse}, - nlp::MODEL_ID_HEADER_NAME as NLP_MODEL_ID_HEADER_NAME, - }, + clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, models::{ ClassifiedGeneratedTextStreamResult, DetectionWarning, DetectorParams, GuardrailsConfig, GuardrailsConfigInput, GuardrailsHttpRequest, TextGenTokenClassificationResults, @@ -64,7 +63,10 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let expected_response = vec![ GeneratedTextStreamResult { @@ -179,7 +181,10 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let expected_response = vec![ GeneratedTextStreamResult { @@ -314,7 +319,10 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let expected_response = vec![ GeneratedTextStreamResult { @@ -434,7 +442,10 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E token_count: 61, }; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let mut generation_mocks = MockSet::new(); generation_mocks.insert( MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), @@ -582,7 +593,10 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an token_count: 61, }; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let mut generation_mocks = MockSet::new(); generation_mocks.insert( MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), @@ -1058,7 +1072,10 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> // Add generation mock let model_id = "my-super-model-8B"; let mut headers = HeaderMap::new(); - headers.insert(NLP_MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); + headers.insert( + GENERATION_NLP_MODEL_ID_HEADER_NAME, + model_id.parse().unwrap(), + ); let mut generation_mocks = MockSet::new(); generation_mocks.insert( From 6f7a0d6febe526a6a02f4bee945e77dec400b21a Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 19:01:56 -0300 Subject: [PATCH 051/117] Undo src/server.rs changes Signed-off-by: Mateus Devino --- src/server.rs | 106 +++++++++++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/src/server.rs b/src/server.rs index d9bbe93e..28d9ad93 100644 --- a/src/server.rs +++ b/src/server.rs @@ -166,7 +166,59 @@ pub async fn run( } // (2b) Add main guardrails server routes - let app = get_app(shared_state); + let mut router = Router::new() + .route( + &format!("{}/classification-with-text-generation", API_PREFIX), + post(classification_with_gen), + ) + .route( + &format!("{}/detection/stream-content", TEXT_API_PREFIX), + post(stream_content_detection), + ) + .route( + &format!( + "{}/server-streaming-classification-with-text-generation", + API_PREFIX + ), + post(stream_classification_with_gen), + ) + .route( + &format!("{}/generation-detection", TEXT_API_PREFIX), + post(generation_with_detection), + ) + .route( + &format!("{}/detection/content", TEXT_API_PREFIX), + post(detection_content), + ) + .route( + &format!("{}/detection/chat", TEXT_API_PREFIX), + post(detect_chat), + ) + .route( + &format!("{}/detection/context", TEXT_API_PREFIX), + post(detect_context_documents), + ) + .route( + &format!("{}/detection/generated", TEXT_API_PREFIX), + post(detect_generated), + ); + + // If chat generation is configured, enable the chat completions detection endpoint. + if shared_state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + let app = router.with_state(shared_state).layer( + TraceLayer::new_for_http() + .make_span_with(utils::trace::incoming_request_span) + .on_request(utils::trace::on_incoming_request) + .on_response(utils::trace::on_outgoing_response) + .on_eos(utils::trace::on_outgoing_eos), + ); // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) @@ -271,58 +323,6 @@ pub fn get_health_app(state: Arc) -> Router { .with_state(state) } -pub fn get_app(state: Arc) -> Router { - let mut router = Router::new() - .route( - &format!("{}/classification-with-text-generation", API_PREFIX), - post(classification_with_gen), - ) - .route( - &format!( - "{}/server-streaming-classification-with-text-generation", - API_PREFIX - ), - post(stream_classification_with_gen), - ) - .route( - &format!("{}/generation-detection", TEXT_API_PREFIX), - post(generation_with_detection), - ) - .route( - &format!("{}/detection/content", TEXT_API_PREFIX), - post(detection_content), - ) - .route( - &format!("{}/detection/chat", TEXT_API_PREFIX), - post(detect_chat), - ) - .route( - &format!("{}/detection/context", TEXT_API_PREFIX), - post(detect_context_documents), - ) - .route( - &format!("{}/detection/generated", TEXT_API_PREFIX), - post(detect_generated), - ); - - // If chat generation is configured, enable the chat completions detection endpoint. - if state.orchestrator.config().chat_generation.is_some() { - info!("Enabling chat completions detection endpoint"); - router = router.route( - "/api/v2/chat/completions-detection", - post(chat_completions_detection), - ); - } - - router.with_state(state).layer( - TraceLayer::new_for_http() - .make_span_with(utils::trace::incoming_request_span) - .on_request(utils::trace::on_incoming_request) - .on_response(utils::trace::on_outgoing_response) - .on_eos(utils::trace::on_outgoing_eos), - ) -} - async fn health() -> Result { // NOTE: we are only adding the package information in the `health` endpoint to have this endpoint // provide a non empty 200 response. If we need to add more information regarding dependencies version From 49c3fea905eb5f866d30b8a5e1d58894b32cf76c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 19:09:53 -0300 Subject: [PATCH 052/117] Fix test_input_detector_returns_503() detector mock response status Signed-off-by: Mateus Devino --- tests/streaming_classification_with_gen_nlp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index 3f217c42..64716df9 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -696,7 +696,7 @@ async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { contents: vec!["This should return a 503".into()], detector_params: DetectorParams::new(), }), - MockResponse::json(&expected_detector_error).with_code(StatusCode::NOT_FOUND), + MockResponse::json(&expected_detector_error).with_code(StatusCode::SERVICE_UNAVAILABLE), ), ); From 6c5d5123e686ebfce45227f3a37f205af67836e7 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 13:57:26 -0300 Subject: [PATCH 053/117] refactor: move input to variable in tests/chunker.rs Signed-off-by: Mateus Devino --- tests/chunker.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index 75d7404f..197da90a 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -33,6 +33,7 @@ pub mod common; async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { // Add detector mock let chunker_id = "sentence_chunker"; + let input_test = "Hi there! how are you? I am great!"; let mut chunker_headers = HeaderMap::new(); chunker_headers.insert(MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); @@ -62,7 +63,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { - text: "Hi there! how are you? I am great!".into(), + text: input_test.into(), }) .with_headers(chunker_headers), MockResponse::pb(expected_response.clone()), @@ -85,7 +86,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { .tokenization_task_predict( chunker_id, ChunkerTokenizationTaskRequest { - text: "Hi there! how are you? I am great!".into(), + text: input_test.into(), }, ) .await; From 817945c7119383a5ad6b84cf6e1d50d83f5f53cb Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 27 Feb 2025 09:59:18 -0300 Subject: [PATCH 054/117] Document chunker::test_isolated_chunker_unary_call() Signed-off-by: Mateus Devino --- tests/chunker.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/chunker.rs b/tests/chunker.rs index 197da90a..ec8fcdd2 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -29,11 +29,13 @@ use test_log::test; pub mod common; +/// Asserts that the chunker client correctly invokes the chunker unary +/// endpoint. #[test(tokio::test)] async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { // Add detector mock let chunker_id = "sentence_chunker"; - let input_test = "Hi there! how are you? I am great!"; + let input_text = "Hi there! how are you? I am great!"; let mut chunker_headers = HeaderMap::new(); chunker_headers.insert(MODEL_ID_HEADER_NAME, chunker_id.parse().unwrap()); @@ -63,7 +65,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { - text: input_test.into(), + text: input_text.into(), }) .with_headers(chunker_headers), MockResponse::pb(expected_response.clone()), @@ -86,7 +88,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { .tokenization_task_predict( chunker_id, ChunkerTokenizationTaskRequest { - text: input_test.into(), + text: input_text.into(), }, ) .await; From 69d3363a1439add7ca624b800fd5f42502057699 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 27 Feb 2025 10:02:32 -0300 Subject: [PATCH 055/117] Rename tests/test.config.yaml Signed-off-by: Mateus Devino --- tests/canary_test.rs | 2 +- tests/common/orchestrator.rs | 2 +- tests/{test.config.yaml => test_config.yaml} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename tests/{test.config.yaml => test_config.yaml} (100%) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index f7412c45..19c37b3c 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -41,7 +41,7 @@ static ONCE: OnceCell> = OnceCell::const_new(); /// The actual async function that initializes the shared state if not already initialized async fn shared_state() -> Arc { - let config = OrchestratorConfig::load("tests/test.config.yaml") + let config = OrchestratorConfig::load("tests/test_config.yaml") .await .unwrap(); let orchestrator = Orchestrator::new(config, false).await.unwrap(); diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 432c888b..6251e2f7 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -37,7 +37,7 @@ use url::Url; use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; // Default orchestrator configuration file for integration tests. -pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test.config.yaml"; +pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test_config.yaml"; // Endpoints pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = diff --git a/tests/test.config.yaml b/tests/test_config.yaml similarity index 100% rename from tests/test.config.yaml rename to tests/test_config.yaml From 85fd2816428212d092995dc283f5c83a642d1874 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 27 Feb 2025 10:23:25 -0300 Subject: [PATCH 056/117] Document test cases on comments Signed-off-by: Mateus Devino --- tests/generation_nlp.rs | 1 + .../streaming_classification_with_gen_nlp.rs | 64 ++++++++++++------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 2e4234b5..26c40065 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -32,6 +32,7 @@ use test_log::test; pub mod common; +/// Asserts that the NlpClient correctly invokes the streaming endpoint. #[test(tokio::test)] async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { // Add detector mock diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index 64716df9..b0bf83c0 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -58,6 +58,29 @@ use tracing::debug; pub mod common; +// To troubleshoot tests with response deserialization errors, the following code +// snippet is recommended: +// // Example showing how to create an event stream from a bytes stream. +// let mut events = Vec::new(); +// let mut event_stream = response.bytes_stream().eventsource(); +// while let Some(event) = event_stream.next().await { +// match event { +// Ok(event) => { +// if event.data == "[DONE]" { +// break; +// } +// println!("recv: {event:?}"); +// events.push(event.data); +// } +// Err(_) => { +// panic!("received error from event stream"); +// } +// } +// } +// println!("{events:?}"); + +/// Asserts that given a request with no detectors configured returns the text generated +/// by the model. #[test(tokio::test)] async fn test_no_detectors() -> Result<(), anyhow::Error> { // Add generation mock @@ -123,26 +146,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { .send() .await?; - // // Example showing how to create an event stream from a bytes stream. - // let mut events = Vec::new(); - // let mut event_stream = response.bytes_stream().eventsource(); - // while let Some(event) = event_stream.next().await { - // match event { - // Ok(event) => { - // if event.data == "[DONE]" { - // break; - // } - // println!("recv: {event:?}"); - // events.push(event.data); - // } - // Err(_) => { - // panic!("received error from event stream"); - // } - // } - // } - // println!("{events:?}"); - - // Test custom SseStream wrapper + // Collects stream results let sse_stream: SseStream = SseStream::new(response.bytes_stream()); let messages = sse_stream @@ -161,6 +165,8 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that the generated text is returned when an input detector configured +/// with the whole_doc_chunker finds no detections. #[test(tokio::test)] async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -265,6 +271,8 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err Ok(()) } +/// Asserts that the generated text is returned when an input detector configured +/// with a sentence chunker finds no detections. #[test(tokio::test)] async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; @@ -406,6 +414,8 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh Ok(()) } +/// Asserts that detections found by an input detector configured with the whole_doc_chunker +/// are returned. #[test(tokio::test)] async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -526,6 +536,8 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E Ok(()) } +/// Asserts that detections found by an input detector configured with a sentence chunker +/// are returned. #[test(tokio::test)] async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), anyhow::Error> { // Add chunker mock @@ -678,6 +690,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an Ok(()) } +/// Asserts that 503 errors returned from detectors are correctly propagated. #[test(tokio::test)] async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -756,6 +769,7 @@ async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that 404 errors returned from detectors are correctly propagated. #[test(tokio::test)] async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -834,6 +848,7 @@ async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that 500 errors returned from detectors are correctly propagated. #[test(tokio::test)] async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -906,8 +921,10 @@ async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 500 is returned when a detector returns a message that does not comply +/// with the detector API. #[test(tokio::test)] -async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { +async fn test_input_detector_returns_invalid_message() -> Result<(), anyhow::Error> { // ensure_global_rustls_state(); let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; let model_id = "my-super-model-8B"; @@ -980,6 +997,7 @@ async fn test_input_detector_returns_non_compliant_message() -> Result<(), anyho Ok(()) } +/// Asserts error 500 is returned when an input chunker returns an error. #[test(tokio::test)] async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; @@ -1052,6 +1070,7 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { Ok(()) } +// Asserts error 500 is returned when generation server returns an error. #[test(tokio::test)] async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -1139,8 +1158,9 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> Ok(()) } +/// Asserts error 422 is returned when the orchestrator request has extra fields. #[test(tokio::test)] -async fn test_orchestrator_receives_a_non_compliant_request() -> Result<(), anyhow::Error> { +async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Error> { let model_id = "my-super-model-8B"; // Run test orchestrator server From e5390d2a05f00b25b2ac3b19e9ceae0adaa07233 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 27 Feb 2025 10:37:35 -0300 Subject: [PATCH 057/117] Add assertions to make sure no detections are returned Signed-off-by: Mateus Devino --- tests/streaming_classification_with_gen_nlp.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index b0bf83c0..3c2ef7a0 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -265,8 +265,13 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // assertions assert!(messages.len() == 3); assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[0].token_classification_results.input == None); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[1].token_classification_results.input == None); + assert!(messages[2].generated_text == Some(" great!".into())); + assert!(messages[2].token_classification_results.input == None); Ok(()) } @@ -408,8 +413,13 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh // assertions assert!(messages.len() == 3); assert!(messages[0].generated_text == Some("I".into())); + assert!(messages[0].token_classification_results.input == None); + assert!(messages[1].generated_text == Some(" am".into())); + assert!(messages[1].token_classification_results.input == None); + assert!(messages[2].generated_text == Some(" great!".into())); + assert!(messages[2].token_classification_results.input == None); Ok(()) } From 6c46d6cd724f3f8e7f4f9f933e8b9228be7ce0ad Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 28 Feb 2025 10:17:23 -0300 Subject: [PATCH 058/117] Update mocktail to 0.1.2-alpha Signed-off-by: Mateus Devino --- Cargo.lock | 10 +-- Cargo.toml | 2 +- tests/chunker.rs | 6 +- tests/common/chunker.rs | 7 -- tests/common/generation.rs | 4 -- tests/common/orchestrator.rs | 16 ++--- tests/detection_content.rs | 12 ++-- tests/generation_nlp.rs | 8 +-- .../streaming_classification_with_gen_nlp.rs | 69 +++++++++---------- 9 files changed, 59 insertions(+), 75 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8aa27eab..2168958c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1448,7 +1448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -1581,23 +1581,25 @@ dependencies = [ [[package]] name = "mocktail" -version = "0.1.0-alpha" -source = "git+https://github.com/IBM/mocktail#f635426acdbfa42e58067319d76e489bec64f825" +version = "0.1.2-alpha" +source = "git+https://github.com/IBM/mocktail#9c6e3502579ca7d5c56e9cf1b3c99d9ea01e297d" dependencies = [ "bytes", "futures", + "h2", "http", "http-body", "http-body-util", "hyper", "hyper-util", "prost", - "rand 0.8.5", + "rand 0.9.0", "reqwest", "serde", "serde_json", "thiserror 2.0.11", "tokio", + "tokio-stream", "tonic", "tracing", "url", diff --git a/Cargo.toml b/Cargo.toml index 71ce9223..7873c10d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,7 +90,7 @@ tonic-build = "0.12.3" [dev-dependencies] axum-test = "17.1.0" faux = "0.1.12" -mocktail = { git = "https://github.com/IBM/mocktail", version = "0.1.0-alpha" } +mocktail = { git = "https://github.com/IBM/mocktail", version = "0.1.2-alpha" } test-log = "0.2.17" [profile.release] diff --git a/tests/chunker.rs b/tests/chunker.rs index ec8fcdd2..080863bd 100644 --- a/tests/chunker.rs +++ b/tests/chunker.rs @@ -15,7 +15,7 @@ */ -use common::chunker::{MockChunkersServiceServer, CHUNKER_UNARY_ENDPOINT}; +use common::chunker::CHUNKER_UNARY_ENDPOINT; use fms_guardrails_orchestr8::{ clients::chunker::{ChunkerClient, MODEL_ID_HEADER_NAME}, config::ServiceConfig, @@ -62,7 +62,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: input_text.into(), @@ -72,7 +72,7 @@ async fn test_isolated_chunker_unary_call() -> Result<(), anyhow::Error> { ), ); - let mock_chunker_server = MockChunkersServiceServer::new(mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, mocks)?; let _ = mock_chunker_server.start().await; let client = ChunkerClient::new(&ServiceConfig { diff --git a/tests/common/chunker.rs b/tests/common/chunker.rs index 05042055..948788f0 100644 --- a/tests/common/chunker.rs +++ b/tests/common/chunker.rs @@ -14,13 +14,6 @@ limitations under the License. */ -use mocktail::generate_grpc_server; -use mocktail::mock::MockSet; - -generate_grpc_server!( - "caikit.runtime.Chunkers.ChunkersService", - MockChunkersServiceServer -); // Chunker names pub const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; diff --git a/tests/common/generation.rs b/tests/common/generation.rs index a68a127c..5a442c68 100644 --- a/tests/common/generation.rs +++ b/tests/common/generation.rs @@ -14,10 +14,6 @@ limitations under the License. */ -use mocktail::generate_grpc_server; -use mocktail::mock::MockSet; - -generate_grpc_server!("caikit.runtime.Nlp.NlpService", MockNlpServiceServer); // NLP generation server endpoints pub const GENERATION_NLP_STREAMING_ENDPOINT: &str = diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 6251e2f7..9a4659b3 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -28,14 +28,12 @@ use bytes::Bytes; use eventsource_stream::{EventStream, Eventsource}; use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; use futures::{stream::BoxStream, Stream, StreamExt}; -use mocktail::server::HttpMockServer; +use mocktail::server::{GrpcMockServer, HttpMockServer}; use rustls::crypto::ring; use serde::de::DeserializeOwned; use tokio::task::JoinHandle; use url::Url; -use super::{chunker::MockChunkersServiceServer, generation::MockNlpServiceServer}; - // Default orchestrator configuration file for integration tests. pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test_config.yaml"; @@ -65,10 +63,10 @@ impl TestOrchestratorServer { config_path: impl AsRef, port: u16, health_port: u16, - generation_server: Option, + generation_server: Option, chat_generation_server: Option, detector_servers: Option>, - chunker_servers: Option>, + chunker_servers: Option>, ) -> Result { // Set default crypto provider ensure_global_rustls_state(); @@ -87,7 +85,7 @@ impl TestOrchestratorServer { } async fn initialize_generation_server( - generation_server: Option, + generation_server: Option, config: &mut OrchestratorConfig, ) -> Result<(), anyhow::Error> { Ok(if let Some(generation_server) = generation_server { @@ -128,17 +126,17 @@ impl TestOrchestratorServer { } async fn initialize_chunkers( - chunker_servers: Option>, + chunker_servers: Option>, config: &mut OrchestratorConfig, ) -> Result<(), anyhow::Error> { Ok(if let Some(chunker_servers) = chunker_servers { - for (name, chunker_server) in chunker_servers { + for chunker_server in chunker_servers { chunker_server.start().await?; config .chunkers .as_mut() .unwrap() - .get_mut(&name) + .get_mut(chunker_server.name()) .unwrap() .service .port = Some(chunker_server.addr().port()); diff --git a/tests/detection_content.rs b/tests/detection_content.rs index 8da3980b..49557853 100644 --- a/tests/detection_content.rs +++ b/tests/detection_content.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use test_log::test; use common::{ - chunker::{MockChunkersServiceServer, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, + chunker::{CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, detectors::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, @@ -56,7 +56,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This sentence has .".into()], @@ -133,7 +133,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "This sentence does not have a detection. But .".into(), @@ -161,7 +161,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -187,7 +187,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, @@ -196,7 +196,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { None, None, Some(vec![mock_detector_server]), - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; diff --git a/tests/generation_nlp.rs b/tests/generation_nlp.rs index 26c40065..c9bffa1a 100644 --- a/tests/generation_nlp.rs +++ b/tests/generation_nlp.rs @@ -15,9 +15,7 @@ */ -use common::generation::{ - MockNlpServiceServer, GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_STREAMING_ENDPOINT, -}; +use common::generation::{GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_STREAMING_ENDPOINT}; use fms_guardrails_orchestr8::{ clients::NlpClient, config::ServiceConfig, @@ -60,7 +58,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::post(GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { text: "Hi there! How are you?".into(), @@ -71,7 +69,7 @@ async fn test_nlp_streaming_call() -> Result<(), anyhow::Error> { ), ); - let generation_nlp_server = MockNlpServiceServer::new(mocks)?; + let generation_nlp_server = GrpcMockServer::new("nlp", mocks)?; generation_nlp_server.start().await?; let client = NlpClient::new(&ServiceConfig { diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index 3c2ef7a0..98dcf40b 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -19,18 +19,15 @@ use std::collections::HashMap; use test_log::test; use common::{ - chunker::{ - MockChunkersServiceServer, CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE, - CHUNKER_UNARY_ENDPOINT, - }, + chunker::{CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, detectors::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, generation::{ - MockNlpServiceServer, GENERATION_NLP_MODEL_ID_HEADER_NAME, - GENERATION_NLP_STREAMING_ENDPOINT, GENERATION_NLP_TOKENIZATION_ENDPOINT, + GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_STREAMING_ENDPOINT, + GENERATION_NLP_TOKENIZATION_ENDPOINT, }, orchestrator::{ SseStream, TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, @@ -108,7 +105,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::post(GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { text: "Hi there! How are you?".into(), @@ -120,7 +117,7 @@ async fn test_no_detectors() -> Result<(), anyhow::Error> { ); // Configure mock servers - let generation_server = MockNlpServiceServer::new(mocks)?; + let generation_server = GrpcMockServer::new("nlp", mocks)?; // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::run( @@ -174,7 +171,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["Hi there! How are you?".into()], @@ -209,7 +206,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::post(GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { text: "Hi there! How are you?".into(), @@ -222,7 +219,7 @@ async fn test_input_detector_whole_doc_no_detections() -> Result<(), anyhow::Err // Start orchestrator server and its dependencies let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; - let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let generation_server = GrpcMockServer::new("nlp", generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -289,7 +286,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "Hi there! How are you?".into(), @@ -316,7 +313,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["Hi there!".into(), " How are you?".into()], @@ -354,7 +351,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::post(GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { text: "Hi there! How are you?".into(), @@ -366,9 +363,9 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; - let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let generation_server = GrpcMockServer::new("nlp", generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -376,7 +373,7 @@ async fn test_input_detector_sentence_chunker_no_detections() -> Result<(), anyh Some(generation_server), None, Some(vec![mock_detector_server]), - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; @@ -443,7 +440,7 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -468,7 +465,7 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E ); let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), + MockPath::post(GENERATION_NLP_TOKENIZATION_ENDPOINT), Mock::new( MockRequest::pb(TokenizationTaskRequest { text: "This sentence does not have a detection. But .".into(), @@ -481,7 +478,7 @@ async fn test_input_detector_whole_doc_with_detections() -> Result<(), anyhow::E // Start orchestrator server and its dependencies let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; - let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let generation_server = GrpcMockServer::new("nlp", generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -557,7 +554,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "This sentence does not have a detection. But .".into(), @@ -595,7 +592,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an }; let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -621,7 +618,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an ); let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_TOKENIZATION_ENDPOINT), + MockPath::post(GENERATION_NLP_TOKENIZATION_ENDPOINT), Mock::new( MockRequest::pb(TokenizationTaskRequest { text: "This sentence does not have a detection. But .".into(), @@ -633,9 +630,9 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; - let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let generation_server = GrpcMockServer::new("nlp", generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -643,7 +640,7 @@ async fn test_input_detector_sentence_chunker_with_detections() -> Result<(), an Some(generation_server), None, Some(vec![mock_detector_server]), - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; @@ -713,7 +710,7 @@ async fn test_input_detector_returns_503() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 503".into()], @@ -792,7 +789,7 @@ async fn test_input_detector_returns_404() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 404".into()], @@ -871,7 +868,7 @@ async fn test_input_detector_returns_500() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 500".into()], @@ -945,7 +942,7 @@ async fn test_input_detector_returns_invalid_message() -> Result<(), anyhow::Err // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -1020,7 +1017,7 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "Hi there! How are you?".into(), @@ -1031,7 +1028,7 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -1039,7 +1036,7 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { None, None, None, - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; @@ -1088,7 +1085,7 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["Hi there! How are you?".into()], @@ -1108,7 +1105,7 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> let mut generation_mocks = MockSet::new(); generation_mocks.insert( - MockPath::new(Method::POST, GENERATION_NLP_STREAMING_ENDPOINT), + MockPath::post(GENERATION_NLP_STREAMING_ENDPOINT), Mock::new( MockRequest::pb(ServerStreamingTextGenerationTaskRequest { text: "Hi there! How are you?".into(), @@ -1121,7 +1118,7 @@ async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> // Start orchestrator server and its dependencies let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; - let generation_server = MockNlpServiceServer::new(generation_mocks)?; + let generation_server = GrpcMockServer::new("nlp", generation_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), From 517e9804dc74e73db42807bd5fe23bd35a24338a Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 09:02:07 -0300 Subject: [PATCH 059/117] Fix comments Signed-off-by: Mateus Devino --- tests/streaming_classification_with_gen_nlp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming_classification_with_gen_nlp.rs b/tests/streaming_classification_with_gen_nlp.rs index 98dcf40b..3350cee3 100644 --- a/tests/streaming_classification_with_gen_nlp.rs +++ b/tests/streaming_classification_with_gen_nlp.rs @@ -1077,7 +1077,7 @@ async fn test_input_chunker_returns_an_error() -> Result<(), anyhow::Error> { Ok(()) } -// Asserts error 500 is returned when generation server returns an error. +/// Asserts error 500 is returned when generation server returns an error. #[test(tokio::test)] async fn test_generation_server_returns_an_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; From 6adcac9a288dc9d5a0d849a1410f7578524ddff0 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 3 Jan 2025 11:27:34 -0300 Subject: [PATCH 060/117] Refactor common integration test code Signed-off-by: Mateus Devino --- tests/canary_test.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 19c37b3c..36f1de83 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -31,7 +31,6 @@ use fms_guardrails_orchestr8::{ }; use hyper::StatusCode; use serde_json::Value; -use tokio::sync::OnceCell; use tracing::debug; pub mod common; From 0c07875940f2a7bac3e953818b05f1b36f488f42 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 6 Jan 2025 14:21:17 -0300 Subject: [PATCH 061/117] /detection/content base test case Signed-off-by: Mateus Devino --- src/server.rs | 106 +++++++++++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/src/server.rs b/src/server.rs index 28d9ad93..d9bbe93e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -166,59 +166,7 @@ pub async fn run( } // (2b) Add main guardrails server routes - let mut router = Router::new() - .route( - &format!("{}/classification-with-text-generation", API_PREFIX), - post(classification_with_gen), - ) - .route( - &format!("{}/detection/stream-content", TEXT_API_PREFIX), - post(stream_content_detection), - ) - .route( - &format!( - "{}/server-streaming-classification-with-text-generation", - API_PREFIX - ), - post(stream_classification_with_gen), - ) - .route( - &format!("{}/generation-detection", TEXT_API_PREFIX), - post(generation_with_detection), - ) - .route( - &format!("{}/detection/content", TEXT_API_PREFIX), - post(detection_content), - ) - .route( - &format!("{}/detection/chat", TEXT_API_PREFIX), - post(detect_chat), - ) - .route( - &format!("{}/detection/context", TEXT_API_PREFIX), - post(detect_context_documents), - ) - .route( - &format!("{}/detection/generated", TEXT_API_PREFIX), - post(detect_generated), - ); - - // If chat generation is configured, enable the chat completions detection endpoint. - if shared_state.orchestrator.config().chat_generation.is_some() { - info!("Enabling chat completions detection endpoint"); - router = router.route( - "/api/v2/chat/completions-detection", - post(chat_completions_detection), - ); - } - - let app = router.with_state(shared_state).layer( - TraceLayer::new_for_http() - .make_span_with(utils::trace::incoming_request_span) - .on_request(utils::trace::on_incoming_request) - .on_response(utils::trace::on_outgoing_response) - .on_eos(utils::trace::on_outgoing_eos), - ); + let app = get_app(shared_state); // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) @@ -323,6 +271,58 @@ pub fn get_health_app(state: Arc) -> Router { .with_state(state) } +pub fn get_app(state: Arc) -> Router { + let mut router = Router::new() + .route( + &format!("{}/classification-with-text-generation", API_PREFIX), + post(classification_with_gen), + ) + .route( + &format!( + "{}/server-streaming-classification-with-text-generation", + API_PREFIX + ), + post(stream_classification_with_gen), + ) + .route( + &format!("{}/generation-detection", TEXT_API_PREFIX), + post(generation_with_detection), + ) + .route( + &format!("{}/detection/content", TEXT_API_PREFIX), + post(detection_content), + ) + .route( + &format!("{}/detection/chat", TEXT_API_PREFIX), + post(detect_chat), + ) + .route( + &format!("{}/detection/context", TEXT_API_PREFIX), + post(detect_context_documents), + ) + .route( + &format!("{}/detection/generated", TEXT_API_PREFIX), + post(detect_generated), + ); + + // If chat generation is configured, enable the chat completions detection endpoint. + if state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + router.with_state(state).layer( + TraceLayer::new_for_http() + .make_span_with(utils::trace::incoming_request_span) + .on_request(utils::trace::on_incoming_request) + .on_response(utils::trace::on_outgoing_response) + .on_eos(utils::trace::on_outgoing_eos), + ) +} + async fn health() -> Result { // NOTE: we are only adding the package information in the `health` endpoint to have this endpoint // provide a non empty 200 response. If we need to add more information regarding dependencies version From 4762954d5120819b93483a550a0849b620b603a7 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:06:40 -0300 Subject: [PATCH 062/117] Rename tests/detection_content.rs Signed-off-by: Mateus Devino --- tests/{detection_content.rs => text_content_detection.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{detection_content.rs => text_content_detection.rs} (100%) diff --git a/tests/detection_content.rs b/tests/text_content_detection.rs similarity index 100% rename from tests/detection_content.rs rename to tests/text_content_detection.rs From 09b2510cf6ce061b70028974ffc84d76a2e63a2c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:18:16 -0300 Subject: [PATCH 063/117] test case: test_no_detection_whole_doc() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 52 +++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 49557853..356298c1 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -46,6 +46,58 @@ use tracing::debug; pub mod common; +#[test(tokio::test)] +async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This sentence has no detections.".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![Vec::::new()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This sentence has no detections.".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == TextContentDetectionResult::default() + ); + + Ok(()) +} + /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects text between . From aba1e1c94682477c31d569750e0fee592b1d68e9 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:38:22 -0300 Subject: [PATCH 064/117] test case: test_no_detection_sentence_chunker() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 88 +++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 356298c1..6a7788ff 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -98,6 +98,94 @@ async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { Ok(()) } +#[test(tokio::test)] +async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { + // Add chunker mock + let chunker_id = CHUNKER_NAME_SENTENCE; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "This sentence does not have a detection. Neither does this one.".into(), + }) + .with_headers(chunker_headers), + MockResponse::pb(TokenizationResults { + results: vec![ + Token { + start: 0, + end: 40, + text: "This sentence does not have a detection.".into(), + }, + Token { + start: 41, + end: 64, + text: "Neither does this one.".into(), + }, + ], + token_count: 0, + }), + ), + ); + + // Add detector mock + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec![ + "This sentence does not have a detection.".into(), + "Neither does this one.".into(), + ], + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![ + Vec::::new(), + Vec::::new(), + ]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + Some(vec![(chunker_id.into(), mock_chunker_server)]), + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This sentence does not have a detection. Neither does this one.".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == TextContentDetectionResult::default() + ); + + Ok(()) +} + /// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). /// /// This test mocks a detector that detects text between . From 70be320cc3a5709387001cce5269dc4e3e00c8fe Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:50:10 -0300 Subject: [PATCH 065/117] test case: test_detector_returns_503() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 63 +++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 6a7788ff..569488c4 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -24,6 +24,7 @@ use common::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, + errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, @@ -372,3 +373,65 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_503() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let expected_detector_error = DetectorError { + code: 503, + message: "The detector service is overloaded.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 503".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error).with_code(StatusCode::SERVICE_UNAVAILABLE), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This should return a 503".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // assertions + assert!(response.status() == StatusCode::SERVICE_UNAVAILABLE); + + let response: OrchestratorError = response.json().await?; + assert!(response.code == 503); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, expected_detector_error.message + ) + ); + + Ok(()) +} From 9f1100facc9b87eb70c6669494ff6115a27e3a76 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:53:21 -0300 Subject: [PATCH 066/117] test case: test_detector_returns_404() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 62 +++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 569488c4..cf05e41c 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -435,3 +435,65 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_404() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let expected_detector_error = DetectorError { + code: 404, + message: "The detector service was not found.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 404".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This should return a 404".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // assertions + assert!(response.status() == StatusCode::NOT_FOUND); + + let response: OrchestratorError = response.json().await?; + assert!(response.code == 404); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, expected_detector_error.message + ) + ); + + Ok(()) +} From 203bcebe24b7f39a2050f0e2b8176a68049d9a1e Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:59:23 -0300 Subject: [PATCH 067/117] test case: test_detector_returns_500() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 59 ++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index cf05e41c..27b7d283 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -27,7 +27,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, + ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, }, }; use fms_guardrails_orchestr8::{ @@ -497,3 +497,60 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_500() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + let expected_detector_error = DetectorError { + code: 500, + message: "Internal error on detector call.".into(), + }; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a 500".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&expected_detector_error) + .with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This should return a 500".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + + let response: OrchestratorError = response.json().await?; + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From 412471441245abbadf1eb9136dc0b10869bac30b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 15:07:53 -0300 Subject: [PATCH 068/117] test case: test_detector_returns_non_compliant_message() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 27b7d283..9d893eca 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -15,6 +15,7 @@ */ +use serde_json::json; use std::collections::HashMap; use test_log::test; @@ -554,3 +555,57 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.insert( + MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContentAnalysisRequest { + contents: vec!["This should return a non-compliant message".into()], + detector_params: DetectorParams::new(), + }), + MockResponse::json(&json!({ + "my_detection": "This message does not comply with the expected API" + })), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, detection_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This should return a non-compliant message".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + + let response: OrchestratorError = response.json().await?; + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From f997f1d94f6d1a19e0ce3b442866bfcedbf97d9d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 15:41:48 -0300 Subject: [PATCH 069/117] test case: test_chunker_returns_an_error() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 9d893eca..8ca2d6b9 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -609,3 +609,58 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err Ok(()) } + +#[test(tokio::test)] +async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + + // Add input chunker mock + let chunker_id = CHUNKER_NAME_SENTENCE; + let mut chunker_headers = HeaderMap::new(); + chunker_headers.insert(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id.parse()?); + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.insert( + MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + Mock::new( + MockRequest::pb(ChunkerTokenizationTaskRequest { + text: "This should return a 500".into(), + }) + .with_headers(chunker_headers), + MockResponse::empty().with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + Some(vec![(chunker_id.into(), mock_chunker_server)]), + ) + .await?; + + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&TextContentDetectionHttpRequest { + content: "This should return a 500".into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + + let response: OrchestratorError = response.json().await?; + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From 5aa899a685673ea34fa5b6e393ee6655cad07a36 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 15:54:09 -0300 Subject: [PATCH 070/117] test case: test_request_with_extra_fields_returns_422() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 8ca2d6b9..67345caf 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -664,3 +664,44 @@ async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "content": "This sentence has no detections.", + "detectors": {detector_name: {}}, + "extra_args": true + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + + assert!(response.code == 422); + assert!(response.details.contains("unknown field `extra_args`")); + + Ok(()) +} From cf3e8da1ab2ef73276886f143aeb7c2eb2cb6658 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 16:09:04 -0300 Subject: [PATCH 071/117] test case: test_request_missing_detectors_field_returns_422() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 37 +++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 67345caf..d3fb9360 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -705,3 +705,40 @@ async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Erro Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_detectors_field_returns_422() -> Result<(), anyhow::Error> { + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "content": "This sentence has no detections.", + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + + assert!(response.code == 422); + assert!(response.details.starts_with("missing field `detectors`")); + + Ok(()) +} From fb5b5da1d7365546431ca7c8a67eafbdd40bd3d1 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 16:11:43 -0300 Subject: [PATCH 072/117] test case: test_request_missing_content_field_returns_422() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index d3fb9360..90689818 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -742,3 +742,41 @@ async fn test_request_missing_detectors_field_returns_422() -> Result<(), anyhow Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_content_field_returns_422() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + + assert!(response.code == 422); + assert!(response.details.starts_with("missing field `content`")); + + Ok(()) +} From d3215bc98f63f9a7ddf998a3fc50f7b76d35f6a9 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 16:15:21 -0300 Subject: [PATCH 073/117] test case: test_request_with_empty_detectors_field_returns_422() Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 90689818..1f093ad1 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -780,3 +780,41 @@ async fn test_request_missing_content_field_returns_422() -> Result<(), anyhow:: Ok(()) } + +#[test(tokio::test)] +async fn test_request_with_empty_detectors_field_returns_422() -> Result<(), anyhow::Error> { + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + None, + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "content": "This sentence has no detections.", + "detectors": {}, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + + assert!(response.code == 422); + assert!(response.details == "`detectors` is required"); + + Ok(()) +} From adb7b50e347896bb1819d758122fdb479c8ce364 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 11:40:00 -0300 Subject: [PATCH 074/117] Add missing import to canary_test.rs Signed-off-by: Mateus Devino --- tests/canary_test.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/canary_test.rs b/tests/canary_test.rs index 36f1de83..19c37b3c 100644 --- a/tests/canary_test.rs +++ b/tests/canary_test.rs @@ -31,6 +31,7 @@ use fms_guardrails_orchestr8::{ }; use hyper::StatusCode; use serde_json::Value; +use tokio::sync::OnceCell; use tracing::debug; pub mod common; From 2ce212cf2843dbcff3493da427ac178aea720976 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 28 Feb 2025 10:23:56 -0300 Subject: [PATCH 075/117] Update tests to use mocktail 0.1.2-alpha Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 1f093ad1..2f87ac22 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -55,7 +55,7 @@ async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This sentence has no detections.".into()], @@ -109,7 +109,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "This sentence does not have a detection. Neither does this one.".into(), @@ -137,7 +137,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec![ @@ -154,7 +154,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, @@ -163,7 +163,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { None, None, Some(vec![mock_detector_server]), - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; @@ -386,7 +386,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 503".into()], @@ -448,7 +448,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 404".into()], @@ -510,7 +510,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a 500".into()], @@ -563,7 +563,7 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err // Add input detection mock let mut detection_mocks = MockSet::new(); detection_mocks.insert( - MockPath::new(Method::POST, TEXT_CONTENTS_DETECTOR_ENDPOINT), + MockPath::post(TEXT_CONTENTS_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContentAnalysisRequest { contents: vec!["This should return a non-compliant message".into()], @@ -621,7 +621,7 @@ async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { let mut chunker_mocks = MockSet::new(); chunker_mocks.insert( - MockPath::new(Method::POST, CHUNKER_UNARY_ENDPOINT), + MockPath::post(CHUNKER_UNARY_ENDPOINT), Mock::new( MockRequest::pb(ChunkerTokenizationTaskRequest { text: "This should return a 500".into(), @@ -632,7 +632,7 @@ async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { ); // Start orchestrator server and its dependencies - let mock_chunker_server = MockChunkersServiceServer::new(chunker_mocks)?; + let mock_chunker_server = GrpcMockServer::new(chunker_id, chunker_mocks)?; let orchestrator_server = TestOrchestratorServer::run( ORCHESTRATOR_CONFIG_FILE_PATH, find_available_port().unwrap(), @@ -640,7 +640,7 @@ async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { None, None, None, - Some(vec![(chunker_id.into(), mock_chunker_server)]), + Some(vec![mock_chunker_server]), ) .await?; From bd4e07a7d1bc3dd5e4b841683eaf1cf3ee9732ba Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 28 Feb 2025 12:13:03 -0300 Subject: [PATCH 076/117] Add comments to text_content tests Signed-off-by: Mateus Devino --- tests/text_content_detection.rs | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 2f87ac22..4b11b4d5 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -48,6 +48,7 @@ use tracing::debug; pub mod common; +/// Asserts that generated text with no detections is returned (detector configured with whole_doc_chunker). #[test(tokio::test)] async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -88,7 +89,7 @@ async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::OK); @@ -100,6 +101,7 @@ async fn test_no_detection_whole_doc() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that generated text with no detections is returned (detector configured with a sentence chunker). #[test(tokio::test)] async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { // Add chunker mock @@ -177,7 +179,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); assert!(response.status() == StatusCode::OK); assert!( @@ -188,9 +190,7 @@ async fn test_no_detection_sentence_chunker() -> Result<(), anyhow::Error> { Ok(()) } -/// Asserts a scenario with a single detection works as expected (assumes a detector configured with whole_doc_chunker). -/// -/// This test mocks a detector that detects text between . +/// Asserts that detections are returned (detector configured with whole_doc_chunker). #[test(tokio::test)] async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -240,7 +240,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::OK); @@ -263,9 +263,7 @@ async fn test_single_detection_whole_doc() -> Result<(), anyhow::Error> { Ok(()) } -/// Asserts a scenario with a single detection works as expected (with sentence chunker). -/// -/// This test mocks a detector that detects text between . +/// Asserts that detections are returned (detector configured with a sentence chunker). #[test(tokio::test)] async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { // Add chunker mock @@ -352,7 +350,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::OK); @@ -375,6 +373,7 @@ async fn test_single_detection_sentence_chunker() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that 503 errors returned by detectors are correctly propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -437,6 +436,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that 404 errors returned by detectors are correctly propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -499,6 +499,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that 500 errors returned by detectors are correctly propagated. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -556,8 +557,9 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts that error 500 is returned when a detector return an invalid message. #[test(tokio::test)] -async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { +async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; // Add input detection mock @@ -610,6 +612,7 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err Ok(()) } +/// Asserts that error 500 is returned upon chunker failure. #[test(tokio::test)] async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; @@ -665,6 +668,7 @@ async fn test_chunker_returns_an_error() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 422 is returned when orchestrator request contains extra fields. #[test(tokio::test)] async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -706,6 +710,8 @@ async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Erro Ok(()) } +/// Asserts error 422 is returned when orchestrator request does not contain `detectors` +/// field. #[test(tokio::test)] async fn test_request_missing_detectors_field_returns_422() -> Result<(), anyhow::Error> { // Start orchestrator server and its dependencies @@ -743,6 +749,8 @@ async fn test_request_missing_detectors_field_returns_422() -> Result<(), anyhow Ok(()) } +/// Asserts error 422 is returned when orchestrator request does not contain `content` +/// field. #[test(tokio::test)] async fn test_request_missing_content_field_returns_422() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; @@ -781,6 +789,7 @@ async fn test_request_missing_content_field_returns_422() -> Result<(), anyhow:: Ok(()) } +/// Asserts error 422 is returned when `detectors` is empty in orchestrator request. #[test(tokio::test)] async fn test_request_with_empty_detectors_field_returns_422() -> Result<(), anyhow::Error> { // Start orchestrator server and its dependencies From f7ece50ff5f4ebcc0dd6787cd570529b562ec1a9 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 15:36:08 -0300 Subject: [PATCH 077/117] test case: detection_on_generation::test_detection_above_default_threshold() Signed-off-by: Mateus Devino --- tests/common/detectors.rs | 2 + tests/common/orchestrator.rs | 1 + tests/detection_on_generation.rs | 105 +++++++++++++++++++++++++++++++ tests/test_config.yaml | 6 ++ 4 files changed, 114 insertions(+) create mode 100644 tests/detection_on_generation.rs diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index c882bdeb..04ec7633 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -18,6 +18,8 @@ // Detector names pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; +pub const ANSWER_RELEVANCE_DETECTOR: &str = "answer_relevance_detector"; // Detector endpoints pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; +pub const DETECTION_ON_GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 9a4659b3..14404cf7 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -41,6 +41,7 @@ pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test_config.yaml"; pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = "/api/v1/task/server-streaming-classification-with-text-generation"; pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; +pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/detection/generated"; // Messages pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs new file mode 100644 index 00000000..a7afde29 --- /dev/null +++ b/tests/detection_on_generation.rs @@ -0,0 +1,105 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::collections::HashMap; +use test_log::test; + +use common::{ + detectors::{ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT}, + orchestrator::{ + TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, + ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, + }, +}; +use fms_guardrails_orchestr8::{ + clients::detector::GenerationDetectionRequest, + models::{ + DetectionOnGeneratedHttpRequest, DetectionOnGenerationResult, DetectionResult, + DetectorParams, + }, +}; +use hyper::StatusCode; +use mocktail::{prelude::*, utils::find_available_port}; +use tracing::debug; + +pub mod common; + +#[test(tokio::test)] +async fn test_detection_above_default_threshold() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + let detection = DetectionResult { + detection_type: "relevance".into(), + detection: "is_relevant".into(), + detector_id: Some(detector_name.into()), + score: 0.89, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == DetectionOnGenerationResult { + detections: vec![detection.clone()] + } + ); + + Ok(()) +} diff --git a/tests/test_config.yaml b/tests/test_config.yaml index e6149d9f..4194f9ba 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -33,3 +33,9 @@ detectors: hostname: localhost chunker_id: whole_doc_chunker default_threshold: 0.5 + answer_relevance_detector: + type: text_generation + service: + hostname: localhost + chunker_id: whole_doc_chunker + default_threshold: 0.5 From cdeaeea76457ec960ed713ae507fca5f6c514ef2 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:04:51 -0300 Subject: [PATCH 078/117] test case: detection_on_generation::test_detection_below_default_threshold_is_not_returned() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 65 +++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index a7afde29..4f404b74 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -39,7 +39,70 @@ use tracing::debug; pub mod common; #[test(tokio::test)] -async fn test_detection_above_default_threshold() -> Result<(), anyhow::Error> { +async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = "The average height of women is 159cm (or 5'3'')."; + let detection = DetectionResult { + detection_type: "relevance".into(), + detection: "is_relevant".into(), + detector_id: Some(detector_name.into()), + score: 0.49, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == DetectionOnGenerationResult { detections: vec![] } + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; let prompt = "In 2014, what was the average height of men who were born in 1996?"; let generated_text = From 9e3b67b04caff8924d4d18c063c9ec5e9c3fa3cc Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:20:32 -0300 Subject: [PATCH 079/117] test case: detection_on_generation::test_detector_returns_503() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 4f404b74..fd463364 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -20,6 +20,7 @@ use test_log::test; use common::{ detectors::{ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT}, + errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, @@ -166,3 +167,69 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_503() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + let detector_error = DetectorError { + code: 503, + message: "The detector is overloaded.".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::SERVICE_UNAVAILABLE), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::SERVICE_UNAVAILABLE); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From 830f9a165013793daa32ad84059d11e53ec5f78b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:22:19 -0300 Subject: [PATCH 080/117] test case: detection_on_generation::test_detector_returns_404() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index fd463364..4642cc97 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -233,3 +233,69 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_404() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + let detector_error = DetectorError { + code: 404, + message: "The detector is overloaded.".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::NOT_FOUND); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From bce22ca15b662dbdbeefd29f1d10765e7e659c0b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:24:35 -0300 Subject: [PATCH 081/117] test case: detection_on_generation::test_detector_returns_500() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 62 +++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 4642cc97..89abd090 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -23,7 +23,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, + ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, }, }; use fms_guardrails_orchestr8::{ @@ -299,3 +299,63 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_500() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + let detector_error = DetectorError { + code: 500, + message: "The detector is overloaded.".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From f329e020341dea38a49686660a43e537d9f25b66 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:30:21 -0300 Subject: [PATCH 082/117] test case: detection_on_generation::test_detector_returns_non_compliant_message() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 59 ++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 89abd090..91b6dff6 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -15,6 +15,7 @@ */ +use serde_json::json; use std::collections::HashMap; use test_log::test; @@ -359,3 +360,61 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(GenerationDetectionRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&json!({ + "my_detection": "This message does not comply with the expected API" + })), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + }) + .send() + .await?; + + debug!(?response); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From d894740ba734fe61a3c45c58dbd4d93abc5ee4f7 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:45:27 -0300 Subject: [PATCH 083/117] test case: detection_on_generation::test_request_with_extra_fields_returns_422() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 44 ++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 91b6dff6..b93afa8e 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -418,3 +418,47 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err Ok(()) } + +#[test(tokio::test)] +async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&json!({ + "prompt": prompt, + "generated_text": generated_text, + "detectors": {detector_name: {}}, + "extra_args": true + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("unknown field `extra_args`")); + + Ok(()) +} From 8c1022147b036331a1b215d8d571964c1abde297 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:46:09 -0300 Subject: [PATCH 084/117] Change response logs to pretty Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index b93afa8e..fdde5a77 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -91,7 +91,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::OK); @@ -155,7 +155,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::OK); @@ -218,7 +218,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::SERVICE_UNAVAILABLE); @@ -284,7 +284,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::NOT_FOUND); @@ -350,7 +350,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); @@ -408,7 +408,7 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err .send() .await?; - debug!(?response); + debug!("{response:#?}"); // assertions assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); From 55a12f000fd28e84598321bfa901a2fefc90587d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:48:30 -0300 Subject: [PATCH 085/117] test case: detection_on_generation::test_request_missing_prompt_returns_422() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index fdde5a77..74fe1c03 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -462,3 +462,44 @@ async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Erro Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_prompt_returns_422() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&json!({ + "generated_text": generated_text, + "detectors": {detector_name: {}}, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("missing field `prompt`")); + + Ok(()) +} From 68d680da3b9fc592940a0343ad5ce32a86ef1211 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:53:31 -0300 Subject: [PATCH 086/117] test case: detection_on_generation::test_request_missing_generated_text_returns_422() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 74fe1c03..76be6244 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -503,3 +503,43 @@ async fn test_request_missing_prompt_returns_422() -> Result<(), anyhow::Error> Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_generated_text_returns_422() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&json!({ + "prompt": prompt, + "detectors": {detector_name: {}}, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("missing field `generated_text`")); + + Ok(()) +} From 0151e99f28611b781c2a19a867d1e158f6780e6d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:54:57 -0300 Subject: [PATCH 087/117] test case: detection_on_generation::test_request_missing_detectors_returns_422() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 76be6244..fc6de174 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -543,3 +543,45 @@ async fn test_request_missing_generated_text_returns_422() -> Result<(), anyhow: Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_detectors_returns_422() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&json!({ + "prompt": prompt, + "generated_text": generated_text, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("missing field `detectors`")); + + Ok(()) +} From e1f4a8bb28fb1c62fdc85e00939c5e78d2c458e8 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Tue, 25 Feb 2025 17:59:22 -0300 Subject: [PATCH 088/117] test case: detection_on_generation::test_request_with_empty_detectors_returns_422() Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 43 ++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index fc6de174..81b24e8b 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -585,3 +585,46 @@ async fn test_request_missing_detectors_returns_422() -> Result<(), anyhow::Erro Ok(()) } + +#[test(tokio::test)] +async fn test_request_with_empty_detectors_returns_422() -> Result<(), anyhow::Error> { + let detector_name = ANSWER_RELEVANCE_DETECTOR; + let prompt = "In 2014, what was the average height of men who were born in 1996?"; + let generated_text = + "The average height of men who were born in 1996 was 171cm (or 5'7.5'') in 2014."; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::new(), + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details == "`detectors` is required"); + + Ok(()) +} From 42d7caaee0e7e595413314e90d66bd9644ee6a98 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 08:49:30 -0300 Subject: [PATCH 089/117] Add comments to detection_on_generation tests Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 81b24e8b..a31dd89a 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -40,6 +40,7 @@ use tracing::debug; pub mod common; +// Asserts detections below the default threshold are not returned. #[test(tokio::test)] async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -56,7 +57,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -103,6 +104,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } +// Asserts detections above the default threshold are returned. #[test(tokio::test)] async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -120,7 +122,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -169,6 +171,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } +// Asserts error 503 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -183,7 +186,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -235,6 +238,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } +// Asserts error 404 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -249,7 +253,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -301,6 +305,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } +// Asserts error 500 from detectors is propagated with generic message. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -315,7 +320,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -361,8 +366,10 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 500 is returned with a generic message when detectors return a response +/// that does not comply with the detector API. #[test(tokio::test)] -async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Error> { +async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; let prompt = "In 2014, what was the average height of men who were born in 1996?"; let generated_text = @@ -371,7 +378,7 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), + MockPath::post(DETECTION_ON_GENERATION_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(GenerationDetectionRequest { prompt: prompt.into(), @@ -419,6 +426,7 @@ async fn test_detector_returns_non_compliant_message() -> Result<(), anyhow::Err Ok(()) } +/// Asserts requests with extra fields return 422. #[test(tokio::test)] async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -463,6 +471,7 @@ async fn test_request_with_extra_fields_returns_422() -> Result<(), anyhow::Erro Ok(()) } +/// Asserts requests missing `prompt` return 422. #[test(tokio::test)] async fn test_request_missing_prompt_returns_422() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -504,6 +513,7 @@ async fn test_request_missing_prompt_returns_422() -> Result<(), anyhow::Error> Ok(()) } +/// Asserts requests missing `generated_text` return 422. #[test(tokio::test)] async fn test_request_missing_generated_text_returns_422() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -544,6 +554,7 @@ async fn test_request_missing_generated_text_returns_422() -> Result<(), anyhow: Ok(()) } +/// Asserts requests missing `detectors` return 422. #[test(tokio::test)] async fn test_request_missing_detectors_returns_422() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -586,6 +597,7 @@ async fn test_request_missing_detectors_returns_422() -> Result<(), anyhow::Erro Ok(()) } +/// Asserts requests with empty `detectors` return 422. #[test(tokio::test)] async fn test_request_with_empty_detectors_returns_422() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; From 69777f4ed31c1ef9668853247d846a4d1eddf78b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 09:04:19 -0300 Subject: [PATCH 090/117] Fix comments Signed-off-by: Mateus Devino --- tests/detection_on_generation.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index a31dd89a..966a6b17 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -40,7 +40,7 @@ use tracing::debug; pub mod common; -// Asserts detections below the default threshold are not returned. +/// Asserts detections below the default threshold are not returned. #[test(tokio::test)] async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -104,7 +104,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } -// Asserts detections above the default threshold are returned. +/// Asserts detections above the default threshold are returned. #[test(tokio::test)] async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -171,7 +171,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } -// Asserts error 503 from detectors is propagated. +/// Asserts error 503 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -238,7 +238,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } -// Asserts error 404 from detectors is propagated. +/// Asserts error 404 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; @@ -305,7 +305,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } -// Asserts error 500 from detectors is propagated with generic message. +/// Asserts error 500 from detectors is propagated with generic message. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = ANSWER_RELEVANCE_DETECTOR; From 89d6adc68f1a79ea6950ab2e6bbeac8f1a7c55ac Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 11:39:16 -0300 Subject: [PATCH 091/117] test case: context_docs_detection::test_detection_below_default_threshold_is_not_returned() Signed-off-by: Mateus Devino --- tests/common/detectors.rs | 2 + tests/common/orchestrator.rs | 1 + tests/context_docs_detection.rs | 99 +++++++++++++++++++++++++++++++++ tests/test_config.yaml | 6 ++ 4 files changed, 108 insertions(+) create mode 100644 tests/context_docs_detection.rs diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index 04ec7633..45f4dc33 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -19,7 +19,9 @@ pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detector_whole_doc"; pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; pub const ANSWER_RELEVANCE_DETECTOR: &str = "answer_relevance_detector"; +pub const FACT_CHECKING_DETECTOR: &str = "fact_checking_detector"; // Detector endpoints pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; pub const DETECTION_ON_GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; +pub const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 14404cf7..66689a6d 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -42,6 +42,7 @@ pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = "/api/v1/task/server-streaming-classification-with-text-generation"; pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/detection/generated"; +pub const ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/context"; // Messages pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs new file mode 100644 index 00000000..caaa5f5e --- /dev/null +++ b/tests/context_docs_detection.rs @@ -0,0 +1,99 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; +use test_log::test; + +use common::{ + detectors::{CONTEXT_DOC_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR}, + orchestrator::{ + TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, + ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, + }, +}; +use fms_guardrails_orchestr8::{ + clients::detector::{ContextDocsDetectionRequest, ContextType}, + models::{ContextDocsHttpRequest, ContextDocsResult, DetectionResult, DetectorParams}, +}; +use hyper::StatusCode; +use mocktail::{prelude::*, utils::find_available_port}; +use tracing::debug; + +pub mod common; + +#[test(tokio::test)] +async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has decreased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + let detection = DetectionResult { + detection_type: "fact_check".into(), + detection: "is_accurate".into(), + detector_id: Some(detector_name.into()), + score: 0.23, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? == ContextDocsResult { detections: vec![] } + ); + + Ok(()) +} diff --git a/tests/test_config.yaml b/tests/test_config.yaml index 4194f9ba..ac2d1354 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -39,3 +39,9 @@ detectors: hostname: localhost chunker_id: whole_doc_chunker default_threshold: 0.5 + fact_checking_detector: + type: text_context_doc + service: + hostname: localhost + chunker_id: whole_doc_chunker + default_threshold: 0.5 From 10a03efa061c657de4e189c9640bfa670cd5c70b Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 11:43:17 -0300 Subject: [PATCH 092/117] test case: context_docs_detection::test_detection_above_default_threshold_is_returned() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index caaa5f5e..789bd1d3 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -97,3 +97,70 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } + +#[test(tokio::test)] +async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + let detection = DetectionResult { + detection_type: "fact_check".into(), + detection: "is_accurate".into(), + detector_id: Some(detector_name.into()), + score: 0.91, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == ContextDocsResult { + detections: vec![detection] + } + ); + + Ok(()) +} From 4577327d881cc35fe741225157370813fc025430 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 11:53:03 -0300 Subject: [PATCH 093/117] test case: context_docs_detection::test_detectior_returns_503() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 68 +++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 789bd1d3..5eea7d46 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -19,6 +19,7 @@ use test_log::test; use common::{ detectors::{CONTEXT_DOC_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR}, + errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, @@ -164,3 +165,70 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_503() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + let detector_error = DetectorError { + code: 503, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::SERVICE_UNAVAILABLE), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::SERVICE_UNAVAILABLE); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From 5d438957bd7d94910cda1a24853a458bbe09ed8f Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 11:57:04 -0300 Subject: [PATCH 094/117] test case: context_docs_detection::test_detectior_returns_404() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 5eea7d46..17b133c1 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -232,3 +232,70 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_404() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + let detector_error = DetectorError { + code: 404, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::NOT_FOUND); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From a0e22e8b2ccfee74a2442ea45c770aa661920550 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:12:45 -0300 Subject: [PATCH 095/117] test case: context_docs_detection::test_orchestrator_receives_a_request_with_extra_fields() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 167 +++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 1 deletion(-) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 17b133c1..cdbbbb0d 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -14,6 +14,7 @@ limitations under the License. */ +use serde_json::json; use std::collections::HashMap; use test_log::test; @@ -22,7 +23,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, + ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, }, }; use fms_guardrails_orchestr8::{ @@ -299,3 +300,167 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_500() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + let detector_error = DetectorError { + code: 500, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + assert!(response.code == detector_error.code); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_detector_returns_invalid_response() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ContextDocsDetectionRequest { + detector_params: DetectorParams::new(), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }), + MockResponse::json( + &json!({"message": "This response does not comply with the Detector API"}), + ) + .with_code(StatusCode::OK), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_with_extra_fields() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "content": content, + "context_type": "url", + "context": context, + "extra_args": true + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + assert!(response.code == 422); + assert!(response.details.contains("unknown field `extra_args`")); + + Ok(()) +} From 07c9044f28fdc0c50ac56952cfa33ddf060c6c35 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:14:44 -0300 Subject: [PATCH 096/117] test case: context_docs_detection::test_orchestrator_receives_a_request_missing_content() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index cdbbbb0d..cedcde85 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -464,3 +464,43 @@ async fn test_orchestrator_receives_a_request_with_extra_fields() -> Result<(), Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_missing_content() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "context_type": "url", + "context": context, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + assert!(response.code == 422); + assert!(response.details.contains("missing field `content`")); + + Ok(()) +} From 2fcd89fb74559c22885d4a55eff91ece1781ebf2 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:20:23 -0300 Subject: [PATCH 097/117] test case: context_docs_detection::test_orchestrator_receives_a_request_missing_context() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index cedcde85..8074b419 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -504,3 +504,43 @@ async fn test_orchestrator_receives_a_request_missing_content() -> Result<(), an Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_missing_context() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "context_type": "url", + "content": content, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + assert!(response.code == 422); + assert!(response.details.contains("missing field `context`")); + + Ok(()) +} From b1c132ea79be1942939ed8151f2b1b82bc23f581 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:22:26 -0300 Subject: [PATCH 098/117] test case: context_docs_detection::test_orchestrator_receives_a_request_missing_context_type() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 8074b419..7bfc7b22 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -544,3 +544,44 @@ async fn test_orchestrator_receives_a_request_missing_context() -> Result<(), an Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_missing_context_type() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "context": context, + "content": content, + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + assert!(response.code == 422); + assert!(response.details.contains("missing field `context_type`")); + + Ok(()) +} From e5e38e21cbfb70bfe706d38bbb98169fdabd47ec Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:26:10 -0300 Subject: [PATCH 099/117] test case: context_docs_detection::test_orchestrator_receives_a_request_with_invalid_context_type() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 46 +++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 7bfc7b22..1fd54ffe 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -585,3 +585,49 @@ async fn test_orchestrator_receives_a_request_missing_context_type() -> Result<( Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_with_invalid_context_type( +) -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "content": content, + "context": context, + "context_type": "thoughts" + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response + .details + .starts_with("context_type: unknown variant `thoughts`, expected `docs` or `url`")); + + Ok(()) +} From 08e4e294b1c244d2d7a81bb321998691d9951c4c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:30:38 -0300 Subject: [PATCH 100/117] test case: context_docs_detection::test_orchestrator_receives_a_request_missing_detectors() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 1fd54ffe..0b43e092 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -631,3 +631,45 @@ async fn test_orchestrator_receives_a_request_with_invalid_context_type( Ok(()) } + +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_missing_detectors() -> Result<(), anyhow::Error> { + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "content": content, + "context": context, + "context_type": "docs" + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.starts_with("missing field `detectors`")); + + Ok(()) +} From d807372960cdc7d1ceeb79072617f854fa28ff6e Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 12:34:55 -0300 Subject: [PATCH 101/117] test case: context_docs_detection::test_orchestrator_receives_a_request_with_invalid_detectors() Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 0b43e092..b1e8eb89 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -673,3 +673,45 @@ async fn test_orchestrator_receives_a_request_missing_detectors() -> Result<(), Ok(()) } +#[test(tokio::test)] +async fn test_orchestrator_receives_a_request_with_invalid_detectors() -> Result<(), anyhow::Error> +{ + let detector_name = FACT_CHECKING_DETECTOR; + let content = "The average human height has increased in the past century."; + let context = vec!["https://ourworldindata.org/human-height".to_string()]; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "content": content, + "context": context, + "context_type": "docs" + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.starts_with("missing field `detectors`")); + + Ok(()) +} From 5bfc573f93b88846f5f3348e2a34b0bde7479b5d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Fri, 28 Feb 2025 10:32:38 -0300 Subject: [PATCH 102/117] Update tests to use mocktail 0.1.2-alpha Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index b1e8eb89..b5bb0f99 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -52,7 +52,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), @@ -116,7 +116,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), @@ -180,7 +180,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), @@ -247,7 +247,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), @@ -314,7 +314,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), @@ -371,7 +371,7 @@ async fn test_detector_returns_invalid_response() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CONTEXT_DOC_DETECTOR_ENDPOINT), + MockPath::post(CONTEXT_DOC_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ContextDocsDetectionRequest { detector_params: DetectorParams::new(), From cb817bc18015f80409fc402c1d63dcbd001dbb55 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 08:55:41 -0300 Subject: [PATCH 103/117] Add comments to context_docs tests Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index b5bb0f99..5fb957de 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -36,6 +36,7 @@ use tracing::debug; pub mod common; +// Asserts detections below the default threshold are not returned. #[test(tokio::test)] async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -100,6 +101,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } +// Asserts detections above the default threshold are returned. #[test(tokio::test)] async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -167,6 +169,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } +// Asserts error 503 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -234,6 +237,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } +// Asserts error 404 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -301,6 +305,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } +// Asserts error 500 from detectors is propagated with generic message. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -362,6 +367,8 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 500 is returned with a generic message when detectors return a response +/// that does not comply with the detector API. #[test(tokio::test)] async fn test_detector_returns_invalid_response() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -422,6 +429,7 @@ async fn test_detector_returns_invalid_response() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts requests with extra fields return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_with_extra_fields() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -465,6 +473,7 @@ async fn test_orchestrator_receives_a_request_with_extra_fields() -> Result<(), Ok(()) } +/// Asserts requests missing `content` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_missing_content() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -505,6 +514,7 @@ async fn test_orchestrator_receives_a_request_missing_content() -> Result<(), an Ok(()) } +/// Asserts requests missing `context` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_missing_context() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -545,6 +555,7 @@ async fn test_orchestrator_receives_a_request_missing_context() -> Result<(), an Ok(()) } +/// Asserts requests missing `context_type` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_missing_context_type() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -586,6 +597,7 @@ async fn test_orchestrator_receives_a_request_missing_context_type() -> Result<( Ok(()) } +/// Asserts requests with invalid `context_type` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_with_invalid_context_type( ) -> Result<(), anyhow::Error> { @@ -632,6 +644,7 @@ async fn test_orchestrator_receives_a_request_with_invalid_context_type( Ok(()) } +/// Asserts requests missing `detectors` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_missing_detectors() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -673,6 +686,8 @@ async fn test_orchestrator_receives_a_request_missing_detectors() -> Result<(), Ok(()) } + +/// Asserts requests with empty `detectors` return 422. #[test(tokio::test)] async fn test_orchestrator_receives_a_request_with_invalid_detectors() -> Result<(), anyhow::Error> { From a0da2c7f87abf60d768521c55afffd535a87ce81 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 09:05:46 -0300 Subject: [PATCH 104/117] Fix comments Signed-off-by: Mateus Devino --- tests/context_docs_detection.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index 5fb957de..b26f68ac 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -36,7 +36,7 @@ use tracing::debug; pub mod common; -// Asserts detections below the default threshold are not returned. +/// Asserts detections below the default threshold are not returned. #[test(tokio::test)] async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -101,7 +101,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } -// Asserts detections above the default threshold are returned. +/// Asserts detections above the default threshold are returned. #[test(tokio::test)] async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -169,7 +169,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } -// Asserts error 503 from detectors is propagated. +/// Asserts error 503 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -237,7 +237,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } -// Asserts error 404 from detectors is propagated. +/// Asserts error 404 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; @@ -305,7 +305,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } -// Asserts error 500 from detectors is propagated with generic message. +/// Asserts error 500 from detectors is propagated with generic message. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = FACT_CHECKING_DETECTOR; From ad2053b3e27a7a6ca2d40bdfd41ca094b8533550 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 24 Feb 2025 14:06:40 -0300 Subject: [PATCH 105/117] Rename tests/detection_content.rs Signed-off-by: Mateus Devino --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 2168958c..d535c15f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1448,7 +1448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] From 9ebbe92e3912695c4dbfbcd430e965a8201dcb4d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:23:09 -0300 Subject: [PATCH 106/117] test case: chat_detection::test_detection_below_default_threshold_is_not_returned() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 107 +++++++++++++++++++++++++++++++++++ tests/common/detectors.rs | 2 + tests/common/orchestrator.rs | 1 + tests/test_config.yaml | 6 ++ 4 files changed, 116 insertions(+) create mode 100644 tests/chat_detection.rs diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs new file mode 100644 index 00000000..2f7a0274 --- /dev/null +++ b/tests/chat_detection.rs @@ -0,0 +1,107 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; +use test_log::test; + +use common::{ + detectors::{CHAT_DETECTOR_ENDPOINT, PII_DETECTOR}, + orchestrator::{ + TestOrchestratorServer, ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, + }, +}; +use fms_guardrails_orchestr8::{ + clients::{ + detector::ChatDetectionRequest, + openai::{Content, Message, Role}, + }, + models::{ChatDetectionHttpRequest, ChatDetectionResult, DetectionResult, DetectorParams}, +}; +use hyper::StatusCode; +use mocktail::{prelude::*, utils::find_available_port}; +use tracing::debug; + +pub mod common; + +#[test(tokio::test)] +async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Hi there!".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("Hello!".into())), + ..Default::default() + }, + ]; + let detection = DetectionResult { + detection_type: "pii".into(), + detection: "is_pii".into(), + detector_id: Some(detector_name.into()), + score: 0.01, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? == ChatDetectionResult { detections: vec![] } + ); + + Ok(()) +} diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index 45f4dc33..51cf95f0 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -20,8 +20,10 @@ pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detecto pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; pub const ANSWER_RELEVANCE_DETECTOR: &str = "answer_relevance_detector"; pub const FACT_CHECKING_DETECTOR: &str = "fact_checking_detector"; +pub const PII_DETECTOR: &str = "pii_detector"; // Detector endpoints pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; pub const DETECTION_ON_GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; pub const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc"; +pub const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 66689a6d..c0cbddcf 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -43,6 +43,7 @@ pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str = pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/detection/generated"; pub const ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/context"; +pub const ORCHESTRATOR_CHAT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/chat"; // Messages pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = diff --git a/tests/test_config.yaml b/tests/test_config.yaml index ac2d1354..468ebf37 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -45,3 +45,9 @@ detectors: hostname: localhost chunker_id: whole_doc_chunker default_threshold: 0.5 + pii_detector: + type: text_chat + service: + hostname: localhost + chunker_id: whole_doc_chunker + default_threshold: 0.5 From cdf0509b6f8f6505d7df67b7fc9d0caad0050049 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:28:21 -0300 Subject: [PATCH 107/117] test case: chat_detection::test_detection_above_default_threshold_is_returned() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 2f7a0274..6cd778ed 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -105,3 +105,76 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } + +#[test(tokio::test)] +async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("What is his cellphone?".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("It's +1 (123) 123-4567.".into())), + ..Default::default() + }, + ]; + let detection = DetectionResult { + detection_type: "pii".into(), + detection: "is_pii".into(), + detector_id: Some(detector_name.into()), + score: 0.97, + evidence: None, + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(vec![detection.clone()]), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::OK); + assert!( + response.json::().await? + == ChatDetectionResult { + detections: vec![detection] + } + ); + + Ok(()) +} From d189011fe7bbd9b9c4d70f6e4e22de742d7f7513 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:40:38 -0300 Subject: [PATCH 108/117] test case: chat_detection::test_detector_returns_503() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 75 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 6cd778ed..0a5c49d6 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -19,6 +19,7 @@ use test_log::test; use common::{ detectors::{CHAT_DETECTOR_ENDPOINT, PII_DETECTOR}, + errors::{DetectorError, OrchestratorError}, orchestrator::{ TestOrchestratorServer, ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, }, @@ -178,3 +179,77 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_503() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Why is orchestrator returning 503?".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("Because the detector returned 503.".into())), + ..Default::default() + }, + ]; + let detector_error = DetectorError { + code: 503, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::SERVICE_UNAVAILABLE), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::SERVICE_UNAVAILABLE); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 503); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From 8e325d7e945eae43df0b4482fed704ff95c7570d Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:42:17 -0300 Subject: [PATCH 109/117] test case: chat_detection::test_detector_returns_404() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 74 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 0a5c49d6..ad47d2e6 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -253,3 +253,77 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_404() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Why is orchestrator returning 404?".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("Because the detector returned 404.".into())), + ..Default::default() + }, + ]; + let detector_error = DetectorError { + code: 404, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::NOT_FOUND), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::NOT_FOUND); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 404); + assert!( + response.details + == format!( + "detector request failed for `{}`: {}", + detector_name, detector_error.message + ) + ); + + Ok(()) +} From c92c23d4bcc949827357d7deda6208d3711e2d7e Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:44:47 -0300 Subject: [PATCH 110/117] test case: chat_detection::test_detector_returns_500() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 71 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index ad47d2e6..391e8228 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -21,7 +21,8 @@ use common::{ detectors::{CHAT_DETECTOR_ENDPOINT, PII_DETECTOR}, errors::{DetectorError, OrchestratorError}, orchestrator::{ - TestOrchestratorServer, ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, + TestOrchestratorServer, ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, + ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, }, }; use fms_guardrails_orchestr8::{ @@ -327,3 +328,71 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_500() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Why is orchestrator returning 500?".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("Because the detector returned 500.".into())), + ..Default::default() + }, + ]; + let detector_error = DetectorError { + code: 500, + message: "The detector is overloaded".into(), + }; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&detector_error).with_code(StatusCode::INTERNAL_SERVER_ERROR), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From d0d71d5e1816bdb3d0c94b455d5f21e32937d488 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:50:14 -0300 Subject: [PATCH 111/117] Change detector messages Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 74 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 391e8228..73c87f69 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -14,6 +14,7 @@ limitations under the License. */ +use serde_json::json; use std::collections::HashMap; use test_log::test; @@ -272,7 +273,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { ]; let detector_error = DetectorError { code: 404, - message: "The detector is overloaded".into(), + message: "The detector was not found".into(), }; // Add detector mock @@ -346,7 +347,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { ]; let detector_error = DetectorError { code: 500, - message: "The detector is overloaded".into(), + message: "The detector had an error".into(), }; // Add detector mock @@ -396,3 +397,72 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Why is orchestrator returning 500?".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text( + "Because something went wrong. Sorry, I can't give more details.".into(), + )), + ..Default::default() + }, + ]; + + // Add detector mock + let mut mocks = MockSet::new(); + mocks.insert( + MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + Mock::new( + MockRequest::json(ChatDetectionRequest { + messages: messages.clone(), + detector_params: DetectorParams::new(), + }), + MockResponse::json(&json!({ + "message": "I won't comply with the detector API." + })) + .with_code(StatusCode::OK), + ), + ); + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, mocks)?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(detector_name.into(), DetectorParams::new())]), + messages, + }) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::INTERNAL_SERVER_ERROR); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 500); + assert!(response.details == ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + + Ok(()) +} From 08aa87242941165cd66539d14ded3695f7798aeb Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 14:59:04 -0300 Subject: [PATCH 112/117] test case: chat_detection::test_request_contains_extra_fields() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 49 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 73c87f69..b8aad9e3 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -466,3 +466,52 @@ async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_request_contains_extra_fields() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}}, + "messages": [ + { + "content": "What is this test asserting?", + "role": "user", + }, + { + "content": "It's making sure requests with extra fields are not accepted.", + "role": "assistant", + } + ], + "extra_args": true + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("unknown field `extra_args`")); + + Ok(()) +} From 8c5e31807fb46bee9ea5bd6669f75c280a122815 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 15:06:06 -0300 Subject: [PATCH 113/117] test case: chat_detection::test_request_missing_messages() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index b8aad9e3..7f932c8f 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -515,3 +515,41 @@ async fn test_request_contains_extra_fields() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_messages() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&json!({ + "detectors": {detector_name: {}} + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("missing field `messages`")); + + Ok(()) +} From 8af0a67c47be17dbf19b25f539b2daf32b9bc7b4 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 15:11:24 -0300 Subject: [PATCH 114/117] test case: chat_detection::test_request_missing_detector() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 7f932c8f..d9e9c89d 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -553,3 +553,50 @@ async fn test_request_missing_messages() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_request_missing_detector() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&json!({ + "messages": [ + { + "content": "What is this test asserting?", + "role": "user", + }, + { + "content": "It's making sure requests with extra fields are not accepted.", + "role": "assistant", + } + ], + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("missing field `detectors`")); + + Ok(()) +} From 0978f87624d69d4fbdc13711f444784fee7b9b6c Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Wed, 26 Feb 2025 15:13:53 -0300 Subject: [PATCH 115/117] test case: chat_detection::test_request_with_invalid_detector() Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 48 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index d9e9c89d..0e2019d6 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -600,3 +600,51 @@ async fn test_request_missing_detector() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn test_request_with_invalid_detector() -> Result<(), anyhow::Error> { + let detector_name = PII_DETECTOR; + + // Start orchestrator server and its dependencies + let mock_detector_server = HttpMockServer::new(detector_name, MockSet::new())?; + let orchestrator_server = TestOrchestratorServer::run( + ORCHESTRATOR_CONFIG_FILE_PATH, + find_available_port().unwrap(), + find_available_port().unwrap(), + None, + None, + Some(vec![mock_detector_server]), + None, + ) + .await?; + + // Make orchestrator call + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&json!({ + "messages": [ + { + "content": "What is this test asserting?", + "role": "user", + }, + { + "content": "It's making sure requests with extra fields are not accepted.", + "role": "assistant", + } + ], + "detectors": {} + })) + .send() + .await?; + + debug!("{response:#?}"); + + // assertions + assert!(response.status() == StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert!(response.code == 422); + assert!(response.details.contains("`detectors` is required")); + + Ok(()) +} From d328a6d5fb27988d9e09e06fe55ca0319c3fa2f8 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Thu, 27 Feb 2025 17:36:34 -0300 Subject: [PATCH 116/117] Update mocktail to 0.1.2-alpha Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 0e2019d6..0a286fc1 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -65,7 +65,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), @@ -135,7 +135,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), @@ -205,7 +205,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), @@ -279,7 +279,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), @@ -353,7 +353,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), @@ -419,7 +419,7 @@ async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { // Add detector mock let mut mocks = MockSet::new(); mocks.insert( - MockPath::new(Method::POST, CHAT_DETECTOR_ENDPOINT), + MockPath::post(CHAT_DETECTOR_ENDPOINT), Mock::new( MockRequest::json(ChatDetectionRequest { messages: messages.clone(), From 20d2afc6ee477ff17e1db15ec8cf69dd1e7cbc13 Mon Sep 17 00:00:00 2001 From: Mateus Devino Date: Mon, 3 Mar 2025 09:00:43 -0300 Subject: [PATCH 117/117] Add comments to chat_detection tests Signed-off-by: Mateus Devino --- tests/chat_detection.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 0a286fc1..659f3b58 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -39,6 +39,7 @@ use tracing::debug; pub mod common; +/// Asserts detections below the default threshold are not returned. #[test(tokio::test)] async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -109,6 +110,7 @@ async fn test_detection_below_default_threshold_is_not_returned() -> Result<(), Ok(()) } +/// Asserts detections above the default threshold are returned. #[test(tokio::test)] async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -182,6 +184,7 @@ async fn test_detection_above_default_threshold_is_returned() -> Result<(), anyh Ok(()) } +/// Asserts error 503 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_503() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -256,6 +259,7 @@ async fn test_detector_returns_503() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 404 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_404() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -330,6 +334,7 @@ async fn test_detector_returns_404() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts error 500 from detectors is propagated. #[test(tokio::test)] async fn test_detector_returns_500() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -398,6 +403,7 @@ async fn test_detector_returns_500() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts invalid response from detectors returns 500. #[test(tokio::test)] async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -467,6 +473,7 @@ async fn test_detector_returns_invalid_message() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts requests with extra fields return 422. #[test(tokio::test)] async fn test_request_contains_extra_fields() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -516,6 +523,7 @@ async fn test_request_contains_extra_fields() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts requests missing `messages` return 422. #[test(tokio::test)] async fn test_request_missing_messages() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -554,6 +562,7 @@ async fn test_request_missing_messages() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts requests missing `detectors` return 422. #[test(tokio::test)] async fn test_request_missing_detector() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR; @@ -601,6 +610,7 @@ async fn test_request_missing_detector() -> Result<(), anyhow::Error> { Ok(()) } +/// Asserts requests with empty `detectors` return 422. #[test(tokio::test)] async fn test_request_with_invalid_detector() -> Result<(), anyhow::Error> { let detector_name = PII_DETECTOR;