Skip to content

Commit

Permalink
Stress test traci connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 20, 2024
1 parent 71dda54 commit 5c41b8b
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 47 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-base-tests-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ jobs:
--ignore=./smarts/core/tests/test_smarts_memory_growth.py \
--ignore=./smarts/core/tests/test_env_frame_rate.py \
--ignore=./smarts/env/tests/test_benchmark.py \
--ignore=./smarts/core/utils/tests/test_traci_port_acquisition.py \
-k 'not test_long_determinism'
examples-rl:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/ci-base-tests-mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ jobs:
--ignore=./smarts/core/tests/test_renderers.py \
--ignore=./smarts/core/tests/test_smarts.py \
--ignore=./smarts/core/tests/test_env_frame_rate.py \
--ignore=./smarts/core/tests/test_observations.py
--ignore=./smarts/core/tests/test_observations.py \
--ignore=./smarts/core/utils/tests/test_traci_port_acquisition.py
11 changes: 4 additions & 7 deletions smarts/core/sumo_traffic_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _initialize_traci_conn(self, num_retries=5):
)

try:
while self._traci_conn.viable and not self._traci_conn.connected:
while not self._traci_conn.connected:
try:
self._traci_conn.connect(
timeout=5,
Expand All @@ -224,15 +224,10 @@ def _initialize_traci_conn(self, num_retries=5):
except traci.exceptions.TraCIException:
# SUMO process died... unsure why this is not a fatal traci error
current_retries += 1

self._traci_conn.close_traci_and_pipes()
continue
except ConnectionRefusedError:
# Some other process somehow owns the port... sumo needs to be restarted.
continue
except OSError:
# TraCI or SUMO version are not at the minimum required version.
raise
except KeyboardInterrupt:
self._log.debug("Keyboard interrupted TraCI connection.")
self._traci_conn.close_traci_and_pipes()
Expand Down Expand Up @@ -378,7 +373,7 @@ def _handle_traci_exception(
self._handling_error = True
if isinstance(error, traci.exceptions.TraCIException):
# XXX: Needs further investigation whenever this happens.
self._log.warning("TraCI has provided a warning %s", error)
self._log.debug("TraCI has provided a warning %s", error)
return
if isinstance(error, traci.exceptions.FatalTraCIError):
self._log.error(
Expand Down Expand Up @@ -435,6 +430,8 @@ def teardown(self):
self._remove_vehicles()
except traci.exceptions.FatalTraCIError:
pass
if self._traci_conn is not None:
self._traci_conn.close_traci_and_pipes()

self._cumulative_sim_seconds = 0
self._non_sumo_vehicle_ids = set()
Expand Down
186 changes: 147 additions & 39 deletions smarts/core/utils/sumo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import os
import subprocess
import sys
from typing import Any, List, Optional, Tuple
from typing import Any, List, Literal, Optional, Tuple

from smarts.core.utils import networking
from smarts.core.utils.core_logging import suppress_output
Expand Down Expand Up @@ -60,16 +60,20 @@
class DomainWrapper:
"""Wraps `traci.Domain` type for the `TraciConn` utility"""

def __init__(self, sumo_proc, domain: traci.domain.Domain) -> None:
def __init__(self, traci_conn, domain: traci.domain.Domain, attribute_name) -> None:
self._domain = domain
self._sumo_proc = sumo_proc
self._traci_conn = traci_conn
self._attribute_name = attribute_name

def __getattr__(self, name: str) -> Any:
attribute = getattr(self._domain, name)

if inspect.isbuiltin(attribute) or inspect.ismethod(attribute):
attribute = functools.partial(
_wrap_traci_method, method=attribute, sumo_process=self._sumo_proc
_wrap_traci_method,
method=attribute,
traci_conn=self._traci_conn,
attribute_name=self._attribute_name,
)

return attribute
Expand All @@ -82,15 +86,22 @@ def __init__(
self,
sumo_port: Optional[int],
base_params: List[str],
sumo_binary: str = "sumo", # Literal["sumo", "sumo-gui"]
sumo_binary: Literal[
"sumo", "sumo-gui"
] = "sumo", # Literal["sumo", "sumo-gui"]
host: str = "localhost",
name: str = "",
):
self._sumo_proc = None
self._traci_conn = None
self._sumo_port = None
self._sumo_version: Tuple[int, ...] = tuple()
self._host = host
self._name = name
# self._log = logging
self._log = logging.Logger(self.__class__.__name__)
# self._log.setLevel(logging.ERROR)
self._connected = False

if sumo_port is None:
sumo_port = networking.find_free_port()
Expand Down Expand Up @@ -127,6 +138,12 @@ def connect(
"""Attempt a connection with the SUMO process."""
traci_conn = None
try:
# See if the process is still alive before attempting a connection.
if self._sumo_proc.poll() is not None:
raise traci.exceptions.TraCIException(
"TraCI server already finished before connection!!!"
)

with suppress_output(stderr=not debug, stdout=True):
traci_conn = traci.connect(
self._sumo_port,
Expand All @@ -136,47 +153,78 @@ def connect(
waitBetweenRetries=0.05,
) # SUMO must be ready within timeout seconds
# We will retry since this is our first sumo command
except traci.exceptions.FatalTraCIError:
self._log.debug("TraCI could not connect in time.")
except traci.exceptions.FatalTraCIError as err:
self._log.error(
"[%s] TraCI could not connect in time to '%s:%s' [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
# XXX: Actually not fatal...
raise
except traci.exceptions.TraCIException:
self._log.warning("SUMO process died.")
except traci.exceptions.TraCIException as err:
self._log.error(
"[%s] SUMO process died while trying to connect to '%s:%s' [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
self.close_traci_and_pipes(kill=True)
raise
except ConnectionRefusedError:
self._log.warning(
"Connection refused. Tried to connect to an unpaired TraCI client."
self._log.error(
"[%s] Intended TraCI server '%s:%s' refused connection.",
self._name,
self._host,
self._sumo_port,
)
self.close_traci_and_pipes(kill=True)
raise

self._connected = True
self._traci_conn = traci_conn
try:
vers, vers_str = traci_conn.getVersion()
if vers < minimum_traci_version:
raise OSError(
raise ValueError(
f"TraCI API version must be >= {minimum_traci_version}. Got version ({vers})"
)
self._sumo_version = tuple(
int(v) for v in vers_str.partition(" ")[2].split(".")
) # e.g. "SUMO 1.11.0" -> (1, 11, 0)
if self._sumo_version < minimum_sumo_version:
raise OSError(f"SUMO version must be >= SUMO {minimum_sumo_version}")
except (traci.exceptions.FatalTraCIError, TypeError) as err:
logging.error(
"TraCI disconnected from '%s:%s', process may have died.",
raise ValueError(f"SUMO version must be >= SUMO {minimum_sumo_version}")
except (traci.exceptions.FatalTraCIError) as err:
self._log.error(
"[%s] TraCI disconnected for connection attempt '%s:%s': [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
# XXX: the error type is changed to TraCIException to make it consistent with the
# process died case of `traci.connect`.
# process died case of `traci.connect`. Since TraCIException is fatal just in this case...
self.close_traci_and_pipes(kill=True)
raise traci.exceptions.TraCIException(err)
except OSError as err:
self._log.error(
"[%s] OS error occurred for TraCI connection attempt '%s:%s': [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
self.close_traci_and_pipes(kill=True)
raise traci.exceptions.TraCIException(err)
except OSError:
except ValueError:
self.close_traci_and_pipes()
raise
self._traci_conn = traci_conn

@property
def connected(self) -> bool:
"""Check if the connection is still valid."""
return self._sumo_proc is not None and self._traci_conn is not None
return self._sumo_proc is not None and self._connected

@property
def viable(self) -> bool:
Expand All @@ -188,32 +236,51 @@ def sumo_version(self) -> Tuple[int, ...]:
"""Get the current SUMO version as a tuple."""
return self._sumo_version

@property
def port(self) -> Optional[int]:
"""Get the used TraCI port."""
return self._sumo_port

@property
def hostname(self) -> Optional[int]:
"""Get the used TraCI port."""
return self._host

def __getattr__(self, name: str) -> Any:
if not self.connected:
return None
raise traci.exceptions.FatalTraCIError("TraCI died.")

attribute = getattr(self._traci_conn, name)

if inspect.isbuiltin(attribute) or inspect.ismethod(attribute):
attribute = functools.partial(
_wrap_traci_method, method=attribute, sumo_process=self
_wrap_traci_method,
method=attribute,
attribute_name=name,
traci_conn=self,
)

if isinstance(attribute, traci.domain.Domain):
attribute = DomainWrapper(sumo_proc=self, domain=attribute)
elif isinstance(attribute, traci.domain.Domain):
attribute = DomainWrapper(
traci_conn=self, domain=attribute, attribute_name=name
)
else:
raise NotImplementedError()

return attribute

def must_reset(self):
"""If the version of sumo will have errors if just reloading such that it must be reset."""
return self._sumo_version > (1, 12, 0)

def close_traci_and_pipes(self):
def close_traci_and_pipes(self, wait: Optional[float] = 0, kill: bool = False):
"""Safely closes all connections. We should expect this method to always work without throwing"""
assert wait is None or isinstance(wait, (int, float))
if isinstance(wait, (int, float)):
wait = max(0.0, wait)

def __safe_close(conn):
def __safe_close(conn, **kwargs):
try:
conn.close()
conn.close(**kwargs)
except (subprocess.SubprocessError, multiprocessing.ProcessError):
# Subprocess or process failed
pass
Expand All @@ -223,33 +290,74 @@ def __safe_close(conn):
except AttributeError:
# Socket was destroyed internally, likely due to an error.
pass
except Exception as err:
self._log.error("Different error occurred: [%s]", err)

if self._traci_conn:
if self._connected:
self._log.debug("Closing TraCI connection to %s", self._sumo_port)
__safe_close(self._traci_conn)
__safe_close(self._traci_conn, wait=bool(wait))

if self._sumo_proc:
__safe_close(self._sumo_proc.stdin)
__safe_close(self._sumo_proc.stdout)
__safe_close(self._sumo_proc.stderr)
self._sumo_proc.kill()
if wait:
try:
self._sumo_proc.wait(timeout=wait)
except subprocess.TimeoutExpired as err:
kill = True
self._log.error(
"TraCI server process shutdown timed out '%s:%s' [%s]",
self._host,
self._sumo_port,
err,
)
if kill:
self._sumo_proc.kill()
self._sumo_proc = None
self._log.error(
"Killed TraCI server process '%s:%s", self._host, self._sumo_port
)

self._sumo_proc = None
self._traci_conn = None
self._connected = False

def teardown(self):
"""Clean up all resources."""
self.close_traci_and_pipes()


def _wrap_traci_method(*args, method, sumo_process: TraciConn, **kwargs):
def _wrap_traci_method(
*args, method, traci_conn: TraciConn, attribute_name: str, **kwargs
):
# Argument order must be `*args` first so `method` and `sumo_process` are keyword only arguments.
try:
return method(*args, **kwargs)
except traci.exceptions.FatalTraCIError:
except traci.exceptions.FatalTraCIError as err:
logging.error(
"[%s] TraCI '%s:%s' disconnected for call '%s', process may have died: [%s]",
traci_conn._name,
traci_conn.hostname,
traci_conn.port,
attribute_name,
err,
)
# TraCI cannot continue
sumo_process.close_traci_and_pipes()
raise
except traci.exceptions.TraCIException:
traci_conn.close_traci_and_pipes(kill=True)
raise traci.exceptions.FatalTraCIError("TraCI died.") from err
except OSError as err:
logging.error(
"[%s] OS error occurred for TraCI '%s:%s' call '%s': [%s]",
traci_conn._name,
traci_conn.hostname,
traci_conn.port,
attribute_name,
err,
)
traci_conn.close_traci_and_pipes(kill=True)
raise OSError("Connection dropped.") from err
except traci.exceptions.TraCIException as err:
# Case where TraCI/SUMO can theoretically continue
raise traci.exceptions.TraCIException("TraCI can continue.") from err
except KeyboardInterrupt:
traci_conn.close_traci_and_pipes(kill=True)
raise
Loading

0 comments on commit 5c41b8b

Please sign in to comment.