Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move message handlers to info logging level #644

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aepsych/benchmark/pathos_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

ctx._force_start_method("spawn") # fixes problems with CUDA and fork

logger = utils_logging.getLogger(logging.INFO)
logger = utils_logging.getLogger()


class PathosBenchmark(Benchmark):
Expand Down
230 changes: 106 additions & 124 deletions aepsych/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import datetime
import io
import json
import logging
import os
Expand All @@ -19,7 +20,7 @@
from aepsych.config import Config
from aepsych.strategy import Strategy
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.session import close_all_sessions

logger = logging.getLogger()
Expand All @@ -45,33 +46,22 @@ def __init__(self, db_path: Optional[str] = None, update: bool = True) -> None:
else:
logger.info(f"No DB found at {db_path}, creating a new DB!")

self._engine = self.get_engine()
self._full_db_path = Path(self._db_dir)
self._full_db_path.mkdir(parents=True, exist_ok=True)
self._full_db_path = self._full_db_path.joinpath(self._db_name)

if update and self.is_update_required():
self.perform_updates()

def get_engine(self) -> sessionmaker:
"""Get the engine for the database.

Returns:
sessionmaker: The sessionmaker object for the database.
"""
if not hasattr(self, "_engine") or self._engine is None:
self._full_db_path = Path(self._db_dir)
self._full_db_path.mkdir(parents=True, exist_ok=True)
self._full_db_path = self._full_db_path.joinpath(self._db_name)

self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}")
self._engine = create_engine(f"sqlite:///{self._full_db_path.as_posix()}")

# create the table metadata and tables
tables.Base.metadata.create_all(self._engine)
# create the table metadata and tables
tables.Base.metadata.create_all(self._engine)

# create an ongoing session to be used. Provides a conduit
# to the db so the instantiated objects work properly.
Session = sessionmaker(bind=self.get_engine())
self._session = Session()
# Create a session to be start and closed on each use
self.session = scoped_session(
sessionmaker(bind=self._engine, expire_on_commit=False)
)

return self._engine
if update and self.is_update_required():
self.perform_updates()

def delete_db(self) -> None:
"""Delete the database."""
Expand Down Expand Up @@ -106,21 +96,6 @@ def perform_updates(self) -> None:
tables.DbParamTable.update(self._engine)
tables.DbOutcomeTable.update(self._engine)

@contextmanager
def session_scope(self):
"""Provide a transactional scope around a series of operations."""
Session = sessionmaker(bind=self.get_engine())
session = Session()
try:
yield session
session.commit()
except Exception as err:
logger.error(f"db session use failed: {err}")
session.rollback()
raise
finally:
session.close()

# @retry(stop_max_attempt_number=8, wait_exponential_multiplier=1.8)
def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
"""Execute an arbitrary query written in sql.
Expand All @@ -132,7 +107,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
Returns:
List[Any]: The results of the query.
"""
with self.session_scope() as session:
with self.session() as session:
return session.execute(query, vals).all()

def get_master_records(self) -> List[tables.DBMasterTable]:
Expand All @@ -141,7 +116,8 @@ def get_master_records(self) -> List[tables.DBMasterTable]:
Returns:
List[tables.DBMasterTable]: The list of master records.
"""
records = self._session.query(tables.DBMasterTable).all()
with self.session() as session:
records = session.query(tables.DBMasterTable).all()
return records

def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
Expand All @@ -153,11 +129,12 @@ def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
Returns:
tables.DBMasterTable or None: The master record or None if it doesn't exist.
"""
records = (
self._session.query(tables.DBMasterTable)
.filter(tables.DBMasterTable.unique_id == master_id)
.all()
)
with self.session() as session:
records = (
session.query(tables.DBMasterTable)
.filter(tables.DBMasterTable.unique_id == master_id)
.all()
)

if 0 < len(records):
return records[0]
Expand Down Expand Up @@ -259,11 +236,7 @@ def get_params_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
raw_record = self.get_raw_for(master_id)

if raw_record is not None:
return [
rec.children_param
for rec in self.get_raw_for(master_id)
if rec is not None
]
return [raw.children_param for raw in raw_record]

return []

Expand All @@ -282,14 +255,19 @@ def get_outcomes_for(self, master_id: int) -> List[List[tables.DbParamTable]]:
raw_record = self.get_raw_for(master_id)

if raw_record is not None:
return [
rec.children_outcome
for rec in self.get_raw_for(master_id)
if rec is not None
]
return [raw.children_outcome for raw in raw_record]

return []

@staticmethod
def _add_commit(session, obj):
# Helps guarantee duplicated objects across session can still be written
merged = session.merge(obj)
session.add(merged)
session.commit()
session.refresh(merged)
return merged

def record_setup(
self,
description: str = None,
Expand All @@ -312,34 +290,36 @@ def record_setup(
Returns:
str: The experiment id.
"""
self.get_engine()

master_table = tables.DBMasterTable()
master_table.experiment_description = description
master_table.experiment_name = name
master_table.experiment_id = exp_id if exp_id is not None else str(uuid.uuid4())
master_table.participant_id = (
par_id if par_id is not None else str(uuid.uuid4())
)
master_table.extra_metadata = extra_metadata
self._session.add(master_table)
with self.session() as session:
master_table = tables.DBMasterTable()
master_table.experiment_description = description
master_table.experiment_name = name
master_table.experiment_id = (
exp_id if exp_id is not None else str(uuid.uuid4())
)
master_table.participant_id = (
par_id if par_id is not None else str(uuid.uuid4())
)
master_table.extra_metadata = extra_metadata

master_table = self._add_commit(session, master_table)

logger.debug(f"record_setup = [{master_table}]")
logger.debug(f"record_setup = [{master_table}]")

record = tables.DbReplayTable()
record.message_type = "setup"
record.message_contents = request
record = tables.DbReplayTable()
record.message_type = "setup"
record.message_contents = request

if request is not None and "extra_info" in request:
record.extra_info = request["extra_info"]
if request is not None and "extra_info" in request:
record.extra_info = request["extra_info"]

record.timestamp = datetime.datetime.now()
record.parent = master_table
logger.debug(f"record_setup = [{record}]")
record.timestamp = datetime.datetime.now()
record.parent = master_table
logger.debug(f"replay_record_setup = [{record}]")

self._session.add(record)
self._session.commit()
self._add_commit(session, record)

master_table
# return the master table if it has a link to the list of child rows
# tis needs to be passed into all future calls to link properly
return master_table
Expand All @@ -354,19 +334,19 @@ def record_message(
type (str): The type of the message.
request (Dict[str, Any]): The request.
"""
# create a linked setup table
record = tables.DbReplayTable()
record.message_type = type
record.message_contents = request
with self.session() as session:
# create a linked setup table
record = tables.DbReplayTable()
record.message_type = type
record.message_contents = request

if "extra_info" in request:
record.extra_info = request["extra_info"]
if "extra_info" in request:
record.extra_info = request["extra_info"]

record.timestamp = datetime.datetime.now()
record.parent = master_table
record.timestamp = datetime.datetime.now()
record.parent = master_table

self._session.add(record)
self._session.commit()
self._add_commit(session, record)

def record_raw(
self,
Expand All @@ -386,19 +366,19 @@ def record_raw(
Returns:
tables.DbRawTable: The raw entry.
"""
raw_entry = tables.DbRawTable()
raw_entry.model_data = model_data
with self.session() as session:
raw_entry = tables.DbRawTable()
raw_entry.model_data = model_data

if timestamp is None:
raw_entry.timestamp = datetime.datetime.now()
else:
raw_entry.timestamp = timestamp
raw_entry.parent = master_table
if timestamp is None:
raw_entry.timestamp = datetime.datetime.now()
else:
raw_entry.timestamp = timestamp
raw_entry.parent = master_table

raw_entry.extra_data = json.dumps(extra_data)
raw_entry.extra_data = json.dumps(extra_data)

self._session.add(raw_entry)
self._session.commit()
raw_entry = self._add_commit(session, raw_entry)

return raw_entry

Expand All @@ -412,14 +392,14 @@ def record_param(
param_name (str): The parameter name.
param_value (str): The parameter value.
"""
param_entry = tables.DbParamTable()
param_entry.param_name = param_name
param_entry.param_value = param_value
with self.session() as session:
param_entry = tables.DbParamTable()
param_entry.param_name = param_name
param_entry.param_value = param_value

param_entry.parent = raw_table
param_entry.parent = raw_table

self._session.add(param_entry)
self._session.commit()
self._add_commit(session, param_entry)

def record_outcome(
self, raw_table: tables.DbRawTable, outcome_name: str, outcome_value: float
Expand All @@ -431,29 +411,31 @@ def record_outcome(
outcome_name (str): The outcome name.
outcome_value (float): The outcome value.
"""
outcome_entry = tables.DbOutcomeTable()
outcome_entry.outcome_name = outcome_name
outcome_entry.outcome_value = outcome_value
with self.session() as session:
outcome_entry = tables.DbOutcomeTable()
outcome_entry.outcome_name = outcome_name
outcome_entry.outcome_value = outcome_value

outcome_entry.parent = raw_table
outcome_entry.parent = raw_table

self._session.add(outcome_entry)
self._session.commit()
self._add_commit(session, outcome_entry)

def record_strat(self, master_table: tables.DBMasterTable, strat: Strategy) -> None:
def record_strat(
self, master_table: tables.DBMasterTable, strat: io.BytesIO
) -> None:
"""Record a strategy in the database.

Args:
master_table (tables.DBMasterTable): The master table.
strat (Strategy): The strategy.
strat (BytesIO): The strategy in buffer form.
"""
strat_entry = tables.DbStratTable()
strat_entry.strat = strat
strat_entry.timestamp = datetime.datetime.now()
strat_entry.parent = master_table
with self.session() as session:
strat_entry = tables.DbStratTable()
strat_entry.strat = strat
strat_entry.timestamp = datetime.datetime.now()
strat_entry.parent = master_table

self._session.add(strat_entry)
self._session.commit()
self._add_commit(session, strat_entry)

def record_config(self, master_table: tables.DBMasterTable, config: Config) -> None:
"""Record a config in the database.
Expand All @@ -462,13 +444,13 @@ def record_config(self, master_table: tables.DBMasterTable, config: Config) -> N
master_table (tables.DBMasterTable): The master table.
config (Config): The config.
"""
config_entry = tables.DbConfigTable()
config_entry.config = config
config_entry.timestamp = datetime.datetime.now()
config_entry.parent = master_table
with self.session() as session:
config_entry = tables.DbConfigTable()
config_entry.config = config
config_entry.timestamp = datetime.datetime.now()
config_entry.parent = master_table

self._session.add(config_entry)
self._session.commit()
self._add_commit(session, config_entry)

def summarize_experiments(self) -> pd.DataFrame:
"""Provides a summary of the experiments contained in the database as a pandas dataframe.
Expand Down
Loading