diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e9ec4d2..73cc4ac 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,12 +17,14 @@ jobs: uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: '3.9' - install-args: --with dev --with backends + install-args: --with dev --with torch - name: Install groco run: poetry install - name: Generate test coverage + env: + KERAS_BACKEND: torch run: | poetry run pytest --cov poetry run coverage xml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ee96c0..b7e1508 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: uses: packetcoders/action-setup-cache-python-poetry@main with: python-version: ${{ matrix.python-version }} - install-args: --with dev --with backends + install-args: --with dev --with ${{ matrix.backend }} - name: Install groco run: poetry install diff --git a/pyproject.toml b/pyproject.toml index fcff417..fc2a260 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,13 +21,23 @@ isort = "^5.13.2" pre-commit = "^3.6.2" Sphinx = "^7.2.6" -[tool.poetry.group.backends] +[tool.poetry.group.tensorflow] optional = true -[tool.poetry.group.backends.dependencies] +[tool.poetry.group.pytorch] +optional = true + +[tool.poetry.group.jax] +optional = true + +[tool.poetry.group.tensorflow.dependencies] tensorflow = "^2.16.1" -jax = "^0.4.25" + +[tool.poetry.group.torch.dependencies] torch = "^2.2.1" + +[tool.poetry.group.jax.dependencies] +jax = "^0.4.25" jaxlib = "^0.4.25" [build-system]