diff --git a/packages/kestrel_core/src/kestrel/frontend/completor.py b/packages/kestrel_core/src/kestrel/frontend/completor.py index 019c6f45..6a2791b7 100644 --- a/packages/kestrel_core/src/kestrel/frontend/completor.py +++ b/packages/kestrel_core/src/kestrel/frontend/completor.py @@ -132,24 +132,27 @@ def do_complete( # handle optional components if ast: - stmt = ast.children[-1].children[0] - cmd = ( - stmt.children[1].data.value - if stmt.data.value == "assignment" - else stmt.data.value - ) - if cmd == "disp": - for clause in ("attr_clause", "limit_clause", "offset_clause"): - if not list(stmt.find_data(clause)): - suggestions.append("ATTR") - elif cmd in ("expression", "find") and not list( - stmt.find_data("where_clause") - ): - suggestions.append("WHERE") - elif cmd in ("get", "find") and not list(stmt.find_data("timerange")): - suggestions.append("START") - elif cmd == "apply" and not list(stmt.find_data("args")): - suggestions.append("WITH") + if ast.children: + stmt = ast.children[-1].children[0] + cmd = ( + stmt.children[1].data.value + if stmt.data.value == "assignment" + else stmt.data.value + ) + if cmd == "disp": + for clause in ("attr_clause", "limit_clause", "offset_clause"): + if not list(stmt.find_data(clause)): + suggestions.append("ATTR") + elif cmd in ("expression", "find") and not list( + stmt.find_data("where_clause") + ): + suggestions.append("WHERE") + elif cmd in ("get", "find") and not list(stmt.find_data("timerange")): + suggestions.append("START") + elif cmd == "apply" and not list(stmt.find_data("args")): + suggestions.append("WITH") + else: + suggestions = ["DISP", "APPLY", "EXPLAIN", "INFO", "SAVE", "DESCRIBE"] suggestions = [x for x in set(suggestions) if x] _p = last_word_prefix diff --git a/packages/kestrel_core/src/kestrel/interface/codegen/sql.py b/packages/kestrel_core/src/kestrel/interface/codegen/sql.py index 349224bf..f703dfef 100644 --- a/packages/kestrel_core/src/kestrel/interface/codegen/sql.py +++ b/packages/kestrel_core/src/kestrel/interface/codegen/sql.py @@ -133,7 +133,7 @@ def __init__( @typechecked def _map_identifier_field(self, field) -> ColumnElement: - if self.data_mapping and not self.is_subquery: + if self.data_mapping: comps = translate_comparison_to_native(self.data_mapping, field, "", None) if len(comps) > 1: raise InvalidMappingWithMultipleIdentifierFields(comps) @@ -155,10 +155,15 @@ def _render_comp(self, comp: FBasicComparison) -> BinaryExpression: *[self._map_identifier_field(field) for field in comp.fields] ) rendered_comp = comp2func[comp.op](col, comp.value) - elif self.data_mapping and not self.is_subquery: # translation needed + elif self.data_mapping: comps = translate_comparison_to_native( self.data_mapping, comp.field, comp.op, comp.value ) + if self.is_subquery: + # do not translate field + # only translate value + comps = [(comp.field, op, value) for (_, op, value) in comps] + translated_comps = ( ( ~comp2func[op](column(field), value) diff --git a/packages/kestrel_core/src/kestrel/mapping/data_model.py b/packages/kestrel_core/src/kestrel/mapping/data_model.py index a61755fe..f45c3f9c 100644 --- a/packages/kestrel_core/src/kestrel/mapping/data_model.py +++ b/packages/kestrel_core/src/kestrel/mapping/data_model.py @@ -447,7 +447,18 @@ def translate_dataframe(df: DataFrame, to_native_nested_map: dict) -> DataFrame: transformer_name = transformer_names.pop() if isinstance(transformer_name, dict): # Not actually a named function; it's a literal value map - df[col] = df[col].replace(transformer_name) + value_map = {} + for k, vl in transformer_name.items(): + if len(vl) > 1: + raise NotImplementedError( + "Multiple to OCSF value mapping" + ) + else: + value_map[k] = vl[0] + # use .apply intead of .replace to handle type correctly + df[col] = df[col].apply( + lambda x: value_map[x] if x in value_map else x + ) else: s = run_transformer_on_series( transformer_name, df[col].dropna()