Skip to content

Commit

Permalink
Unify AEPsych Model Mixin (#627)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #627

Everything should use the same model mixin now.

Reviewed By: crasanders

Differential Revision: D69194015

fbshipit-source-id: b5a8a1260f4e785760fd605deb853a64864434ff
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Feb 6, 2025
1 parent 0693d50 commit de4af54
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 187 deletions.
4 changes: 2 additions & 2 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from aepsych.config import Config, ConfigurableMixin
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from botorch.acquisition import (
AcquisitionFunction,
LogNoisyExpectedImprovement,
Expand All @@ -23,7 +23,7 @@

from ..models.model_protocol import ModelProtocol

AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychMixin)
AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychModelMixin)


@runtime_checkable
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine

Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
model: Optional[AEPsychModelMixin] = None, # included for API compatibility
fixed_features: Optional[Dict[int, float]] = None,
**kwargs, # Ignored
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
model (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
fixed_features (Dict[int, float], optional): Ignored, kept for consistent
API.
**kwargs: Ignored, API compatibility
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/random_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds


Expand Down Expand Up @@ -38,14 +38,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility.
model: Optional[AEPsychModelMixin] = None, # included for API compatibility.
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Currently, only 1 point can be queried at a time.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator.
model (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Ignored, API compatibility
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine

Expand Down Expand Up @@ -49,14 +49,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
model: Optional[AEPsychModelMixin] = None, # included for API compatibility
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
moodel (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
moodel (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Ignored, API compatibility
Returns:
Expand Down
Loading

0 comments on commit de4af54

Please sign in to comment.