Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/training with obs #252

Closed
wants to merge 32 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cee4058
dual sampler, queue, and batch handler with obs. modifying Sup3rDatas…
bnb32 Dec 20, 2024
61f6206
training with obs test
bnb32 Dec 20, 2024
e6b8818
split up interface and abstact model
bnb32 Dec 21, 2024
ee79c13
made dual batch queue flexible enough to account for additional obs m…
bnb32 Dec 22, 2024
1b61a66
tensorboard mixin moved to model utilities. dual queue completely abs…
bnb32 Dec 22, 2024
df698e1
integrated dual sampler with obs into base dual sampler.
bnb32 Dec 23, 2024
0cb253c
examples added to DataHandler doc string. Some instructions on sup3rw…
bnb32 Dec 23, 2024
13664d1
removed namedtuple from Sup3rDataset to make Sup3rDataset picklable.
bnb32 Dec 26, 2024
c31ed72
parallel batch queue test added.
bnb32 Dec 27, 2024
89c6cde
namedtuple -> DsetTuple missing attr fix
bnb32 Dec 27, 2024
5976aa2
gust added to era download variables. len dunder added to ``Container…
bnb32 Dec 27, 2024
ec8b739
computing before reshaping is 2x faster.
bnb32 Dec 28, 2024
e265cd5
obs_index fix - sampler needs to use hr_out_features for the obs member.
bnb32 Dec 28, 2024
d4c009d
split up ``calc_loss`` and ``calc_loss_obs``
bnb32 Dec 29, 2024
ae564ac
Optional run_qa flag in ``DualRasterizer``. Queue shape fix for queue…
bnb32 Dec 29, 2024
4085a07
``run_qa=True`` default for ``DualRasterizer``
bnb32 Dec 29, 2024
dc933f9
better tracking of batch counting. (this can be tricky for parallel q…
bnb32 Dec 29, 2024
f26f304
missed compute call for slow batching. this was hidden by queueing an…
bnb32 Dec 29, 2024
b43e807
Included convert to tensor in ``sample_batch``. Test for training wit…
bnb32 Dec 30, 2024
9b542be
cc batch handler test fix
bnb32 Dec 31, 2024
10dbc9c
added test for new disc with "valid" padding
bnb32 Dec 31, 2024
2e6ed14
parallel sampling batch sampling test.
bnb32 Jan 1, 2025
8f2218d
removed workers tests. max_workers > 1 still not consistently faster.…
bnb32 Jan 2, 2025
02a9ecc
``Sup3rGanWithObs`` model subclass. Other misc model refactoring.
bnb32 Jan 3, 2025
dd70a57
moved ``_run`` method to bias correction interface ``AbstractBiasCorr…
bnb32 Jan 5, 2025
d93a6b1
moved ``_run`` method to bias correction interface ``AbstractBiasCorr…
bnb32 Jan 5, 2025
5afa9ed
fix: tensorboard issue with loss obs details
bnb32 Jan 8, 2025
32edc38
Adding obs loss to logging of loss gen
bnb32 Jan 19, 2025
dffdee2
Adding ``loss_obs`` to ``loss_gen`` so the total loss shows in log ou…
bnb32 Jan 21, 2025
ea1d3fd
generalized min pad width for padding slices so that this can accomod…
bnb32 Jan 11, 2025
06221d7
min padding depends on the ``.paddings`` attribute of the ``FlexibleP…
bnb32 Jan 15, 2025
8980735
`max_paddings` method in `interface` instead of in `strategy.py`.
bnb32 Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
parallel sampling batch sampling test.
  • Loading branch information
bnb32 committed Jan 19, 2025
commit 2e6ed144f378c86a14f5f65946d357b7bcea80bb
40 changes: 22 additions & 18 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
@@ -247,24 +247,21 @@ def running(self):
and not self.queue.is_closed()
)

def _enqueue_batches(self, n_batches) -> None:
"""Sample N batches and enqueue them as they are sampled."""
def sample_batches(self, n_batches) -> None:
"""Sample given number of batches either in serial or with thread
pool."""
if n_batches == 1 or self.max_workers == 1:
for _ in range(n_batches):
self.queue.enqueue(self.sample_batch())

else:
tasks = [
self._thread_pool.submit(self.sample_batch)
for _ in range(n_batches)
]
logger.debug(
'Added %s sample_batch futures to %s queue.',
n_batches,
self._thread_name,
)
for batch in as_completed(tasks):
self.queue.enqueue(batch.result())
return [self.sample_batch() for _ in range(n_batches)]
tasks = [
self._thread_pool.submit(self.sample_batch)
for _ in range(n_batches)
]
logger.debug(
'Added %s sample_batch futures to %s queue.',
n_batches,
self._thread_name,
)
return tasks

def enqueue_batches(self) -> None:
"""Callback function for queue thread. While training, the queue is
@@ -273,8 +270,15 @@ def enqueue_batches(self) -> None:
log_time = time.time()
while self.running:
needed = max(self.queue_cap - self.queue_len, 0)
needed = min(self.max_workers, needed)
if needed > 0:
self._enqueue_batches(n_batches=needed)
batches = self.sample_batches(n_batches=needed)
if needed > 1 and self.max_workers > 1:
for batch in as_completed(batches):
self.queue.enqueue(batch.result())
else:
for batch in batches:
self.queue.enqueue(batch)

if time.time() > log_time + 10:
logger.debug(self.log_queue_info())
84 changes: 61 additions & 23 deletions tests/batch_handlers/test_bh_general.py
Original file line number Diff line number Diff line change
@@ -2,11 +2,8 @@

import copy

import dask.array as da
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from scipy.ndimage import gaussian_filter

from sup3r.preprocessing import (
@@ -33,13 +30,12 @@
BatchHandlerTester = BatchHandlerTesterFactory(BatchHandler, SamplerTester)


def test_batch_handler_workers():
"""Check that it is faster to get batches with max_workers > 1 than with
max_workers = 1."""
def test_batch_sampling_workers():
"""Check that it is faster to sample batches with max_workers > 1 than with
max_workers = 1. This does not include enqueueing and dequeueing."""

timer = Timer()
n_lats = 200
n_lons = 200
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
sample_shape = (20, 20, 30)
chunk_shape = (
2 * sample_shape[0],
@@ -51,23 +47,65 @@ def test_batch_handler_workers():
n_batches = 10
n_epochs = 3

lons, lats = np.meshgrid(
np.linspace(0, 180, n_lats), np.linspace(40, 60, n_lons)
ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape)))

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=max_workers,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
batches = batcher.sample_batches(n_batches)
_ = [batch.result() for batch in batches]
timer.stop()
parallel_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()

batcher = BatchHandler(
[ds],
n_batches=n_batches,
batch_size=n_obs,
sample_shape=sample_shape,
max_workers=1,
means={'u_100m': 0, 'v_100m': 0},
stds={'u_100m': 1, 'v_100m': 1},
)
timer.start()
for _ in range(n_epochs):
_ = batcher.sample_batches(n_batches)
timer.stop()
serial_time = timer.elapsed / (n_batches * n_epochs)
batcher.stop()

print(
'Elapsed (serial / parallel): {} / {}'.format(
serial_time, parallel_time
)
)
time = pd.date_range('2023-01-01', '2023-05-01', freq='h')
u_arr = da.random.random((*lats.shape, len(time))).astype('float32')
v_arr = da.random.random((*lats.shape, len(time))).astype('float32')
ds = xr.Dataset(
coords={
'latitude': (('south_north', 'west_east'), lats),
'longitude': (('south_north', 'west_east'), lons),
'time': time,
},
data_vars={
'u_100m': (('south_north', 'west_east', 'time'), u_arr),
'v_100m': (('south_north', 'west_east', 'time'), v_arr),
},
assert serial_time > parallel_time


def test_batch_queue_workers():
"""Check that it is faster to queue batches with max_workers > 1 than with
max_workers = 1."""

timer = Timer()
ds = DummyData((200, 200, 2000), ['u_100m', 'v_100m'])
sample_shape = (20, 20, 30)
chunk_shape = (
2 * sample_shape[0],
2 * sample_shape[1],
2 * sample_shape[-1],
)
n_obs = 10
max_workers = 10
n_batches = 10
n_epochs = 3
ds = ds.chunk(dict(zip(['south_north', 'west_east', 'time'], chunk_shape)))

batcher = BatchHandler(
9 changes: 4 additions & 5 deletions tests/training/test_train_gan.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,6 @@ def _get_handlers():
['fp_gen', 'fp_disc', 's_enhance', 't_enhance', 'sample_shape'],
[
(pytest.ST_FP_GEN, pytest.ST_FP_DISC, 3, 4, (12, 12, 16)),
(pytest.ST_FP_GEN, pytest.ST_FP_DISC_PROD, 3, 4, (12, 12, 16)),
(pytest.S_FP_GEN, pytest.S_FP_DISC, 2, 1, (10, 10, 1)),
],
)
@@ -54,7 +53,8 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):
lr = 5e-5
Sup3rGan.seed()
model = Sup3rGan(
fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError')
fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError'
)

train_handler, val_handler = _get_handlers()

@@ -169,7 +169,7 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):
batch_handler.stop()


def test_train_workers(n_epoch=20):
def test_train_workers(n_epoch=10):
"""Test that model training with max_workers > 1 for the batch queue is
faster than for max_workers = 1."""

@@ -188,7 +188,6 @@ def test_train_workers(n_epoch=20):
)

with tempfile.TemporaryDirectory() as td:

batch_handler = BatchHandler(
train_containers=[train_handler],
val_containers=[val_handler],
@@ -252,7 +251,7 @@ def test_train_st_weight_update(n_epoch=2):
pytest.ST_FP_GEN,
pytest.ST_FP_DISC,
learning_rate=1e-4,
learning_rate_disc=4e-4
learning_rate_disc=4e-4,
)

train_handler, val_handler = _get_handlers()