diff --git a/picas/actors.py b/picas/actors.py index e79b9a1..a15d71a 100644 --- a/picas/actors.py +++ b/picas/actors.py @@ -5,6 +5,7 @@ @author: Jan Bot, Joris Borgdorff """ +import ssl import logging import signal import subprocess @@ -48,6 +49,10 @@ def __init__(self, db, iterator=None, view='todo', token_reset_values=[0, 0], ** else: self.iterator = iterator + def reconnect(self): + self.db = self.db.copy() + self.iterator.reconnect(self.db) + def _run(self, task, timeout): """ Execution of the work on the iterator used in the run method. @@ -77,8 +82,21 @@ def _run(self, task, timeout): log.info(msg) new_task = self.db.get(task.id) task['_rev'] = new_task.rev + except ssl.SSLEOFError as ex: + # SSLEOFError can occur for long-lived connections, re-establish connection + msg = f"Warning: {type(ex)} occurred while saving task to database: " + \ + "Trying ro reconnect to database" + log.info(msg) + self.reconnect() + try: + self.db.save(task) + except Exception as ex: + msg = f"Error: {type(ex)} occurred while saving task to database: " + \ + "Not able to reconnect to database" + log.info(msg) + raise except Exception as ex: - # re-raise Exception + # re-raise unknown exception, this will terminate the iterator msg = f"Error: {type(ex)} occurred while saving task to database: {ex}" log.info(msg) raise diff --git a/picas/iterators.py b/picas/iterators.py index b13eb5b..36b1f78 100644 --- a/picas/iterators.py +++ b/picas/iterators.py @@ -57,6 +57,10 @@ def claim_task(self): """Get the first available task from a view.""" raise NotImplementedError("claim_task function not implemented.") + def reconnect(self, database): + """Reconnect to database""" + self.database = database + def _claim_task(database, view, allowed_failures=10, **view_params): for _ in range(allowed_failures): diff --git a/tests/test_actors.py b/tests/test_actors.py index a61a8b0..be05431 100644 --- a/tests/test_actors.py +++ b/tests/test_actors.py @@ -3,6 +3,7 @@ import subprocess import time import unittest +import ssl from test_mock import MockDB, MockEmptyDB from unittest.mock import patch @@ -72,6 +73,16 @@ def test_run_resourceconflict(self, mock_save): runner._run(task=Task({'_id': 'c', 'lock': None, 'done': None}), timeout=None) self.assertEqual(runner.tasks_processed, 1) + @patch('test_mock.MockDB.save') + def test_run_ssleoferror(self, mock_save): + """ + Test the _run function, in case the DB throws a an SSLEOFError + """ + with pytest.raises(ssl.SSLEOFError): + mock_save.side_effect = ssl.SSLEOFError + runner = ExampleRun(self._callback) + runner._run(task=Task({'_id': 'c', 'lock': None, 'done': None}), timeout=None) + @patch('test_mock.MockDB.save') def test_run_exception(self, mock_save): """ diff --git a/tests/test_mock.py b/tests/test_mock.py index 265ead6..5a97d81 100644 --- a/tests/test_mock.py +++ b/tests/test_mock.py @@ -41,6 +41,9 @@ def save(self, doc): return doc + def copy(self): + return self + class MockEmptyDB(MockDB): TASKS = []