Skip to content

Commit

Permalink
make db models mutable -SupportedDBs, DataConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
gargimaheshwari committed Jun 6, 2024
1 parent 4166f5a commit b22fbcd
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 38 deletions.
5 changes: 2 additions & 3 deletions build/lib/lyzr/data_analyzr/analyzr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from lyzr.base.prompt import LyzrPromptFactory
from lyzr.base.llm import LyzrLLMFactory, LiteLLM
from lyzr.data_analyzr.db_models import SupportedDBs, DataConfig


class DataAnalyzr:
Expand Down Expand Up @@ -228,16 +227,16 @@ def get_data(
vector_store_config=vector_store_config
)
"""
from pydantic import TypeAdapter
from lyzr.data_analyzr.file_utils import get_db_details
from lyzr.data_analyzr.db_models import DataConfig, SupportedDBs

if not isinstance(db_config, dict):
raise ValueError("data_config must be a dictionary.")
db_config["db_type"] = SupportedDBs(db_type.lower().strip())
self.database_connector, self.df_dict, self.vector_store = get_db_details(
analysis_type=self.analysis_type,
db_type=db_config["db_type"],
db_config=TypeAdapter(DataConfig).validate_python(db_config),
db_config=DataConfig.validate(db_config),
vector_store_config=VectorStoreConfig(**vector_store_config),
logger=self.logger,
)
Expand Down
72 changes: 58 additions & 14 deletions build/lib/lyzr/data_analyzr/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

# standart-library imports
import warnings
from enum import Enum
from typing import Annotated, Union
from aenum import Enum
from typing import Union

# third-party imports
import pandas as pd
from pydantic import BaseModel, Field, Discriminator, Tag, AliasChoices, ConfigDict
from pydantic import BaseModel, Field, AliasChoices, ConfigDict


warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -62,17 +62,60 @@ class SQLiteConfig(BaseModel):
db_path: str


DataConfig = Annotated[
Union[
Annotated[FilesConfig, Tag(SupportedDBs.files)],
Annotated[RedshiftConfig, Tag(SupportedDBs.redshift)],
Annotated[PostgresConfig, Tag(SupportedDBs.postgres)],
Annotated[SQLiteConfig, Tag(SupportedDBs.sqlite)],
],
Discriminator(lambda x: x["db_type"]),
]
class DynamicConfigUnion:
"""
A class to manage dynamic configuration types.
It allows for the registration of configuration types and
creation of instances based on the discriminator key.
To be used as a mutable Union type for validation with Pydantic models.
Procedure:
1. Create an instance of the `DynamicConfigUnion` class.
2. Register configuration types using the `add_config_type` method.
3. Validate the data using the `validate` method.
4. The `validate` method returns an instance of the appropriate configuration type.
Usage:
config = DynamicConfigUnion()
config.add_config_type("db_type", DBConfig)
config.validate({"db_type": "db_type", "key": "value"})
"""

def __init__(self, key_name: str = None):
self._config_types: dict[str, BaseModel] = {}
self._discriminator_key = key_name or "db_type"

def add_config_type(self, name, config_type: BaseModel):
if name not in self._config_types:
self._config_types[name] = config_type
elif config_type != self._config_types[name]:
raise ValueError(
f"Config type with name '{name}' is already registered with {self._config_types[name]}"
)

def _create_instance(self, name, **kwargs) -> BaseModel:
if name not in self._config_types:
raise ValueError(f"Config type with name '{name}' is not registered")
config_type = self._config_types[name]
return config_type(**kwargs)

def validate(self, data: dict) -> BaseModel:
if self._discriminator_key not in data:
raise ValueError(
f"Data must contain the discriminator key '{self._discriminator_key}'"
)
name = data[self._discriminator_key]
return self._create_instance(name, **data)


DataConfig = DynamicConfigUnion()
DataConfig.add_config_type(SupportedDBs.files, FilesConfig)
DataConfig.add_config_type(SupportedDBs.redshift, RedshiftConfig)
DataConfig.add_config_type(SupportedDBs.postgres, PostgresConfig)
DataConfig.add_config_type(SupportedDBs.sqlite, SQLiteConfig)

"""
Union type for database configurations of supported types.
Mutable union type for database configurations of supported types.
This type is used to validate and discriminate between different
database configurations based on the `db_type` field:
Expand All @@ -82,5 +125,6 @@ class SQLiteConfig(BaseModel):
- SQLiteConfig: Configuration model for SQLite databases.
Usage:
TypeAdapter(DataConfig).validate_python(db_config)
config = DataConfig.validate({"db_type": "files", "datasets": [...], "db_path": "path/to/db"})
assert isinstance(config, FilesConfig)
"""
12 changes: 10 additions & 2 deletions build/lib/lyzr/data_analyzr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RedshiftConfig,
PostgresConfig,
SQLiteConfig,
DataConfig,
)
from lyzr.data_analyzr.db_connector import (
DatabaseConnector,
Expand Down Expand Up @@ -77,9 +78,16 @@ def get_db_details(
+ "\n"
)
else:
accepted_db_types = tuple(
[
elem
for elem in DataConfig._config_types.values()
if elem is not FilesConfig
]
)
assert isinstance(
db_config, (RedshiftConfig, PostgresConfig, SQLiteConfig)
), f"Expected RedshiftConfig, PostgresConfig or SQLiteConfig, got {type(db_config)}"
db_config, accepted_db_types
), f"Expected one of {accepted_db_types}, got {type(db_config)}"
connector = DatabaseConnector.get_connector(db_type)(**db_config.model_dump())
df_dict, connector = ensure_correct_data_format(
analysis_type=analysis_type,
Expand Down
Binary file modified dist/lyzr-0.1.39-py3-none-any.whl
Binary file not shown.
Binary file modified dist/lyzr-0.1.39.tar.gz
Binary file not shown.
1 change: 1 addition & 0 deletions lyzr.egg-info/requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ python-dotenv>=1.0.0
weaviate-client==3.25.3

[data-analyzr]
aenum
chromadb==0.4.22
matplotlib==3.8.2
mysql-connector-python==8.2.0
Expand Down
5 changes: 2 additions & 3 deletions lyzr/data_analyzr/analyzr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from lyzr.base.prompt import LyzrPromptFactory
from lyzr.base.llm import LyzrLLMFactory, LiteLLM
from lyzr.data_analyzr.db_models import SupportedDBs, DataConfig


class DataAnalyzr:
Expand Down Expand Up @@ -228,16 +227,16 @@ def get_data(
vector_store_config=vector_store_config
)
"""
from pydantic import TypeAdapter
from lyzr.data_analyzr.file_utils import get_db_details
from lyzr.data_analyzr.db_models import DataConfig, SupportedDBs

if not isinstance(db_config, dict):
raise ValueError("data_config must be a dictionary.")
db_config["db_type"] = SupportedDBs(db_type.lower().strip())
self.database_connector, self.df_dict, self.vector_store = get_db_details(
analysis_type=self.analysis_type,
db_type=db_config["db_type"],
db_config=TypeAdapter(DataConfig).validate_python(db_config),
db_config=DataConfig.validate(db_config),
vector_store_config=VectorStoreConfig(**vector_store_config),
logger=self.logger,
)
Expand Down
72 changes: 58 additions & 14 deletions lyzr/data_analyzr/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

# standart-library imports
import warnings
from enum import Enum
from typing import Annotated, Union
from aenum import Enum
from typing import Union

# third-party imports
import pandas as pd
from pydantic import BaseModel, Field, Discriminator, Tag, AliasChoices, ConfigDict
from pydantic import BaseModel, Field, AliasChoices, ConfigDict


warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -62,17 +62,60 @@ class SQLiteConfig(BaseModel):
db_path: str


DataConfig = Annotated[
Union[
Annotated[FilesConfig, Tag(SupportedDBs.files)],
Annotated[RedshiftConfig, Tag(SupportedDBs.redshift)],
Annotated[PostgresConfig, Tag(SupportedDBs.postgres)],
Annotated[SQLiteConfig, Tag(SupportedDBs.sqlite)],
],
Discriminator(lambda x: x["db_type"]),
]
class DynamicConfigUnion:
"""
A class to manage dynamic configuration types.
It allows for the registration of configuration types and
creation of instances based on the discriminator key.
To be used as a mutable Union type for validation with Pydantic models.
Procedure:
1. Create an instance of the `DynamicConfigUnion` class.
2. Register configuration types using the `add_config_type` method.
3. Validate the data using the `validate` method.
4. The `validate` method returns an instance of the appropriate configuration type.
Usage:
config = DynamicConfigUnion()
config.add_config_type("db_type", DBConfig)
config.validate({"db_type": "db_type", "key": "value"})
"""

def __init__(self, key_name: str = None):
self._config_types: dict[str, BaseModel] = {}
self._discriminator_key = key_name or "db_type"

def add_config_type(self, name, config_type: BaseModel):
if name not in self._config_types:
self._config_types[name] = config_type
elif config_type != self._config_types[name]:
raise ValueError(
f"Config type with name '{name}' is already registered with {self._config_types[name]}"
)

def _create_instance(self, name, **kwargs) -> BaseModel:
if name not in self._config_types:
raise ValueError(f"Config type with name '{name}' is not registered")
config_type = self._config_types[name]
return config_type(**kwargs)

def validate(self, data: dict) -> BaseModel:
if self._discriminator_key not in data:
raise ValueError(
f"Data must contain the discriminator key '{self._discriminator_key}'"
)
name = data[self._discriminator_key]
return self._create_instance(name, **data)


DataConfig = DynamicConfigUnion()
DataConfig.add_config_type(SupportedDBs.files, FilesConfig)
DataConfig.add_config_type(SupportedDBs.redshift, RedshiftConfig)
DataConfig.add_config_type(SupportedDBs.postgres, PostgresConfig)
DataConfig.add_config_type(SupportedDBs.sqlite, SQLiteConfig)

"""
Union type for database configurations of supported types.
Mutable union type for database configurations of supported types.
This type is used to validate and discriminate between different
database configurations based on the `db_type` field:
Expand All @@ -82,5 +125,6 @@ class SQLiteConfig(BaseModel):
- SQLiteConfig: Configuration model for SQLite databases.
Usage:
TypeAdapter(DataConfig).validate_python(db_config)
config = DataConfig.validate({"db_type": "files", "datasets": [...], "db_path": "path/to/db"})
assert isinstance(config, FilesConfig)
"""
12 changes: 10 additions & 2 deletions lyzr/data_analyzr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RedshiftConfig,
PostgresConfig,
SQLiteConfig,
DataConfig,
)
from lyzr.data_analyzr.db_connector import (
DatabaseConnector,
Expand Down Expand Up @@ -77,9 +78,16 @@ def get_db_details(
+ "\n"
)
else:
accepted_db_types = tuple(
[
elem
for elem in DataConfig._config_types.values()
if elem is not FilesConfig
]
)
assert isinstance(
db_config, (RedshiftConfig, PostgresConfig, SQLiteConfig)
), f"Expected RedshiftConfig, PostgresConfig or SQLiteConfig, got {type(db_config)}"
db_config, accepted_db_types
), f"Expected one of {accepted_db_types}, got {type(db_config)}"
connector = DatabaseConnector.get_connector(db_type)(**db_config.model_dump())
df_dict, connector = ensure_correct_data_format(
analysis_type=analysis_type,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
],
extras_require={
"data-analyzr": [
"aenum",
"matplotlib==3.8.2",
"seaborn==0.13.2",
"scikit-learn==1.4.0",
Expand Down

0 comments on commit b22fbcd

Please sign in to comment.