diff --git a/earthaccess/store.py b/earthaccess/store.py index 7bf54c60..35f369e0 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -1,5 +1,6 @@ import datetime import logging +import threading import traceback from functools import lru_cache from itertools import chain @@ -17,7 +18,7 @@ import earthaccess -from .auth import Auth +from .auth import Auth, SessionWithHeaderRedirection from .daac import DAAC_TEST_URLS, find_provider from .results import DataGranule from .search import DataCollections @@ -118,6 +119,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None: Parameters: auth: Auth instance to download and access data. """ + self.thread_locals = threading.local() if auth.authenticated is True: self.auth = auth self._s3_credentials: Dict[ @@ -126,7 +128,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None: oauth_profile = f"https://{auth.system.edl_hostname}/profile" # sets the initial URS cookie self._requests_cookies: Dict[str, Any] = {} - self.set_requests_session(oauth_profile) + self.set_requests_session(oauth_profile, bearer_token=True) if pre_authorize: # collect cookies from other DAACs for url in DAAC_TEST_URLS: @@ -182,7 +184,7 @@ def _running_in_us_west_2(self) -> bool: return False def set_requests_session( - self, url: str, method: str = "get", bearer_token: bool = False + self, url: str, method: str = "get", bearer_token: bool = True ) -> None: """Sets up a `requests` session with bearer tokens that are used by CMR. @@ -323,19 +325,19 @@ def get_fsspec_session(self) -> fsspec.AbstractFileSystem: session = fsspec.filesystem("https", client_kwargs=client_kwargs) return session - def get_requests_session(self, bearer_token: bool = True) -> requests.Session: + def get_requests_session(self) -> SessionWithHeaderRedirection: """Returns a requests HTTPS session with bearer tokens that are used by CMR. This HTTPS session can be used to download granules if we want to use a direct, lower level API. - Parameters: - bearer_token: if true, will be used for authenticated queries on CMR - Returns: requests Session """ - return self.auth.get_session() + if hasattr(self, "_http_session"): + return self._http_session + else: + raise AttributeError("The requests session hasn't been set up yet.") def open( self, @@ -651,6 +653,27 @@ def _get_granules( data_links, local_path, pqdm_kwargs=pqdm_kwargs ) + def _clone_session_in_local_thread( + self, original_session: SessionWithHeaderRedirection + ) -> None: + """Clone the original session and store it in the local thread context. + + This method creates a new session that replicates the headers, cookies, and authentication settings + from the provided original session. The new session is stored in a thread-local storage. + + Parameters: + original_session (SessionWithHeaderRedirection): The session to be cloned. + + Returns: + None + """ + if not hasattr(self.thread_locals, "local_thread_session"): + local_thread_session = SessionWithHeaderRedirection() + local_thread_session.headers.update(original_session.headers) + local_thread_session.cookies.update(original_session.cookies) + local_thread_session.auth = original_session.auth + self.thread_locals.local_thread_session = local_thread_session + def _download_file(self, url: str, directory: Path) -> str: """Download a single file from an on-prem location, a DAAC data center. @@ -668,7 +691,11 @@ def _download_file(self, url: str, directory: Path) -> str: path = directory / Path(local_filename) if not path.exists(): try: - session = self.auth.get_session() + original_session = self.get_requests_session() + # This reuses the auth cookie, we make sure we only authenticate N threads instead + # of one per file, see #913 + self._clone_session_in_local_thread(original_session) + session = self.thread_locals.local_thread_session with session.get(url, stream=True, allow_redirects=True) as r: r.raise_for_status() with open(path, "wb") as f: diff --git a/tests/unit/test_store.py b/tests/unit/test_store.py index f2be362b..069cb030 100644 --- a/tests/unit/test_store.py +++ b/tests/unit/test_store.py @@ -1,13 +1,18 @@ # package imports import os +import threading import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch import fsspec import pytest import responses import s3fs from earthaccess import Auth, Store +from earthaccess.auth import SessionWithHeaderRedirection from earthaccess.store import EarthAccessFile +from pqdm.threads import pqdm class TestStoreSessions(unittest.TestCase): @@ -128,6 +133,81 @@ def test_store_can_create_s3_fsspec_session(self): return None + @responses.activate + def test_session_reuses_token_download(self): + mock_creds = { + "accessKeyId": "sure", + "secretAccessKey": "correct", + "sessionToken": "whynot", + } + test_cases = [ + (2, 500), # 2 threads, 500 files + (4, 400), # 4 threads, 400 files + (8, 5000), # 8 threads, 5k files + ] + for n_threads, n_files in test_cases: + with self.subTest(n_threads=n_threads, n_files=n_files): + urls = [f"https://example.com/file{i}" for i in range(1, n_files + 1)] + for i, url in enumerate(urls): + responses.add( + responses.GET, url, body=f"Content of file {i + 1}", status=200 + ) + + mock_auth = MagicMock() + mock_auth.authenticated = True + mock_auth.system.edl_hostname = "urs.earthdata.nasa.gov" + responses.add( + responses.GET, + "https://urs.earthdata.nasa.gov/profile", + json=mock_creds, + status=200, + ) + + original_session = SessionWithHeaderRedirection() + original_session.cookies.set("sessionid", "mocked-session-cookie") + mock_auth.get_session.return_value = original_session + + store = Store(auth=mock_auth) + store.thread_locals = threading.local() # Use real thread-local storage + + # Track cloned sessions + cloned_sessions = set() + + def mock_clone_session_in_local_thread(original_session): + """Mock session cloning to track cloned sessions.""" + if not hasattr(store.thread_locals, "local_thread_session"): + session = SessionWithHeaderRedirection() + session.cookies.update(original_session.cookies) + cloned_sessions.add(id(session)) + store.thread_locals.local_thread_session = session + + with patch.object( + store, + "_clone_session_in_local_thread", + side_effect=mock_clone_session_in_local_thread, + ): + mock_directory = Path("/mock/directory") + downloaded_files = [] + + def mock_download_file(url): + """Mock file download to track downloaded files.""" + # Ensure session cloning happens before downloading + store._clone_session_in_local_thread(original_session) + downloaded_files.append(url) + return mock_directory / f"{url.split('/')[-1]}" + + with patch.object( + store, "_download_file", side_effect=mock_download_file + ): + # Test multi-threaded download + pqdm(urls, store._download_file, n_jobs=n_threads) # type: ignore + + # We make sure we reuse the token up to N threads + self.assertTrue(len(cloned_sessions) <= n_threads) + + self.assertEqual(len(downloaded_files), n_files) # 10 files downloaded + self.assertCountEqual(downloaded_files, urls) # All files accounted for + @pytest.mark.xfail( reason="Expected failure: Reproduces a bug (#610) that has not yet been fixed." @@ -135,7 +215,7 @@ def test_store_can_create_s3_fsspec_session(self): def test_earthaccess_file_getattr(): fs = fsspec.filesystem("memory") with fs.open("/foo", "wb") as f: - earthaccess_file = EarthAccessFile(f, granule="foo") + earthaccess_file = EarthAccessFile(f, granule="foo") # type: ignore assert f.tell() == earthaccess_file.tell() # cleanup fs.store.clear()