Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-local Session Management and Cookie Reuse to Address EDL DSE issue #909

Merged
merged 13 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import threading
import traceback
from functools import lru_cache
from itertools import chain
Expand All @@ -17,7 +18,7 @@

import earthaccess

from .auth import Auth
from .auth import Auth, SessionWithHeaderRedirection
from .daac import DAAC_TEST_URLS, find_provider
from .results import DataGranule
from .search import DataCollections
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
Parameters:
auth: Auth instance to download and access data.
"""
self.thread_locals = threading.local()
if auth.authenticated is True:
self.auth = auth
self._s3_credentials: Dict[
Expand All @@ -126,7 +128,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
oauth_profile = f"https://{auth.system.edl_hostname}/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
self.set_requests_session(oauth_profile)
self.set_requests_session(oauth_profile, bearer_token=True)
if pre_authorize:
# collect cookies from other DAACs
for url in DAAC_TEST_URLS:
Expand Down Expand Up @@ -182,7 +184,7 @@ def _running_in_us_west_2(self) -> bool:
return False

def set_requests_session(
self, url: str, method: str = "get", bearer_token: bool = False
self, url: str, method: str = "get", bearer_token: bool = True
) -> None:
"""Sets up a `requests` session with bearer tokens that are used by CMR.

Expand Down Expand Up @@ -323,19 +325,19 @@ def get_fsspec_session(self) -> fsspec.AbstractFileSystem:
session = fsspec.filesystem("https", client_kwargs=client_kwargs)
return session

def get_requests_session(self, bearer_token: bool = True) -> requests.Session:
def get_requests_session(self) -> SessionWithHeaderRedirection:
"""Returns a requests HTTPS session with bearer tokens that are used by CMR.

This HTTPS session can be used to download granules if we want to use a direct,
lower level API.

Parameters:
bearer_token: if true, will be used for authenticated queries on CMR

Returns:
requests Session
"""
return self.auth.get_session()
if hasattr(self, "_http_session"):
return self._http_session
else:
raise AttributeError("The requests session hasn't been set up yet.")

def open(
self,
Expand Down Expand Up @@ -651,6 +653,27 @@ def _get_granules(
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _clone_session_in_local_thread(
self, original_session: SessionWithHeaderRedirection
) -> None:
"""Clone the original session and store it in the local thread context.

This method creates a new session that replicates the headers, cookies, and authentication settings
from the provided original session. The new session is stored in a thread-local storage.

Parameters:
original_session (SessionWithHeaderRedirection): The session to be cloned.

Returns:
None
"""
if not hasattr(self.thread_locals, "local_thread_session"):
local_thread_session = SessionWithHeaderRedirection()
local_thread_session.headers.update(original_session.headers)
local_thread_session.cookies.update(original_session.cookies)
local_thread_session.auth = original_session.auth
self.thread_locals.local_thread_session = local_thread_session

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.

Expand All @@ -668,7 +691,11 @@ def _download_file(self, url: str, directory: Path) -> str:
path = directory / Path(local_filename)
if not path.exists():
try:
session = self.auth.get_session()
original_session = self.get_requests_session()
# This reuses the auth cookie, we make sure we only authenticate N threads instead
# of one per file, see #913
self._clone_session_in_local_thread(original_session)
session = self.thread_locals.local_thread_session
with session.get(url, stream=True, allow_redirects=True) as r:
r.raise_for_status()
with open(path, "wb") as f:
Expand Down
82 changes: 81 additions & 1 deletion tests/unit/test_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# package imports
import os
import threading
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import fsspec
import pytest
import responses
import s3fs
from earthaccess import Auth, Store
from earthaccess.auth import SessionWithHeaderRedirection
from earthaccess.store import EarthAccessFile
from pqdm.threads import pqdm


class TestStoreSessions(unittest.TestCase):
Expand Down Expand Up @@ -128,14 +133,89 @@ def test_store_can_create_s3_fsspec_session(self):

return None

@responses.activate
def test_session_reuses_token_download(self):
mock_creds = {
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
}
test_cases = [
(2, 500), # 2 threads, 500 files
(4, 400), # 4 threads, 400 files
(8, 5000), # 8 threads, 5k files
]
for n_threads, n_files in test_cases:
with self.subTest(n_threads=n_threads, n_files=n_files):
urls = [f"https://example.com/file{i}" for i in range(1, n_files + 1)]
for i, url in enumerate(urls):
responses.add(
responses.GET, url, body=f"Content of file {i + 1}", status=200
)

mock_auth = MagicMock()
mock_auth.authenticated = True
mock_auth.system.edl_hostname = "urs.earthdata.nasa.gov"
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json=mock_creds,
status=200,
)

original_session = SessionWithHeaderRedirection()
original_session.cookies.set("sessionid", "mocked-session-cookie")
mock_auth.get_session.return_value = original_session

store = Store(auth=mock_auth)
store.thread_locals = threading.local() # Use real thread-local storage

# Track cloned sessions
cloned_sessions = set()

def mock_clone_session_in_local_thread(original_session):
"""Mock session cloning to track cloned sessions."""
if not hasattr(store.thread_locals, "local_thread_session"):
session = SessionWithHeaderRedirection()
session.cookies.update(original_session.cookies)
cloned_sessions.add(id(session))
store.thread_locals.local_thread_session = session

with patch.object(
store,
"_clone_session_in_local_thread",
side_effect=mock_clone_session_in_local_thread,
):
mock_directory = Path("/mock/directory")
downloaded_files = []

def mock_download_file(url):
"""Mock file download to track downloaded files."""
# Ensure session cloning happens before downloading
store._clone_session_in_local_thread(original_session)
downloaded_files.append(url)
return mock_directory / f"{url.split('/')[-1]}"

with patch.object(
store, "_download_file", side_effect=mock_download_file
):
# Test multi-threaded download
pqdm(urls, store._download_file, n_jobs=n_threads) # type: ignore

# We make sure we reuse the token up to N threads
self.assertTrue(len(cloned_sessions) <= n_threads)

self.assertEqual(len(downloaded_files), n_files) # 10 files downloaded
self.assertCountEqual(downloaded_files, urls) # All files accounted for


@pytest.mark.xfail(
reason="Expected failure: Reproduces a bug (#610) that has not yet been fixed."
)
def test_earthaccess_file_getattr():
fs = fsspec.filesystem("memory")
with fs.open("/foo", "wb") as f:
earthaccess_file = EarthAccessFile(f, granule="foo")
earthaccess_file = EarthAccessFile(f, granule="foo") # type: ignore
assert f.tell() == earthaccess_file.tell()
# cleanup
fs.store.clear()
Loading