Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mlflow #77

Merged
merged 330 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
330 commits
Select commit Hold shift + click to select a range
6fff3fc
install py39 on cirun runner
leifdenby Jun 3, 2024
74b4a10
cleanup: boundary_mask, zarr-opening, utils
sadamov Jun 4, 2024
0a041d1
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov Jun 4, 2024
8054e9e
change ami image to gpu
leifdenby Jun 4, 2024
39fbf3a
Merge remote-tracking branch 'upstream/main' into maint/deps-in-pypro…
leifdenby Jun 4, 2024
97aeb2e
use cheaper gpu instance
leifdenby Jun 4, 2024
425123c
adapted tests for zarr-analysis data
sadamov Jun 4, 2024
4dcf671
Readme adapted for yaml zarr analysis workflow
sadamov Jun 4, 2024
6d384f0
samller bugfixes and improvements
sadamov Jun 4, 2024
12ff4f2
Added fixed data config file for testing on Danra
sadamov Jun 4, 2024
03f7769
reducing runtime of tests with smaller sample
sadamov Jun 4, 2024
26f069c
download danra data for test and example (streaming not possible)
sadamov Jun 6, 2024
1f1cbcc
bugfixes after real-life testcase
sadamov Jun 6, 2024
b369306
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov Jun 6, 2024
0cdc361
organize .zarr in /data
sadamov Jun 6, 2024
23ca7b3
cleanup
sadamov Jun 6, 2024
81422f1
linter
sadamov Jun 6, 2024
124541b
static dataset doesn't have time dim
sadamov Jun 7, 2024
6140fdb
making two complex functions more modular
sadamov Jun 7, 2024
db6a912
chunk dataset by time
sadamov Jun 8, 2024
1aaa8dc
create list first for performance
sadamov Jun 8, 2024
81856b2
converting to_array is very slow
sadamov Jun 8, 2024
b3da818
allow for forcings to not be normalized
sadamov Jun 8, 2024
7ee5398
allow non_normalized_vars to be null
sadamov Jun 8, 2024
4782103
fixed coastlines using new xy_extent function
sadamov Jun 8, 2024
e0ffc5b
Some projections return inverted axes (rotatedPole)
sadamov Jun 9, 2024
c1f43b7
Docstrings added
sadamov Jun 13, 2024
21fd929
wip
leifdenby Jun 26, 2024
c52f98e
npy mllam nearly done
leifdenby Jul 6, 2024
80f3639
minor adjustment
leifdenby Jul 7, 2024
048f8c6
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 11, 2024
5aaa239
add pooch and tweak pip cicd testing
leifdenby Jul 11, 2024
66c3b03
combine cicd tests with caching
leifdenby Jul 11, 2024
8566b8f
linting
leifdenby Jul 11, 2024
29bd9e5
add pyg dep
leifdenby Jul 11, 2024
bc7f028
set cirun aws region to frankfurt
leifdenby Jul 11, 2024
2070166
adapt image
leifdenby Jul 11, 2024
e4e86e5
set image
leifdenby Jul 11, 2024
1fba8fe
try different image
leifdenby Jul 11, 2024
02b77cf
add pooch to cicd
leifdenby Jul 11, 2024
b481929
add pdm gpu test
leifdenby Jul 16, 2024
bcec472
start work on readme
leifdenby Jul 16, 2024
c5beec9
Merge branch 'maint/deps-in-pyproject-toml' into datastore
leifdenby Jul 16, 2024
e89facc
Merge branch 'main' into maint/refactor-as-package
leifdenby Jul 16, 2024
0b5687a
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 16, 2024
095fdbc
turn meps testdata download into pytest fixture
leifdenby Jul 16, 2024
49e9bfe
adapt README for package
leifdenby Jul 16, 2024
12cc02b
remove pdm cicd test (will be in separate PR)
leifdenby Jul 16, 2024
b47f50b
remove pdm in gitignore
leifdenby Jul 16, 2024
90d99ca
remove pdm and pyproject files (will be sep PR)
leifdenby Jul 16, 2024
a91eaaa
add pyproject.toml from main
leifdenby Jul 16, 2024
5508cea
clean out tests
leifdenby Jul 16, 2024
5c623c3
fix linting
leifdenby Jul 16, 2024
08ec168
add cli entrypoints import test
leifdenby Jul 16, 2024
d9cf7ba
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
3954f04
tweak cicd pytest execution
leifdenby Jul 16, 2024
f99fdce
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
db9d96f
Update tests/test_mllam_dataset.py
leifdenby Jul 17, 2024
3c864b2
grid-shape ok
leifdenby Jul 17, 2024
1f54b0e
get_vars_names and units
leifdenby Jul 17, 2024
9b88160
get_vars_names and units 2
leifdenby Jul 17, 2024
a9fdad5
test for stats
leifdenby Jul 23, 2024
555154f
get_dataarray test
leifdenby Jul 24, 2024
8b8a77e
get_dataarray test
leifdenby Jul 24, 2024
41f11cd
boundary_mask
leifdenby Jul 24, 2024
a17de0f
get_xy
leifdenby Jul 24, 2024
0a38a7d
remove TrainingSample dataclass
leifdenby Jul 24, 2024
f65f6b5
test for WeatherDataset.__getitem__
leifdenby Jul 24, 2024
a35100e
test for graph creation
leifdenby Jul 24, 2024
cfb0618
more graph creation tests
leifdenby Jul 24, 2024
8698719
check for consistency of num features across splits
leifdenby Jul 24, 2024
3381404
test for single batch from mllam through model
leifdenby Jul 24, 2024
2a6796c
Add init files to expose classes in editable package
joeloskarsson Jul 24, 2024
8f4e0e0
Linting
joeloskarsson Jul 24, 2024
e657abb
working training_step with datastores!
Jul 25, 2024
effc99b
remove superfluous tests
Jul 25, 2024
a047026
fix for dataset length
Jul 25, 2024
d2c62ed
step length should be int
Jul 25, 2024
58f5d99
step length should be int
Jul 25, 2024
64d43a6
training working with mllam datastore!
Jul 25, 2024
07444f8
adapt neural_lam.train_model for datastores
Jul 25, 2024
d1b6fc1
fixes for npy
Jul 25, 2024
6fe19ac
npyfiles datastore complete
leifdenby Jul 26, 2024
fe65a4d
cleanup for datastore examples
leifdenby Jul 26, 2024
e533794
training on ohm with danra!
Jul 26, 2024
640ac05
use mllam-data-prep v0.2.0
Aug 5, 2024
0f16f13
remove py3.12 from pre-commit
Aug 5, 2024
724548e
cleanup
Aug 8, 2024
a1b2037
all tests passing!
Aug 12, 2024
e35958f
use mllam-data-prep v0.3.0
Aug 12, 2024
8b92318
delete requirements.txt
Aug 13, 2024
658836a
remove .DS_Store
Aug 13, 2024
421efed
use tmate in gpu pdm cicd
Aug 13, 2024
05f1e9f
remove requirements
Aug 13, 2024
3afe0e4
update pdm gpu cicd setup to pdm venv on nvme drive
Aug 13, 2024
f3d028b
don't try to use pdm venv in-project
Aug 13, 2024
2c35662
remove tmate
Aug 13, 2024
5f30255
update README with install instructions
Aug 14, 2024
b2b5631
changelog
Aug 14, 2024
c8ae829
update ci/cd badges to include gpu + gpu
Aug 14, 2024
e7cf2c0
Merge pull request #1 from mllam/package_inits
leifdenby Aug 14, 2024
0b72e9d
add pyproject-flake8 to precommit config
Aug 14, 2024
190d1de
use Flake8-pyproject instead
Aug 14, 2024
791af0a
update README
Aug 14, 2024
58fab84
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
dbe2e6d
Merge branch 'maint/refactor-as-package' into maint/deps-in-pyproject…
Aug 14, 2024
eac6e35
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
799d55e
linting fixes
Aug 14, 2024
57bbb81
train only 1 epoch in cicd and print to stdout
Aug 14, 2024
a955cee
log datastore config
Aug 14, 2024
0a79c74
cleanup doctrings
Aug 15, 2024
9f3c014
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Aug 19, 2024
41364a8
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Aug 19, 2024
3422298
update changelog
leifdenby Aug 19, 2024
689ef69
move dev deps optional dependencies group
leifdenby Aug 20, 2024
9a0d538
update cicd tests to install dev deps
leifdenby Aug 20, 2024
bddfcaf
update readme with new dev deps group
leifdenby Aug 20, 2024
b96cfdc
quote the skip step the install readme
leifdenby Aug 20, 2024
2600dee
remove unused files
leifdenby Aug 20, 2024
65a8074
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Aug 20, 2024
6adf6cc
revert to line length of 80
leifdenby Aug 20, 2024
46b37f8
revert docstring formatting changes
leifdenby Aug 20, 2024
3cd0f8b
pin numpy to <2.0.0
leifdenby Aug 20, 2024
826270a
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
leifdenby Aug 20, 2024
4ba22ea
Merge branch 'main' into feat/datastores
leifdenby Aug 20, 2024
1f661c6
fix flake8 linting errors
leifdenby Aug 20, 2024
4838872
Update neural_lam/weather_dataset.py
leifdenby Sep 8, 2024
b59e7e5
Update neural_lam/datastore/multizarr/create_normalization_stats.py
leifdenby Sep 8, 2024
75b1fe7
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
7e736cb
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
613a7e2
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
65e199b
Update tests/test_training.py
leifdenby Sep 8, 2024
4435e26
Update tests/test_datasets.py
leifdenby Sep 8, 2024
4693408
Update README.md
leifdenby Sep 8, 2024
2dfed2c
update README
leifdenby Sep 10, 2024
c3d033d
Merge branch 'main' of https://github.com/mllam/neural-lam into feat/…
leifdenby Sep 10, 2024
4a70268
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
66c663f
column_water -> open_water_fraction
leifdenby Sep 10, 2024
11a7978
fix linting
leifdenby Sep 10, 2024
a41c314
static data same for all splits
leifdenby Sep 10, 2024
6f1efd6
forcing_window_size from args
leifdenby Sep 10, 2024
bacb9ec
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
4a9db4e
only use first ensemble member in datastores
leifdenby Sep 10, 2024
4fc2448
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
bcaa919
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
90bc594
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
5bda935
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
8e7931d
remove all multizarr functionality
leifdenby Sep 10, 2024
6998683
cleanup and test fixes for recent changes
leifdenby Sep 10, 2024
c415008
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
735d324
fix linting
leifdenby Sep 10, 2024
5f2d919
remove multizar example files
leifdenby Sep 10, 2024
5263d2c
normalization -> standardization
leifdenby Sep 10, 2024
ba1bec3
fix import for tests
leifdenby Sep 10, 2024
d04d15e
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
743d7a1
fix coord issues and add datastore example plotting cli
leifdenby Sep 12, 2024
ac10d7d
add lru_cache to get_xy_extent
leifdenby Sep 12, 2024
bf8172a
MLLAMDatastore -> MDPDatastore
leifdenby Sep 12, 2024
90ca400
missed renames for MDPDatastore
leifdenby Sep 12, 2024
154139d
update graph plot for datastores
leifdenby Sep 12, 2024
50ee0b0
use relative import
leifdenby Sep 12, 2024
7dfd570
add long_names and refactor npyfiles create weights
leifdenby Sep 12, 2024
2b45b5a
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
aee0b1c
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
8453c2b
Update neural_lam/models/ar_model.py
leifdenby Sep 27, 2024
7f32557
Update neural_lam/weather_dataset.py
leifdenby Sep 27, 2024
67998b8
read projection from datastore config extra section
leifdenby Sep 27, 2024
ac7e46a
NpyFilesDatastore -> NpyFilesDatastoreMEPS
leifdenby Sep 27, 2024
b7bf506
revert tp training with 1 AR step by default
leifdenby Sep 27, 2024
5df2ecf
add missing kwarg to BaseHiGraphModel.__init__
leifdenby Sep 27, 2024
d4d438f
add missing kwarg to HiLAM.__init__
leifdenby Sep 27, 2024
1889771
add missing kwarg to HiLAMParallel
leifdenby Sep 27, 2024
2c3bbde
check that for enough forecast steps given ar_steps
leifdenby Sep 27, 2024
f0a151b
remove numpy<2.0.0 version cap
leifdenby Sep 27, 2024
f3566b0
tweak print statement working in mdp
Oct 1, 2024
dba94b3
fix missed removed argument from cli
Oct 1, 2024
bca1482
remove wandb config log comment, we log now
Oct 1, 2024
fc973c4
ensure loading from checkpoint during train possible
Oct 1, 2024
9fcf06e
get step_length from datastore in plot_error_map
leifdenby Oct 1, 2024
2bbe666
remove step_legnth attr in ARModel
leifdenby Oct 1, 2024
b41ed2f
remove unused obs_mask arg for vis.plot_prediction
leifdenby Oct 1, 2024
7e46194
ensure no reference to multizarr "data_config"
leifdenby Oct 1, 2024
b57bc7a
introduce neural-lam config
leifdenby Oct 2, 2024
2b30715
include meps neural-lam config example
leifdenby Oct 2, 2024
8e7b2e6
fix extra space typo in BaseDatastore
leifdenby Oct 2, 2024
e0300fb
add check and print of train/test/val split in MDPDatastore
leifdenby Oct 2, 2024
a921e35
add experimental mlflow server support
leifdenby Oct 2, 2024
0f30259
more fixes for mlflow logging support
leifdenby Oct 3, 2024
3fbe2d0
Make wandb work again with pytorch_lightning.logger
khintz Oct 3, 2024
e0284a8
upload of artifact to mlflow works, but instantiates a new experiment
khintz Oct 4, 2024
7eed79b
make mlflow use same experiment run id as pl.logger.MLFlowLogger
khintz Oct 7, 2024
27408f2
logger artifact working for both wandb and mlflow
khintz Oct 7, 2024
e61a9e7
support mlflow system metrics logging
khintz Oct 7, 2024
b53bab5
support model logging for mlflow
khintz Oct 7, 2024
de27e9a
log model
khintz Oct 7, 2024
89d8cde
test system metrics
khintz Nov 13, 2024
54c7ca7
make mlflow work also for eval mode
khintz Nov 15, 2024
a47de0c
dummy prints to identify workflow
khintz Nov 21, 2024
10a4494
update mlflow on eval mode
khintz Nov 21, 2024
427a4b1
Merge branch 'main' into feat/mlflow
khintz Nov 21, 2024
78e874d
inspect plot routines
khintz Nov 25, 2024
5904cbe
identified issue, cleanup next
leifdenby Nov 25, 2024
efe0302
use xarray plot only
leifdenby Nov 26, 2024
a489c2e
don't reraise
leifdenby Nov 26, 2024
242d08b
remove debug plot
leifdenby Nov 26, 2024
c1f706c
remove extent calc used in diagnosing issue
leifdenby Nov 26, 2024
88ec9dc
Test order of dimension in eval plots
khintz Nov 28, 2024
d367cdb
Merge branch 'fix/eval-vis-plots' into feat/mlflow
khintz Nov 28, 2024
90f8918
fix tensors on cpu and plot time index
khintz Nov 28, 2024
53f0ea4
restore tests/test_datasets.py
khintz Nov 29, 2024
cfc249f
cleaning up with focus on linting
khintz Nov 29, 2024
b218c8b
update tests
khintz Nov 29, 2024
1f1aed8
use correct data module for input example
khintz Dec 4, 2024
f3abd47
Merge branch 'main' into feat/mlflow
khintz Dec 4, 2024
47932b5
clean log model function
khintz Dec 9, 2024
98dc5c4
Merge branch 'main' into feat/mlflow
khintz Dec 9, 2024
64971ae
revert bad merge
khintz Dec 9, 2024
010f716
remove unused init for datastore
khintz Dec 9, 2024
2620bd1
set logger url
khintz Dec 9, 2024
75a39e6
change type of default logger_url in config
khintz Dec 9, 2024
9d27a4c
linting
khintz Dec 9, 2024
8f42cd1
fix log_image issue in tests
khintz Dec 9, 2024
b5ebe6f
add entry to changelog
khintz Dec 9, 2024
ae69f3f
remove artifacts from earlier merging/rebase
khintz Dec 11, 2024
6e16035
catch error when aws credentials not set
khintz Dec 11, 2024
821443a
remove functions log_model and create_input_example
khintz Jan 7, 2025
d503048
Elaborate warning when no logger is set
khintz Jan 7, 2025
c80d36f
add pynvml to pyproject to allow mlflow to log gpustats
khintz Jan 21, 2025
538c26d
move logger from config to command line arguments
khintz Jan 21, 2025
30cb31a
change logger_project command line argument
khintz Jan 21, 2025
d165fcd
adjust tests after moving loggers to cli
khintz Jan 21, 2025
e1dac04
add choices to logger argument
khintz Jan 21, 2025
5fe9f84
remove unused command line arguments
khintz Jan 21, 2025
97e7bb3
Merge branch 'main' into feat/mlflow
khintz Jan 21, 2025
7a90c57
correct changelog after merging in main
khintz Jan 21, 2025
2ca3959
restore defaults for pl.Trainer
khintz Jan 21, 2025
757b502
use MLFLOW_TRACKING_URI environment variable instead of command-line
khintz Jan 21, 2025
05e1be0
Update README on MLFlow
khintz Jan 21, 2025
f7f90c4
init logger on rank0 and use run_name for mlflow
khintz Jan 21, 2025
3187dc3
move custom_logger to its own file
khintz Jan 22, 2025
d88c23c
correct typo in custom_loggers
khintz Jan 22, 2025
d54e3d8
Merge branch 'main' into feat/mlflow
khintz Jan 23, 2025
90b48ff
improve code to pass linting
khintz Jan 23, 2025
d5c5dcc
Apply suggestions from code review
khintz Jan 27, 2025
1c0d121
Satisfy linting after review
khintz Jan 27, 2025
42dafa1
elaborate docstring for save_dir
khintz Jan 27, 2025
1b9b7f0
remove unnecesary change
khintz Jan 27, 2025
f743890
add inline comment for logger handling
khintz Jan 27, 2025
2d2aa23
warn if logger does not support image logging
khintz Jan 27, 2025
d2ffefc
expand on inline comment on logger key
khintz Jan 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/mllam/neural-lam/compare/v0.3.0...HEAD)

### Added

- Add support for MLFlow logging and metrics tracking. [\#77](https://github.com/mllam/neural-lam/pull/77)
@khintz

- Add support for multi-node training.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,9 @@ The graphs used for the different models in the [paper](#graph-based-neural-weat

The graph-related files are stored in a directory called `graphs`.

## Weights & Biases Integration
## Logging your experiments

### Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`.
Expand All @@ -398,6 +400,13 @@ If you would like to turn off W&B and just log things locally, run:
wandb off
```

### MLFlow Integration
The project is also integrated with [MLFlow](https://mlflow.org/) for logging and storing artefacts.

MLFlow is not used by default, but can be switched to by setting `--logger mlflow` in the training command. With MLFlow enabled, training configuration, training/test statistics and plots are logged to the MLFlow server. MLFlow is self-hosted and can be run locally or on a server. See the [MLFlow documentation](https://mlflow.org/docs/latest/index.html) for details.

Use the environment variable `MLFLOW_TRACKING_URI` to set the URI of the MLFlow server. If not set the logging can not be used. An example of setting the URI to a server is and running a training command is `MLFLOW_TRACKING_URI=http://localhost:5000 python -m neural_lam.train_model --config_path <config_path> --logger mlflow`.

leifdenby marked this conversation as resolved.
Show resolved Hide resolved
## Train Models
Models can be trained using `python -m neural_lam.train_model --config_path <config_path>`.
Run `python neural_lam.train_model --help` for a full list of training options.
Expand Down
68 changes: 68 additions & 0 deletions neural_lam/custom_loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Standard library
import sys

# Third-party
import mlflow
import mlflow.pytorch
import pytorch_lightning as pl
from loguru import logger


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""
Custom MLFlow logger that adds the `log_image()` functionality not
present in the default implementation from pytorch-lightning as
of version `2.0.3` at least.
"""

def __init__(self, experiment_name, tracking_uri, run_name):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)

mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.set_tag("mlflow.runName", run_name)
mlflow.log_param("run_id", self.run_id)

@property
def save_dir(self):
leifdenby marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the directory where the MLFlow artifacts are saved.
Used to define the path to save output when using the logger.

Returns
-------
str
Path to the directory where the artifacts are saved.
"""
return "mlruns"

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow

key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
"""
# Third-party
import botocore
from PIL import Image

if step is not None:
key = f"{key}_{step}"

# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but is buggy
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
try:
mlflow.log_image(img, f"{key}.png")
except botocore.exceptions.NoCredentialsError:
logger.error("Error logging image\nSet AWS credentials")
sys.exit(1)
77 changes: 58 additions & 19 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Standard library
import os
import warnings
from typing import List, Union

# Third-party
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
import xarray as xr

# Local
Expand Down Expand Up @@ -539,14 +539,26 @@ def plot_examples(self, batch, n_examples, split, prediction=None):

example_i = self.plotted_examples

wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
for var_name, fig in zip(
self._datastore.get_vars_names("state"), var_figs
):

# We need treat logging images differently for different
# loggers. WANDB can log multiple images to the same key,
# while other loggers, as MLFlow, need unique keys for
# each image.
if isinstance(self.logger, pl.loggers.WandbLogger):
leifdenby marked this conversation as resolved.
Show resolved Hide resolved
key = f"{var_name}_example_{example_i}"
else:
key = f"{var_name}_example"

if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[fig], step=t_i)
leifdenby marked this conversation as resolved.
Show resolved Hide resolved
else:
warnings.warn(
f"{self.logger} does not support image logging."
)
}
)

plt.close(
"all"
) # Close all figs for this time step, saves memory
Expand All @@ -555,13 +567,15 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
torch.save(
pred_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_pred_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_pred_{self.plotted_examples}.pt",
),
)
torch.save(
target_slice.cpu(),
os.path.join(
wandb.run.dir, f"example_target_{self.plotted_examples}.pt"
self.logger.save_dir,
f"example_target_{self.plotted_examples}.pt",
),
)

Expand All @@ -582,16 +596,16 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
datastore=self._datastore,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
log_dict[full_log_name] = metric_fig

if prefix == "test":
# Save pdf
metric_fig.savefig(
os.path.join(wandb.run.dir, f"{full_log_name}.pdf")
os.path.join(self.logger.save_dir, f"{full_log_name}.pdf")
)
# Save errors also as csv
np.savetxt(
os.path.join(wandb.run.dir, f"{full_log_name}.csv"),
os.path.join(self.logger.save_dir, f"{full_log_name}.csv"),
metric_tensor.cpu().numpy(),
delimiter=",",
)
Expand Down Expand Up @@ -639,8 +653,27 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
)
)

# Ensure that log_dict has structure for
# logging as dict(str, plt.Figure)
assert all(
isinstance(key, str) and isinstance(value, plt.Figure)
for key, value in log_dict.items()
)

if self.trainer.is_global_zero and not self.trainer.sanity_checking:
wandb.log(log_dict) # Log all

current_epoch = self.trainer.current_epoch

for key, figure in log_dict.items():
# For other loggers than wandb, add epoch to key.
# Wandb can log multiple images to the same key, while other
# loggers, such as MLFlow need unique keys for each image.
if not isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{key}-{current_epoch}"

if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[figure])

plt.close("all") # Close all figs

def on_test_epoch_end(self):
Expand Down Expand Up @@ -672,9 +705,13 @@ def on_test_epoch_end(self):
)
]

# log all to same wandb key, sequentially
for fig in loss_map_figs:
wandb.log({"test_loss": wandb.Image(fig)})
# log all to same key, sequentially
for i, fig in enumerate(loss_map_figs):
key = "test_loss"
if not isinstance(self.logger, pl.loggers.WandbLogger):
key = f"{key}_{i}"
if hasattr(self.logger, "log_image"):
self.logger.log_image(key=key, images=[fig])

# also make without title and save as pdf
pdf_loss_map_figs = [
Expand All @@ -683,14 +720,16 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
pdf_loss_maps_dir = os.path.join(
self.logger.save_dir, "spatial_loss_maps"
)
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
os.path.join(wandb.run.dir, "mean_spatial_loss.pt"),
os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"),
)

self.spatial_loss_maps.clear()
Expand Down
36 changes: 24 additions & 12 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from argparse import ArgumentParser

# Third-party
# for logging the model:
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
Expand Down Expand Up @@ -182,10 +183,17 @@ def main(input_args=None):

# Logger Settings
parser.add_argument(
"--wandb_project",
"--logger",
type=str,
default="wandb",
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
choices=["wandb", "mlflow"],
help="Logger to use for training (wandb/mlflow) (default: wandb)",
)
parser.add_argument(
"--logger-project",
type=str,
default="neural_lam",
help="Wandb project name (default: neural_lam)",
help="Logger project name, for eg. Wandb (default: neural_lam)",
)
parser.add_argument(
"--val_steps_to_log",
Expand Down Expand Up @@ -286,26 +294,26 @@ def main(input_args=None):
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
)

training_logger = utils.setup_training_logger(
datastore=datastore, args=args, run_name=run_name
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=f"saved_models/{run_name}",
filename="min_val_loss",
monitor="val_mean_loss",
mode="min",
save_last=True,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
trainer = pl.Trainer(
max_epochs=args.epochs,
deterministic=True,
strategy="ddp",
accelerator=device_name,
num_nodes=args.num_nodes,
devices=devices,
logger=logger,
logger=training_logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
check_val_every_n_epoch=args.val_interval,
Expand All @@ -314,11 +322,15 @@ def main(input_args=None):

# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
utils.init_training_logger_metrics(
training_logger, val_steps=args.val_steps_to_log
) # Do after initializing logger
if args.eval:
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
trainer.test(
model=model,
datamodule=data_module,
ckpt_path=args.load,
)
else:
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)

Expand Down
Loading
Loading