diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f87e8d50..53da68a1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -76,6 +76,6 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --all --release --features persistent-moe,nlopt # ,blas,linfa/intel-mkl-static # disabled till linfa 0.8 + args: --all --release --features persistent-moe,nlopt,blas,linfa/intel-mkl-static diff --git a/Cargo.lock b/Cargo.lock index 71333366..0be81951 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,19 +64,20 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "approx" @@ -165,9 +166,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" dependencies = [ "serde", ] @@ -222,9 +223,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.5" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31a0499c1dc64f458ad13872de75c0eb7e3fdb0e67964610c914b034fc5956e" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -270,9 +271,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -280,9 +281,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -292,14 +293,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -539,7 +540,7 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "egobox" -version = "0.25.0" +version = "0.25.1" dependencies = [ "argmin_testfunctions", "ctrlc", @@ -562,7 +563,7 @@ dependencies = [ [[package]] name = "egobox-doe" -version = "0.25.0" +version = "0.25.1" dependencies = [ "approx", "criterion", @@ -578,7 +579,7 @@ dependencies = [ [[package]] name = "egobox-ego" -version = "0.25.0" +version = "0.25.1" dependencies = [ "anyhow", "approx", @@ -619,7 +620,7 @@ dependencies = [ [[package]] name = "egobox-gp" -version = "0.25.0" +version = "0.25.1" dependencies = [ "approx", "argmin_testfunctions", @@ -650,7 +651,7 @@ dependencies = [ [[package]] name = "egobox-moe" -version = "0.25.0" +version = "0.25.1" dependencies = [ "approx", "argmin_testfunctions", @@ -688,9 +689,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -698,9 +699,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -897,9 +898,12 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" +checksum = "3b31349d02fe60f80bbbab1a9402364cad7460626d6030494b08ac4a2075bf81" +dependencies = [ + "rustversion", +] [[package]] name = "is-terminal" @@ -935,51 +939,52 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", ] [[package]] -name = "kdtree" -version = "0.6.0" +name = "katexit" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80ee359328fc9087e9e3fc0a4567c4dd27ec69a127d6a70e8d9dd22845b8b1a2" +checksum = "eb1304c448ce2c207c2298a34bc476ce7ae47f63c23fa2b498583b26be9bc88c" dependencies = [ - "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] -name = "lapack" -version = "0.18.0" +name = "kdtree" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a7f0050af10913bc5e4f6091df38870cb5d4e45094d3dd551c1aff9d1d59b26" +checksum = "80ee359328fc9087e9e3fc0a4567c4dd27ec69a127d6a70e8d9dd22845b8b1a2" dependencies = [ - "lapack-sys", - "libc", - "num-complex", + "num-traits", ] [[package]] name = "lapack-sys" -version = "0.12.1" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1d3a8a9f07310243de6c6226f039f14bce8d2f4c96b5d30ddbcfa31eb4e94ad" +checksum = "447f56c85fb410a7a3d36701b2153c1018b1d2b908c5fbaf01c1b04fac33bcbe" dependencies = [ "libc", ] [[package]] name = "lax" -version = "0.15.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccd3ec1cacffe7a44aee66f9e85d87e3ac69b472b546449884bd1fc1ca8ab359" +checksum = "1f96a229d9557112e574164f8024ce703625ad9f88a90964c1780809358e53da" dependencies = [ "cauchy", - "lapack", + "katexit", + "lapack-sys", "num-traits", "thiserror 1.0.69", ] @@ -1014,9 +1019,9 @@ dependencies = [ [[package]] name = "linfa" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cab423110bc374e4cfa915da88952e2c6a4a5a6300ac0a0e68022bff2ace0b3" +checksum = "56f9097edc7c89d03d526efbacf6d90914e3a8fa53bd56c2d1489e3a90819370" dependencies = [ "approx", "ndarray", @@ -1030,9 +1035,9 @@ dependencies = [ [[package]] name = "linfa-clustering" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aef28a433e8cf5301a878b8bb36d5de6d6ff9895245b358dc6039bf9cf4438b" +checksum = "be0bc52d5e4da397609cd0e6007efc6bd278158d1803673bd936c374f27513c5" dependencies = [ "linfa", "linfa-linalg", @@ -1061,9 +1066,9 @@ dependencies = [ [[package]] name = "linfa-nn" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1945b1a1435e885b4743dad94f9676c0bf65d6cce102e96c753f93b4cfe68d0" +checksum = "b2fc281870379428baa56165c49ea6975c4db7fcf34d5a469738b001d7ade30b" dependencies = [ "kdtree", "linfa", @@ -1078,9 +1083,9 @@ dependencies = [ [[package]] name = "linfa-pls" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd9bc190667ad5c4b6a529e3fec5b47979dcd9a58f935a5c3d47ad099deb3b0" +checksum = "a3a706c8fa8952cbfb39e133417dc4c3a7e64acdc9430d605fa7c2398f1a1c48" dependencies = [ "linfa", "linfa-linalg", @@ -1105,9 +1110,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "matrixmultiply" @@ -1136,9 +1141,9 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", ] @@ -1175,11 +1180,12 @@ dependencies = [ [[package]] name = "ndarray-linalg" -version = "0.15.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f87ff36428f228c6056204d0f5cb8c5165f0db0a065429faace3edbc2718f1f" +checksum = "0b0e8dda0c941b64a85c5deb2b3e0144aca87aced64678adfc23eacea6d2cc42" dependencies = [ "cauchy", + "katexit", "lax", "ndarray", "num-complex", @@ -1421,7 +1427,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror 2.0.8", + "thiserror 2.0.11", "ucd-trie", ] @@ -1445,7 +1451,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1461,9 +1467,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -1522,9 +1528,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -1600,7 +1606,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1613,14 +1619,14 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1758,9 +1764,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -1779,9 +1785,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de" +checksum = "28e1c91382686d21b5ac7959341fcb9780fa7c03773646995a87c950fa7be640" dependencies = [ "sdd", ] @@ -1800,29 +1806,29 @@ checksum = "478f121bb72bbf63c52c93011ea1791dca40140dfe13f8336c4c5ac952c33aa9" [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -1852,7 +1858,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -1935,9 +1941,9 @@ checksum = "e990cc6cb89a82d70fe722cd7811dbce48a72bbfaebd623e58f142b6db28428f" [[package]] name = "sprs" -version = "0.11.2" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704ef26d974e8a452313ed629828cd9d4e4fa34667ca1ad9d6b1fffa43c6e166" +checksum = "88bab60b0a18fb9b3e0c26e92796b3c3a278bf5fa4880f5ad5cc3bdfb843d0b1" dependencies = [ "ndarray", "num-complex", @@ -1964,9 +1970,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -2007,11 +2013,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.8" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.8", + "thiserror-impl 2.0.11", ] [[package]] @@ -2022,18 +2028,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] name = "thiserror-impl" -version = "2.0.8" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2101,9 +2107,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "typetag" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ba3b6e86ffe0054b2c44f2d86407388b933b16cb0a70eea3929420db1d9bbe" +checksum = "044fc3365ddd307c297fe0fe7b2e70588cdab4d0f62dc52055ca0d11b174cf0e" dependencies = [ "erased-serde", "inventory", @@ -2114,13 +2120,13 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70b20a22c42c8f1cd23ce5e34f165d4d37038f5b663ad20fb6adbdf029172483" +checksum = "d9d30226ac9cbd2d1ff775f74e8febdab985dab14fb14aa2582c29a92d5555dc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] @@ -2171,34 +2177,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2206,28 +2213,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -2374,7 +2384,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.96", ] [[package]] diff --git a/crates/ego/Cargo.toml b/crates/ego/Cargo.toml index f7095e99..b2c88f67 100644 --- a/crates/ego/Cargo.toml +++ b/crates/ego/Cargo.toml @@ -30,7 +30,7 @@ linfa-pls = { version = "0.7", default-features = false } linfa-linalg = { version = "0.1", default-features = false } ndarray.workspace = true -ndarray-linalg = { version = "0.15", optional = true } +ndarray-linalg = { version = "0.16", optional = true } ndarray-stats.workspace = true ndarray-rand.workspace = true ndarray-npy.workspace = true diff --git a/crates/gp/Cargo.toml b/crates/gp/Cargo.toml index f3777eb6..7d79b41e 100644 --- a/crates/gp/Cargo.toml +++ b/crates/gp/Cargo.toml @@ -36,7 +36,7 @@ thiserror.workspace = true log.workspace = true rayon.workspace = true -ndarray-linalg = { version = "0.15", optional = true } +ndarray-linalg = { version = "0.16", optional = true } ndarray_einsum_beta = "0.7" ndarray-npy.workspace = true diff --git a/crates/gp/src/algorithm.rs b/crates/gp/src/algorithm.rs index 991b8d14..84887da2 100644 --- a/crates/gp/src/algorithm.rs +++ b/crates/gp/src/algorithm.rs @@ -10,6 +10,11 @@ use linfa::prelude::{Dataset, DatasetBase, Fit, Float, PredictInplace}; #[cfg(not(feature = "blas"))] use linfa_linalg::{cholesky::*, eigh::*, qr::*, svd::*, triangular::*}; +#[cfg(feature = "blas")] +use log::warn; +#[cfg(feature = "blas")] +use ndarray_linalg::{cholesky::*, eigh::*, qr::*, svd::*, triangular::*}; + use linfa_pls::PlsRegression; use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip}; @@ -583,20 +588,22 @@ impl, Corr: CorrelationModel> GaussianProc x: &ArrayBase, Ix1>, ) -> Array1 { let x = &(x.to_owned().insert_axis(Axis(0))); - let xnorm = (x - &self.xtrain.mean) / &self.xtrain.std; + let xnorm = (x - &self.xt_norm.mean) / &self.xt_norm.std; - let dx = pairwise_differences(&xnorm, &self.xtrain.data); + let dx = pairwise_differences(&xnorm, &self.xt_norm.data); let sigma2 = self.inner_params.sigma2; let r_chol = &self.inner_params.r_chol.to_owned().with_lapack(); let r = self + .params .corr .value(&dx, &self.theta, &self.w_star) .with_lapack(); let dr = self + .params .corr - .jacobian(&xnorm.row(0), &self.xtrain.data, &self.theta, &self.w_star) + .jacobian(&xnorm.row(0), &self.xt_norm.data, &self.theta, &self.w_star) .with_lapack(); let rho1 = r_chol @@ -611,8 +618,8 @@ impl, Corr: CorrelationModel> GaussianProc let p2 = inv_kr.t().dot(&dr); - let f_x = self.mean.value(x).t().to_owned(); - let f_mean = self.mean.value(&self.xtrain.data).with_lapack(); + let f_x = self.params.mean.value(x).t().to_owned(); + let f_mean = self.params.mean.value(&self.xt_norm.data).with_lapack(); let rho2 = r_chol .solve_triangular(UPLO::Lower, Diag::NonUnit, &f_mean) @@ -641,16 +648,16 @@ impl, Corr: CorrelationModel> GaussianProc } }; - let df = self.mean.jacobian(&xnorm.row(0)).with_lapack(); + let df = self.params.mean.jacobian(&xnorm.row(0)).with_lapack(); let d_a = df.t().to_owned() - dr.t().dot(&inv_kf); // let p3 = d_a.dot(&d_mat).t(); let p4 = d_mat.t().dot(&d_a.t()); let two = F::cast(2.); - let prime_t = (-p2 + p4).without_lapack().mapv(|v| two * v).t().to_owned(); + let prime_t = (p4 - p2).without_lapack().mapv(|v| two * v); - let x_std = &self.xtrain.std; + let x_std = &self.xt_norm.std; let dvar = (prime_t / x_std).mapv(|v| v * sigma2); dvar.row(0).into_owned() } @@ -1590,8 +1597,8 @@ mod tests { fn assert_rel_or_abs_error(y_deriv: f64, fdiff: f64) { println!("analytic deriv = {y_deriv}, fdiff = {fdiff}"); - if fdiff.abs() < 6e-1 { - let atol = 6e-1; + if fdiff.abs() < 1. { + let atol = 1.; println!("Check absolute error: abs({y_deriv}) should be < {atol}"); assert_abs_diff_eq!(y_deriv, 0.0, epsilon = atol); // check absolute when close to zero } else { diff --git a/crates/gp/src/optimization.rs b/crates/gp/src/optimization.rs index 8e69981c..a4c942d9 100644 --- a/crates/gp/src/optimization.rs +++ b/crates/gp/src/optimization.rs @@ -3,11 +3,7 @@ use ndarray::{arr1, s}; use ndarray_rand::rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256Plus; -#[cfg(feature = "blas")] -use log::warn; use ndarray::{Array, Array1, Array2, Zip}; -#[cfg(feature = "blas")] -use ndarray_linalg::{cholesky::*, eigh::*, qr::*, svd::*, triangular::*}; use linfa::prelude::Float; diff --git a/crates/moe/Cargo.toml b/crates/moe/Cargo.toml index 4cfa35b7..8681284b 100644 --- a/crates/moe/Cargo.toml +++ b/crates/moe/Cargo.toml @@ -39,7 +39,7 @@ linfa-pls = { version = "0.7", default-features = false } linfa-linalg = { version = "0.1", default-features = false } ndarray.workspace = true -ndarray-linalg = { version = "0.15", optional = true } +ndarray-linalg = { version = "0.16", optional = true } ndarray-stats.workspace = true ndarray-rand.workspace = true ndarray-npy.workspace = true diff --git a/crates/moe/src/clustering.rs b/crates/moe/src/clustering.rs index 115481cc..1a3f0a53 100644 --- a/crates/moe/src/clustering.rs +++ b/crates/moe/src/clustering.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use crate::parameters::GpMixtureParams; use crate::types::*; -use log::{debug, info}; +use log::debug; // , info}; use linfa::dataset::{Dataset, DatasetView}; use linfa::traits::{Fit, Predict}; @@ -9,7 +9,7 @@ use linfa::Float; use linfa_clustering::GaussianMixtureModel; use ndarray::{concatenate, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip}; use ndarray_rand::rand::Rng; -use std::ops::Sub; +// use std::ops::Sub; fn mean(list: &[f64]) -> f64 { let sum: f64 = Iterator::sum(list.iter()); @@ -93,7 +93,7 @@ pub fn find_best_number_of_clusters( let use_median = true; - info!( + debug!( "Find best nb of clusters (max={}, dataset size={}x{})", max_nb_clusters - 1, x.nrows(), @@ -102,15 +102,20 @@ pub fn find_best_number_of_clusters( // Find error for each cluster while i < max_nb_clusters && !stop { - debug!("Try {} cluster(s)", i + 1); + debug!("###############################Try {} cluster(s)", i + 1); let mut h_errors: Vec = Vec::new(); let mut s_errors: Vec = Vec::new(); let mut ok = true; // Say if this number of cluster is possible let n_clusters = i + 1; - + // let test_dir = "target/tests"; if ok { + // let xy = concatenate( + // Axis(1), + // &[x.view(), y.to_owned().insert_axis(Axis(1)).view()], + // ) + // .unwrap(); let xydata = Dataset::from( concatenate( Axis(1), @@ -126,7 +131,15 @@ pub fn find_best_number_of_clusters( if let Some(gmm) = maybe_gmm { // Cross Validation + // let data_clustering = gmm.predict(&xy); + // ndarray_npy::write_npy( + // format!("{test_dir}/clustering_{}.npy", i + 1), + // &data_clustering.mapv(|v| v as f64), + // ) + // .expect("clustering saved"); + for (train, valid) in dataset.fold(5).into_iter() { + debug!("X: {}", Array1::from_iter(valid.records().iter().cloned())); if let Ok(mixture) = GpMixtureParams::default() .n_clusters(n_clusters) .regression_spec(regression_spec) @@ -158,12 +171,16 @@ pub fn find_best_number_of_clusters( ok = false; // something wrong => early exit 1.0 } else { - let denom = actual.mapv(|x| x * x).sum().sqrt(); - if denom > 100. * f64::EPSILON { - pred.sub(actual).mapv(|x| x * x).sum().sqrt() / denom - } else { - pred.sub(actual).mapv(|x| x * x).sum().sqrt() - } + let denom = actual.mapv(|x| x.abs()).sum(); + // if denom > 100. * f64::EPSILON { + // (pred - actual).mapv(|x| x * x).sum().sqrt() / denom + // } else { + // (pred - actual).mapv(|x| x * x).sum().sqrt() + // } + debug!("Diff: {}", &pred.to_owned() - actual); + let err = (pred - actual).mapv(|x| x.abs()).sum() / denom; + debug!("Err = {}", err); + err } } else { ok = false; @@ -180,12 +197,13 @@ pub fn find_best_number_of_clusters( ok = false; // something wrong => early exit 1.0 } else { - let denom = actual.mapv(|x| x * x).sum().sqrt(); - if denom > 100. * f64::EPSILON { - pred.sub(actual).mapv(|x| x * x).sum().sqrt() / denom - } else { - pred.sub(actual).mapv(|x| x * x).sum().sqrt() - } + // let denom = actual.mapv(|x| x * x).sum().sqrt(); + // if denom > 100. * f64::EPSILON { + // (pred - actual).mapv(|x| x * x).sum().sqrt() / denom + // } else { + // (pred - actual).mapv(|x| x * x).sum().sqrt() + // } + (pred - actual).mapv(|x| x.abs()).sum() } } else { ok = false; @@ -212,6 +230,9 @@ pub fn find_best_number_of_clusters( debug!("Prediction with {} clusters fails", n_clusters); } + debug!("hard errors : {:?}", h_errors); + debug!("soft errors : {:?}", s_errors); + // Stock median errors median_err_s.push(median(&s_errors)); median_err_h.push(median(&h_errors)); @@ -272,7 +293,7 @@ pub fn find_best_number_of_clusters( if nb_clusters_ok.is_empty() { // Selection fails even with one cluster // possibly because some predicitions give inf or nan values - info!("Selection of best number of clusters fails. Default to 1 cluster with Smooth(None) recombination"); + debug!("Selection of best number of clusters fails. Default to 1 cluster with Smooth(None) recombination"); return (1, Recombination::Smooth(None)); } @@ -363,7 +384,7 @@ mod tests { use ndarray::{array, Array1, Array2, Axis, Zip}; #[cfg(feature = "blas")] use ndarray_linalg::Norm; - use ndarray_npy::write_npy; + //use ndarray_npy::write_npy; use ndarray_rand::rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; @@ -387,37 +408,45 @@ mod tests { #[test] fn test_find_best_cluster_nb_1d() { + // let env = env_logger::Env::new().filter_or("EGOBOX_LOG", "info"); + // let mut builder = env_logger::Builder::from_env(env); + // let builder = builder.target(env_logger::Target::Stdout); + // builder.try_init().ok(); + + //let test_dir = "target/tests"; let rng = Xoshiro256Plus::seed_from_u64(42); - let doe = Lhs::new(&array![[0., 1.]]).with_rng(rng); - //write_npy("doe.npy", &doe); + let doe = Lhs::new(&array![[0., 1.]]).with_rng(rng.clone()); let xtrain = doe.sample(50); - //write_npy("xtrain.npy", &xtrain); + // write_npy(format!("{test_dir}/xtrain.npy"), &xtrain).expect("xt save"); let ytrain = function_test_1d(&xtrain); - //write_npy("ytrain.npy", &ytrain); - let rng = Xoshiro256Plus::seed_from_u64(42); - let (nb_clusters, recombination) = find_best_number_of_clusters( + // write_npy(format!("{test_dir}/ytrain.npy"), &ytrain).expect("yt save"); + let (nb_clusters, _recombination) = find_best_number_of_clusters( &xtrain, &ytrain, - 5, + 3, None, RegressionSpec::ALL, CorrelationSpec::ALL, rng.clone(), ); - let moe = GpMixture::params() - .n_clusters(nb_clusters) - .recombination(recombination) - .with_rng(rng) - .fit(&Dataset::new(xtrain, ytrain)) - .unwrap(); - let obs = Array1::linspace(0., 1., 100).insert_axis(Axis(1)); - let preds = moe.predict(&obs).unwrap(); - - let test_dir = "target/tests"; - std::fs::create_dir_all(test_dir).ok(); - write_npy(format!("{test_dir}/best_obs.npy"), &obs).expect("obs save"); - write_npy(format!("{test_dir}/best_preds.npy"), &preds).expect("preds save"); assert_eq!(3, nb_clusters); + + println!("Optimal number of clusters = {nb_clusters}"); + + // for i in 1..=3 { + // let moe = GpMixture::params() + // .n_clusters(i) + // .recombination(recombination) + // .with_rng(rng.clone()) + // .fit(&Dataset::new(xtrain.clone(), ytrain.clone())) + // .unwrap(); + // let obs = Array1::linspace(0., 1., 100).insert_axis(Axis(1)); + // let preds = moe.predict(&obs).unwrap(); + + // std::fs::create_dir_all(test_dir).ok(); + // write_npy(format!("{test_dir}/best_obs.npy"), &obs).expect("obs save"); + // write_npy(format!("{test_dir}/best_preds_{i}.npy"), &preds).expect("preds save"); + // } } #[test]