Skip to content

Commit

Permalink
update the select source columns
Browse files Browse the repository at this point in the history
  • Loading branch information
unnir committed Feb 8, 2025
1 parent 0005788 commit 9138bdf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
30 changes: 22 additions & 8 deletions augini/data_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,17 @@ def _create_feature_prompt(self, df: pd.DataFrame, spec: FeatureSpec) -> str:
Returns:
Formatted prompt string
"""
# Get relevant columns
relevant_columns = spec.source_columns if spec.source_columns else df.columns
relevant_df = df[relevant_columns]

prompt = [
f"Generate a {spec.output_type} feature named '{spec.new_feature_name}'",
f"Description: {spec.new_feature_description}",
"\nDataset Information:",
f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns",
f"- Columns: {', '.join(df.columns)}",
f"- Data Types: {df.dtypes.to_dict()}\n",
"\nRelevant Dataset Information:",
f"- Shape: {df.shape[0]} rows, {len(relevant_columns)} columns",
f"- Columns: {', '.join(relevant_columns)}",
f"- Data Types: {relevant_df.dtypes.to_dict()}\n",
]

if spec.constraints:
Expand Down Expand Up @@ -331,6 +335,16 @@ def _create_multi_feature_prompt(self, df: pd.DataFrame, specs: List[FeatureSpec
Returns:
Formatted prompt string
"""
# Collect all unique source columns across all features
all_source_columns = set()
for spec in specs:
if spec.source_columns:
all_source_columns.update(spec.source_columns)

# If no source columns specified for any feature, use all columns
relevant_columns = list(all_source_columns) if all_source_columns else df.columns
relevant_df = df[relevant_columns]

prompt = [
"Generate multiple features with these specifications:",
"\nFeatures to generate:"
Expand All @@ -348,10 +362,10 @@ def _create_multi_feature_prompt(self, df: pd.DataFrame, specs: List[FeatureSpec
prompt.append(f"- Source columns: {', '.join(spec.source_columns)}")

prompt.extend([
"\nDataset Information:",
f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns",
f"- Columns: {', '.join(df.columns)}",
f"- Data Types: {df.dtypes.to_dict()}\n"
"\nRelevant Dataset Information:",
f"- Shape: {df.shape[0]} rows, {len(relevant_columns)} columns",
f"- Columns: {', '.join(relevant_columns)}",
f"- Data Types: {relevant_df.dtypes.to_dict()}\n"
])

return "\n".join(prompt)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "augini"
version = "0.3.2"
version = "0.3.3"
authors = [
{ name = "Vadim Borisov", email = "vadim@tabularis.ai" },
]
Expand Down

0 comments on commit 9138bdf

Please sign in to comment.