From 97feba1db0a673bf63ac1a55eb4cfa4700b812a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Mon, 18 Nov 2024 17:08:30 +0300 Subject: [PATCH] Readme tests removed. --- tests/scripts/test_scripts.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index bb928b5b..fedc1822 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -192,22 +192,6 @@ def test_composite_1_extend_from_inputs(): assert_results_equal(grads_1, grads_2) -def test_readme_model_3(): - import mithril as ml - - # Build a simple linear model - model = Linear(256) - # Generate a PyTorch backend with a (2,) device mesh - backend = ml.TorchBackend(device_mesh=(2, 1)) - # Compile the model - pm = ml.compile(model, backend, jit=False, data_keys={"input"}) - # Generate sharded data and parameters - params = {"w": backend.ones([128, 256]), "b": backend.ones([256])} - input = {"input": backend.ones(256, 128, device_mesh=(2, 1))} - # Run the compiled model - output = pm.evaluate(params, input) # noqa - - def test_primitive_model_with_context(): model = Buffer() context = TrainModel(model) @@ -7175,4 +7159,4 @@ def __call__( # type: ignore[override] trainable_keys = {"input": input} outputs = pm.evaluate(trainable_keys) ref_outputs = {"output": backend.ones(7) * 6} - assert_results_equal(outputs, ref_outputs) \ No newline at end of file + assert_results_equal(outputs, ref_outputs)