Skip to content

Commit

Permalink
Resolve all errors except for those in gpu.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung committed May 1, 2024
1 parent c20a6d3 commit b9bf4f2
Show file tree
Hide file tree
Showing 24 changed files with 134 additions and 92 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ Documentation = "https://ml.energy/zeus"
# One day FastAPI will drop support for Pydantic V1. Then fastapi has to be pinned as well.
perseus = ["fastapi[all]", "pydantic<2", "lowtime", "aiofiles", "httpx"]
bso = ["pydantic<2", "httpx"]
bso-server = ["fastapi[all]","sqlalchemy","pydantic<2"]
bso-server = ["fastapi[all]", "sqlalchemy", "pydantic<2", "python-dotenv"]
migration = ["alembic", "sqlalchemy", "pydantic<2", "python-dotenv"]
lint = ["ruff", "black==22.6.0", "pandas-stubs"]
test = ["fastapi[all]","sqlalchemy","pydantic<2", "httpx", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1", "anyio==3.7.1", "aiosqlite==0.20.0"]
lint = ["ruff", "black==22.6.0", "pyright", "pandas-stubs", "zeus-ml[perseus,bso,bso-server]"]
test = ["fastapi[all]", "sqlalchemy", "pydantic<2", "httpx", "pytest==7.3.2", "pytest-mock==3.10.0", "pytest-xdist==3.3.1", "anyio==3.7.1", "aiosqlite==0.20.0"]
docs = ["mkdocs-material[imaging]==9.5.19", "mkdocstrings[python]==0.25.0", "mkdocs-gen-files==0.5.0", "mkdocs-literate-nav==0.6.1", "mkdocs-section-index==0.3.9", "urllib3<2", "black"]
# greenlet is for supporting apple mac silicon for sqlalchemy(https://docs.sqlalchemy.org/en/20/faq/installation.html)
dev = ["zeus-ml[lint,test,docs,perseus,bso-server]", "greenlet"]
dev = ["zeus-ml[lint,test,docs,perseus,bso,bso-server,migration]", "greenlet"]

[tool.setuptools.packages.find]
where = ["."]
Expand Down
4 changes: 2 additions & 2 deletions zeus/_legacy/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,12 @@ def compute_energy_and_time(
# Also there are all runs 1, 2, ... included, but power info is actually
# completely duplicated across different runs in the DataFrame.
# Thus, taking the mean across the entire power_df gets us what we want.
energy_first_epoch = power_df.energy_per_epoch.mean().item()
energy_first_epoch = power_df.energy_per_epoch.mean().item() # type: ignore
energy_from_second_epoch = path.energy_per_epoch.item() * (
num_epochs - 1
)
energy_consumption = energy_first_epoch + energy_from_second_epoch
time_first_epoch = power_df.time_per_epoch.mean().item()
time_first_epoch = power_df.time_per_epoch.mean().item() # type: ignore
time_from_second_epoch = path.time_per_epoch.item() * (num_epochs - 1)
time_consumption = time_first_epoch + time_from_second_epoch
# Just run num_epochs with the given power limit. Simple.
Expand Down
4 changes: 2 additions & 2 deletions zeus/optimizer/batch_size/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
raise ZeusBSOConfigError("No GPUs detected.")

# set gpu configurations(max_power, number of gpus, and gpu model)
self.job = JobSpecFromClient(
self.job_spec = JobSpecFromClient(
**job.dict(),
max_power=gpus.getPowerManagementLimitConstraints(0)[1]
// 1000
Expand All @@ -85,7 +85,7 @@ def __init__(
self.current_batch_size = 0

# Register job
res = httpx.post(self.server_url + REGISTER_JOB_URL, content=self.job.json())
res = httpx.post(self.server_url + REGISTER_JOB_URL, content=self.job_spec.json())
self._handle_response(res)

self.job = CreatedJob.parse_obj(res.json())
Expand Down
14 changes: 7 additions & 7 deletions zeus/optimizer/batch_size/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Dict, Optional
from typing import Any, Optional

from zeus.utils.pydantic_v1 import BaseModel, root_validator, validator, Field

Expand Down Expand Up @@ -53,21 +53,21 @@ class JobParams(BaseModel):
mab_seed: Optional[int] = None

@validator("batch_sizes")
def _validate_batch_sizes(cls, bs: list[int]) -> int:
def _validate_batch_sizes(cls, bs: list[int]) -> list[int]:
if bs is not None and len(bs) > 0:
bs.sort()
return bs
else:
raise ValueError(f"Batch Sizes = {bs} is empty")

@validator("eta_knob")
def _validate_eta_knob(cls, v: float) -> int:
def _validate_eta_knob(cls, v: float) -> float:
if v < 0 or v > 1:
raise ValueError("eta_knob should be in range [0,1]")
return v

@validator("beta_knob")
def _validate_beta_knob(cls, v: float) -> int:
def _validate_beta_knob(cls, v: float) -> float:
if v is None or v > 0:
return v
else:
Expand All @@ -76,7 +76,7 @@ def _validate_beta_knob(cls, v: float) -> int:
)

@root_validator(skip_on_failure=True)
def _check_default_batch_size(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def _check_default_batch_size(cls, values: dict[str, Any]) -> dict[str, Any]:
bs = values["default_batch_size"]
bss = values["batch_sizes"]
if bs not in bss:
Expand Down Expand Up @@ -108,10 +108,10 @@ class JobSpec(JobParams):
Refer [`JobParams`][`zeus.optimizer.batch_size.common.JobParams`] for other attributes.
"""

job_id: Optional[str]
job_id: Optional[str] # pyright: ignore[reportIncompatibleVariableOverride]

@root_validator(skip_on_failure=True)
def _check_job_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def _check_job_id(cls, values: dict[str, Any]) -> dict[str, Any]:
job_id: str | None = values.get("job_id")
prefix: str = values["job_id_prefix"]

Expand Down
2 changes: 1 addition & 1 deletion zeus/optimizer/batch_size/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def run_migrations_online() -> None:
"""
connectable = AsyncEngine(
engine_from_config(
config.get_section(config.config_ini_section),
config.get_section(config.config_ini_section), # type: ignore
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
Expand Down
28 changes: 14 additions & 14 deletions zeus/optimizer/batch_size/server/batch_size_state/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class CreateTrialBase(BatchSizeBase):
"""Base command to create trial."""

type: TrialType
start_timestamp: datetime = Field(datetime.now(), const=True)
status: TrialStatus = Field(TrialStatus.Dispatched, const=True)
start_timestamp: datetime = Field(default_factory=datetime.now)
status: TrialStatus = Field(default=TrialStatus.Dispatched, const=True)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -52,7 +52,7 @@ class CreateTrial(CreateTrialBase):

trial_number: int = Field(gt=0)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -76,9 +76,9 @@ def to_orm(self) -> TrialTable:
class CreateExplorationTrial(CreateTrialBase):
"""Create a exploration."""

type: TrialType = Field(TrialType.Exploration, const=True)
type: TrialType = Field(default=TrialType.Exploration, const=True)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -90,9 +90,9 @@ class Config:
class CreateMabTrial(CreateTrialBase):
"""Create a MAB trial."""

type: TrialType = Field(TrialType.MAB, const=True)
type: TrialType = Field(default=TrialType.MAB, const=True)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -104,9 +104,9 @@ class Config:
class CreateConcurrentTrial(CreateTrialBase):
"""Create a exploration."""

type: TrialType = Field(TrialType.Concurrent, const=True)
type: TrialType = Field(default=TrialType.Concurrent, const=True)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -119,13 +119,13 @@ class UpdateTrial(BatchSizeBase):
"""Report the result of trial."""

trial_number: int = Field(gt=0)
end_timestamp: datetime = Field(datetime.now(), const=True)
end_timestamp: datetime = Field(default_factory=datetime.now, const=True)
status: TrialStatus
time: Optional[float] = Field(None, ge=0)
energy: Optional[float] = Field(None, ge=0)
time: Optional[float] = Field(default=None, ge=0)
energy: Optional[float] = Field(default=None, ge=0)
converged: Optional[bool] = None

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand Down
14 changes: 7 additions & 7 deletions zeus/optimizer/batch_size/server/batch_size_state/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BatchSizeBase(BaseModel):
job_id: str
batch_size: int = Field(gt=0)

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand Down Expand Up @@ -58,7 +58,7 @@ class Trial(BatchSizeBase):
energy: Optional[float] = Field(None, ge=0)
converged: Optional[bool] = None

class Config:
class Config: # type: ignore
"""Model configuration.
Enable instantiating model from an ORM object, and make it immutable after it's created.
Expand Down Expand Up @@ -117,7 +117,7 @@ class GaussianTsArmState(BatchSizeBase):
reward_precision: float
num_observations: int = Field(ge=0)

class Config:
class Config: # type: ignore
"""Model configuration.
Enable instantiating model from an ORM object, and make it immutable after it's created.
Expand Down Expand Up @@ -154,7 +154,7 @@ class TrialResult(BatchSizeBase):
energy: float = Field(ge=0)
converged: bool

class Config:
class Config: # type: ignore
"""Model configuration.
Enable instantiating model from an ORM object, and make it immutable after it's created.
Expand Down Expand Up @@ -217,7 +217,7 @@ class ExplorationsPerJob(BaseModel):
job_id: str
explorations_per_bs: dict[int, list[Trial]] # BS -> Trials with exploration type

class Config:
class Config: # type: ignore
"""Model configuration.
Make it immutable after it's created.
Expand All @@ -237,11 +237,11 @@ def _check_explorations(cls, values: dict[str, Any]) -> dict[str, Any]:
for exp in exps:
if job_id != exp.job_id:
raise ValueError(
f"job_id doesn't correspond with explorations: {job_id} != {exps.job_id}"
f"job_id doesn't correspond with explorations: {job_id} != {exp.job_id}"
)
if bs != exp.batch_size:
raise ValueError(
f"Batch size doesn't correspond with explorations: {bs} != {exps.batch_size}"
f"Batch size doesn't correspond with explorations: {bs} != {exp.batch_size}"
)
if exp.type != TrialType.Exploration:
raise ValueError("Trial type is not equal to Exploration.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class BatchSizeStateRepository(DatabaseRepository):
def __init__(self, session: AsyncSession):
"""Set db session and intialize fetched trial. We are only updating one trial per session."""
super().__init__(session)
self.fetched_trial: Trial | None = None
self.fetched_arm: GaussianTsArmState | None = None
self.fetched_trial: TrialTable | None = None
self.fetched_arm: GaussianTsArmStateTable | None = None

async def get_next_trial_number(self, job_id: str) -> int:
"""Get next trial number of a given job. Trial number starts from 1 and increase by 1 at a time."""
Expand Down Expand Up @@ -152,12 +152,13 @@ async def get_trial(self, trial: ReadTrial) -> Trial | None:
def get_trial_from_session(self, trial: ReadTrial) -> Trial | None:
"""Fetch a trial from the session."""
if (
self.fetched_trial.job_id != trial.job_id
self.fetched_trial is None
or self.fetched_trial.job_id != trial.job_id
or self.fetched_trial.batch_size != trial.batch_size
or self.fetched_trial.trial_number != trial.trial_number
):
return None
return self.fetched_trial
return Trial.from_orm(self.fetched_trial)

def create_trial(self, trial: CreateTrial) -> None:
"""Create a trial in db.
Expand Down
6 changes: 3 additions & 3 deletions zeus/optimizer/batch_size/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ZeusBsoSettings(BaseSettings):
echo_sql: Union[bool, str] = False # To prevent conversion error for empty string
log_level: str = "INFO"

class Config:
class Config: # type: ignore
"""Model configuration.
Set how to find the env variables and how to parse it.
Expand All @@ -42,7 +42,7 @@ def _validate_echo_sql(cls, v) -> bool:
return False

@validator("log_level")
def _validate_log_level(cls, v) -> bool:
def _validate_log_level(cls, v) -> str:
if v is None or v not in {
"NOTSET",
"DEBUG",
Expand All @@ -56,4 +56,4 @@ def _validate_log_level(cls, v) -> bool:
return v


settings = ZeusBsoSettings()
settings = ZeusBsoSettings() # type: ignore
4 changes: 3 additions & 1 deletion zeus/optimizer/batch_size/server/database/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
and https://medium.com/@tclaitken/setting-up-a-fastapi-app-with-async-sqlalchemy-2-0-pydantic-v2-e6c540be4308
"""

from __future__ import annotations

import contextlib
from typing import Any, AsyncIterator

Expand All @@ -20,7 +22,7 @@
class DatabaseSessionManager:
"""Session manager class."""

def __init__(self, host: str, engine_kwargs: dict[str, Any] = None):
def __init__(self, host: str, engine_kwargs: dict[str, Any] | None = None):
"""Create async engine and session maker."""
if engine_kwargs is None:
engine_kwargs = {}
Expand Down
6 changes: 5 additions & 1 deletion zeus/optimizer/batch_size/server/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from zeus.optimizer.batch_size.server.batch_size_state.models import ExplorationsPerJob
from zeus.optimizer.batch_size.server.database.schema import TrialStatus
from zeus.optimizer.batch_size.server.exceptions import ZeusBSOServerRuntimeError
from zeus.optimizer.batch_size.server.exceptions import ZeusBSOServerRuntimeError, ZeusBSOValueError
from zeus.optimizer.batch_size.server.job.models import JobState
from zeus.optimizer.batch_size.server.services.service import ZeusService
from zeus.utils.logging import get_logger
Expand Down Expand Up @@ -91,6 +91,10 @@ async def next_batch_size(
converged_bs_list.append(bs)

m = exploration_history.explorations_per_bs[bs][round]
if m.energy is None or m.time is None:
raise ZeusBSOValueError(
"Energy or time is not available for the exploration."
)
cost = zeus_cost(
m.energy, m.time, job.eta_knob, job.max_power
)
Expand Down
5 changes: 3 additions & 2 deletions zeus/optimizer/batch_size/server/job/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def _validate_states(cls, values: dict[str, Any]) -> dict[str, Any]:

return values

def from_job_config(js: JobSpecFromClient) -> "CreateJob":
@classmethod
def from_job_config(cls, js: JobSpecFromClient) -> "CreateJob":
"""From JobConfig, instantiate `CreateJob`.
Initialize generator state, exp_default_batch_size, and min_cost_batch_size.
Expand All @@ -145,7 +146,7 @@ def from_job_config(js: JobSpecFromClient) -> "CreateJob":
rng = np.random.default_rng(js.mab_seed)
d["mab_random_generator_state"] = json.dumps(rng.__getstate__())
d["min_cost_batch_size"] = js.default_batch_size
return CreateJob.parse_obj(d)
return cls.parse_obj(d)

def to_orm(self) -> JobTable:
"""Convert pydantic model `CreateJob` to ORM object Job."""
Expand Down
2 changes: 1 addition & 1 deletion zeus/optimizer/batch_size/server/job/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Stage(Enum):
class JobGetter(GetterDict):
"""Getter for batch size to convert ORM batch size object to integer."""

def get(self, key: str, default: Any) -> Any:
def get(self, key: str, default: Any = None) -> Any:
"""Get value from dict."""
if key == "batch_sizes":
# If the key is batch_sizes, parse the integer from object.
Expand Down
2 changes: 1 addition & 1 deletion zeus/optimizer/batch_size/server/job/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_job_from_session(self, job_id: str) -> JobState | None:
"""
if self.fetched_job is None or self.fetched_job.job_id != job_id:
return None
return self.fetched_job
return JobState.from_orm(self.fetched_job)

def update_exp_default_bs(self, updated_bs: UpdateExpDefaultBs) -> None:
"""Update exploration default batch size on fetched job.
Expand Down
Loading

0 comments on commit b9bf4f2

Please sign in to comment.