diff --git a/packages/kestrel_core/pyproject.toml b/packages/kestrel_core/pyproject.toml index e57a5bca..404eee09 100644 --- a/packages/kestrel_core/pyproject.toml +++ b/packages/kestrel_core/pyproject.toml @@ -54,6 +54,10 @@ Homepage = "https://github.com/opencybersecurityalliance/kestrel-lang" Documentation = "https://kestrel.readthedocs.io/" Repository = "https://github.com/opencybersecurityalliance/kestrel-lang.git" +[project.scripts] +kestrel = "kestrel.cli:kestrel" +ikestrel = "kestrel.cli:ikestrel" + [tool.setuptools.packages.find] where = ["src"] diff --git a/packages/kestrel_core/src/kestrel/__future__.py b/packages/kestrel_core/src/kestrel/__future__.py index efe66a26..73c43393 100644 --- a/packages/kestrel_core/src/kestrel/__future__.py +++ b/packages/kestrel_core/src/kestrel/__future__.py @@ -1,6 +1,6 @@ import sys -from typeguard import typechecked +from typeguard import typechecked """Entrance to invoke any backward compatibility patch diff --git a/packages/kestrel_core/src/kestrel/analytics/__init__.py b/packages/kestrel_core/src/kestrel/analytics/__init__.py new file mode 100644 index 00000000..4ef76205 --- /dev/null +++ b/packages/kestrel_core/src/kestrel/analytics/__init__.py @@ -0,0 +1 @@ +from .interface import PythonAnalyticsInterface diff --git a/packages/kestrel_core/src/kestrel/analytics/config.py b/packages/kestrel_core/src/kestrel/analytics/config.py new file mode 100644 index 00000000..7d8bdbb1 --- /dev/null +++ b/packages/kestrel_core/src/kestrel/analytics/config.py @@ -0,0 +1,51 @@ +import logging + +from kestrel.config.utils import CONFIG_DIR_DEFAULT, load_user_config +from kestrel.exceptions import InvalidAnalytics + +PROFILE_PATH_DEFAULT = CONFIG_DIR_DEFAULT / "pythonanalytics.yaml" +PROFILE_PATH_ENV_VAR = "KESTREL_PYTHON_ANALYTICS_CONFIG" + +_logger = logging.getLogger(__name__) + + +def load_profiles(): + config = load_user_config(PROFILE_PATH_ENV_VAR, PROFILE_PATH_DEFAULT) + if config and "profiles" in config: + _logger.debug(f"python analytics profiles found in config file") + profiles = config["profiles"] + else: + _logger.info("no python analytics config with profiles found") + profiles = {} + _logger.debug(f"profiles loaded: {profiles}") + return profiles + + +def get_profile(profile_name, profiles): + if profile_name not in profiles: + raise InvalidAnalytics( + profile_name, + "python", + f"no {profile_name} configuration found", + ) + else: + profile = profiles[profile_name] + _logger.debug(f"profile to use: {profile}") + if "module" not in profile: + raise InvalidAnalytics( + profile_name, + "python", + f"no {profile_name} module defined", + ) + else: + module_name = profile["module"] + if "func" not in profile: + raise InvalidAnalytics( + profile_name, + "python", + f"no {profile_name} func defined", + ) + else: + func_name = profile["func"] + + return module_name, func_name diff --git a/packages/kestrel_core/src/kestrel/analytics/interface.py b/packages/kestrel_core/src/kestrel/analytics/interface.py new file mode 100644 index 00000000..7c65c4b6 --- /dev/null +++ b/packages/kestrel_core/src/kestrel/analytics/interface.py @@ -0,0 +1,400 @@ +"""Python analytics interface executes Python function as Kestrel analytics. + +Use a Python Analytics +---------------------- + +Create a profile for each analytics in the python analytics interface config +file (YAML): + +- Default path: ``~/.config/kestrel/pythonanalytics.yaml``. +- A customized path specified in the environment variable ``KESTREL_PYTHON_ANALYTICS_CONFIG``. + +Example of the python analytics interface config file: + +.. code-block:: yaml + + profiles: + analytics-name-1: # the analytics name to use in the APPLY command + module: /home/user/kestrel-analytics/analytics/piniponmap/analytics.py + func: analytics # the analytics function in the module to call + analytics-name-2: + module: /home/user/kestrel-analytics/analytics/suspiciousscoring/analytics.py + func: analytics + +Develop a Python Analytics +-------------------------- + +A Python analytics is a python function that follows the rules: + +#. The function takes in one or more Kestrel variable dumps in Pandas DataFrames. + +#. The return of the function is a tuple containing either or both: + + - Updated variables. The number of variables can be either 0, e.g., + visualization analytics, or the same number as input Kestrel variables. + The order of the updated variables should follow the same order as input + variables. + + - An object to display, which can be any of the following types: + + - Kestrel display object + + - HTML element as a string + + - Matplotlib figure (by default, Pandas DataFrame plots use this) + + The display object can be either before or after updated variables. In other + words, if the input variables are ``var1``, ``var2``, and ``var3``, the + return of the analytics can be either of the following: + + .. code-block:: python + + # the analytics enriches variables without returning a display object + return var1_updated, var3_updated, var3_updated + + # this is a visualization analytics and no variable updates + return display_obj + + # the analytics does both variable updates and visualization + return var1_updated, var3_updated, var3_updated, display_obj + + # the analytics does both variable updates and visualization + return display_obj, var1_updated, var3_updated, var3_updated + + +#. Parameters in the APPLY command are passed in as keyword args when possible, + otherwise as environment variables. The interface will inspect the signature + of the analytics function to determine which methods to use. For example, the + following analytics function would be called with keyword args taken exactly + from the ``WITH`` part of the Kestrel statement. + + .. code-block:: + + def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5) + + This function could be called as: + + .. code-block:: + + APPLY python://my_analytic ON var1 WITH x=1, y=0.7 + + The names of the environment variables are the exact parameter keys given in the + ``APPLY`` command. For example, the following command + + .. code-block:: + + APPLY python://a1 ON var1 WITH XPARAM=src_ref.value, YPARAM=number_observed + + creates environment variables ``$XPARAM`` with value ``src_ref.value`` and + ``$YPARAM`` with value ``number_observed`` to be used by the analytics + ``a1``. After the execution of the analytics, the environment variables will + be roll back to the original state. + +#. The Python function could spawn other processes or execute other binaries, + where the Python function just acts like a wrapper. Check our `domain name + lookup analytics`_ as an example. + +.. _domain name lookup analytics: https://github.com/opencybersecurityalliance/kestrel-analytics/tree/release/analytics/domainnamelookup + +""" + +import inspect +import json +import logging +import os +import pathlib +import sys +import traceback +from contextlib import AbstractContextManager +from importlib.util import module_from_spec, spec_from_file_location +from typing import Any, Iterable, Mapping, MutableMapping, Optional +from uuid import UUID + +from pandas import DataFrame + +from kestrel.analytics.config import get_profile, load_profiles +from kestrel.display import GraphletExplanation +from kestrel.exceptions import ( + AnalyticsError, + InvalidAnalytics, + InvalidAnalyticsArgumentCount, + InvalidAnalyticsInterfaceImplementation, + InvalidAnalyticsOutput, +) +from kestrel.interface import AbstractInterface +from kestrel.ir.graph import IRGraphEvaluable +from kestrel.ir.instructions import ( + Analytic, + Instruction, + TransformingInstruction, + Variable, +) + +_logger = logging.getLogger(__name__) + + +class PythonAnalyticsJob: + """Simple config class to hold all the bits necessary to call the external analytics""" + + def __init__( + self, + iid: UUID, + cache: MutableMapping[UUID, Any], + ): + self.cache = cache + self.input_iid = iid + self.output_iid: Optional[UUID] = None + self.analytic: str = "" + self.params: dict = {} + + def run(self, config: dict) -> DataFrame: + module_name, func_name = get_profile(self.analytic, config) + df = self.cache[self.input_iid] + with PythonAnalytics( + self.analytic, module_name, func_name, self.params + ) as func: + df = func(df) + _logger.debug("python analytics job result:\n%s", df) + return df + + +class PythonAnalyticsInterface(AbstractInterface): + def __init__( + self, + serialized_cache_catalog: Optional[str] = None, + session_id: Optional[UUID] = None, + ): + _logger.debug("PythonAnalyticsInterface: loading config") + super().__init__(serialized_cache_catalog, session_id) + self.config = load_profiles() + + @staticmethod + def schemes() -> Iterable[str]: + return ["python"] + + def store( + self, + instruction_id: UUID, + data: DataFrame, + ): + raise NotImplementedError("PythonAnalyticsInterface.store") # TEMP + + def explain_graph( + self, + graph: IRGraphEvaluable, + instructions_to_explain: Optional[Iterable[Instruction]] = None, + ) -> Mapping[UUID, GraphletExplanation]: + raise NotImplementedError("PythonAnalyticsInterface.explain_graph") # TEMP + + def evaluate_graph( + self, + graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], + instructions_to_evaluate: Optional[Iterable[Instruction]] = None, + ) -> Mapping[UUID, DataFrame]: + _logger.debug( + "python: graph: %s", json.dumps(json.loads(graph.to_json()), indent=4) + ) + mapping = {} + if not instructions_to_evaluate: + instructions_to_evaluate = graph.get_sink_nodes() + for instruction in instructions_to_evaluate: + _logger.debug("python: inst to evaluate: %s", instruction) + job = self._evaluate_instruction_in_graph(graph, cache, instruction) + mapping[job.output_iid] = job.run(self.config) + return mapping + + def _evaluate_instruction_in_graph( + self, + graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], + instruction: Instruction, + ) -> PythonAnalyticsJob: + _logger.debug("python analytics: instruction = %s", instruction) + + if isinstance(instruction, TransformingInstruction): + if instruction.id in cache: + iid = instruction.id + _logger.debug("python analytics: got iid = %s from cache", iid) + job = PythonAnalyticsJob(iid, cache) + else: + trunk, _r2n = graph.get_trunk_n_branches(instruction) + job = self._evaluate_instruction_in_graph(graph, cache, trunk) + + if isinstance(instruction, Analytic): + job.analytic = instruction.name + job.params = instruction.params + _logger.debug("python analytics: %s", job.analytic) + elif isinstance(instruction, Variable): + if not job.input_iid: + job.input_iid = instruction.id + _logger.debug("python analytics: input_iid = %s", job.input_iid) + else: + job.output_iid = instruction.id + _logger.debug("python analytics: output_iid = %s", job.output_iid) + + return job + + +class PythonAnalytics(AbstractContextManager): + def __init__( + self, analytic_name: str, module_name: str, func_name: str, parameters: dict + ): + self.name = analytic_name + self.func_name = func_name + self.parameters = parameters + self.module_path = pathlib.Path(module_name).expanduser().resolve() + self.module_path_dir_str = str(self.module_path.parent) + self.use_env = False + + def __enter__(self): + # accommodate any other Python modules to load in the dir + self.syspath = sys.path.copy() + sys.path.append(self.module_path_dir_str) + + # accommodate any other executables or data to load in the dir + self.cwd_original = os.getcwd() + os.chdir(self.module_path_dir_str) + + # time to load the analytics function + self.analytics_function = self._locate_analytics_func(self._load_module()) + + # inspect signature of function and skip env vars if it has kwargs + self.use_env = self._set_env_vars() + + # passing parameters as environment variables + if self.use_env: + self.environ_original = os.environ.copy() + if self.parameters: + if isinstance(self.parameters, Mapping): + parameters = {k: str(v) for k, v in self.parameters.items()} + _logger.debug(f"setting parameters as env vars: {parameters}") + os.environ.update(parameters) + else: + raise InvalidAnalyticsInterfaceImplementation( + "parameters should be passed in as a Mapping" + ) + + return self._execute + + def __exit__(self, exception_type, exception_value, _traceback): + sys.path = self.syspath + os.chdir(self.cwd_original) + if self.use_env: + os.environ = self.environ_original + + def _execute(self, dataframe: DataFrame): + """Execute the analytics + + Args: + dataframe (DataFrame): input variable to the analytics. + + Returns: + + DataFrame: the analytics output (i.e. "enriched" DataFrame) + + """ + input_dataframes = [dataframe] # TEMP + if len(input_dataframes) != self._get_var_count(): + raise InvalidAnalyticsArgumentCount( + self.name, len(input_dataframes), self._get_var_count() + ) + + try: + if self.use_env: + outputs = self.analytics_function(*input_dataframes) + else: + outputs = self.analytics_function(*input_dataframes, **self.parameters) + except Exception as e: + _logger.error('"%s" failed at execution: %s', self.name, e, exc_info=e) + raise AnalyticsError(f"{self.name} failed at execution") from e + + if not isinstance(outputs, tuple): + outputs = (outputs,) + + output_dfs, output_dsps = [], [] + for x in outputs: + x_class_str = type(x).__module__ + "." + type(x).__name__ + if isinstance(x, DataFrame): + output_dfs.append(x) + elif isinstance(x, str): + _logger.info( + f'analytics "{self.name}" yielded a string return. treat it as an HTML element.' + ) + output_dsps.append(x) # FIXME:DisplayHtml(x)) + elif x_class_str == "matplotlib.figure.Figure": + _logger.info(f'analytics "{self.name}" yielded a figure.') + output_dsps.append(x) # FIXME:DisplayFigure(x)) + else: + raise InvalidAnalyticsOutput(self.name, type(x)) + + if not outputs: + raise AnalyticsError(f'analytics "{self.name}" yield nothing') + if len(output_dsps) > 1: + raise AnalyticsError( + f'analytics "{self.name}" yielded more than one Kestrel Display object' + ) + if output_dfs: + if len(output_dfs) != len(input_dataframes): + raise AnalyticsError( + f'analytics "{self.name}" yielded less/more Kestrel variable(s) than given' + ) + return output_dfs[0] # TEMP + + display = output_dsps[0] if output_dsps else None + return display + + def _load_module(self): + spec = spec_from_file_location( + "kestrel_analytics_python.analytics.{profile_name}", str(self.module_path) + ) + + try: + module = module_from_spec(spec) + spec.loader.exec_module(module) + except ModuleNotFoundError as e: + raise AnalyticsError( + f"{self.name} misses dependent library: {e.name}", + "pip install the corresponding Python package", + ) + except Exception as e: + if isinstance(e, AttributeError) and e.args == ( + "'NoneType' object has no attribute 'loader'", + ): + raise AnalyticsError( + f"{self.name} is not found", + "please make sure the Python module and function specified in the profile (configuration) exist", + ) + else: + exc_type, exc_value, exc_traceback = sys.exc_info() + error = "".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + raise AnalyticsError(f"{self.name} failed at importing:\n{error}") + + return module + + def _locate_analytics_func(self, module): + if hasattr(module, self.func_name): + return getattr(module, self.func_name) + raise InvalidAnalytics( + self.name, + "python", + f'function "{self.func_name}" not exist in module: {self.module_path}', + ) + + def _get_var_count(self): + """Determine number of vars/DataFrames the analytics func expects""" + sig = inspect.signature(self.analytics_function) + # Count of params with type DataFrame + # If there are no type annotations, then fall back to param count + df_count = sum(1 for i in sig.parameters.values() if i.annotation == DataFrame) + return df_count if df_count else len(sig.parameters) + + def _set_env_vars(self): + """Check if the analytics function DOES NOT accept any non-DataFrame parameters. If so, return True and use env to pass params.""" + sig = inspect.signature(self.analytics_function) + return ( + sum(1 for i in sig.parameters.values() if i.annotation != DataFrame) == 0 + or len(sig.parameters) == 1 + ) diff --git a/packages/kestrel_core/src/kestrel/cache/base.py b/packages/kestrel_core/src/kestrel/cache/base.py index 4d1a94bb..d7ee53fd 100644 --- a/packages/kestrel_core/src/kestrel/cache/base.py +++ b/packages/kestrel_core/src/kestrel/cache/base.py @@ -1,8 +1,10 @@ from __future__ import annotations -from pandas import DataFrame -from typing import MutableMapping -from uuid import UUID + from abc import abstractmethod +from typing import Iterable, MutableMapping +from uuid import UUID + +from pandas import DataFrame from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER from kestrel.interface import AbstractInterface diff --git a/packages/kestrel_core/src/kestrel/cache/inmemory.py b/packages/kestrel_core/src/kestrel/cache/inmemory.py index 87557222..f88f4580 100644 --- a/packages/kestrel_core/src/kestrel/cache/inmemory.py +++ b/packages/kestrel_core/src/kestrel/cache/inmemory.py @@ -1,30 +1,25 @@ from copy import copy +from typing import Any, Iterable, Mapping, MutableMapping, Optional +from uuid import UUID + from pandas import DataFrame from typeguard import typechecked -from uuid import UUID -from typing import ( - Mapping, - MutableMapping, - Optional, - Iterable, - Any, -) from kestrel.cache.base import AbstractCache -from kestrel.ir.graph import IRGraphEvaluable from kestrel.display import GraphletExplanation, NativeQuery +from kestrel.interface.codegen.dataframe import ( + evaluate_source_instruction, + evaluate_transforming_instruction, +) +from kestrel.ir.graph import IRGraphEvaluable from kestrel.ir.instructions import ( - Instruction, - Return, Explain, - Variable, Filter, + Instruction, + Return, SourceInstruction, TransformingInstruction, -) -from kestrel.interface.codegen.dataframe import ( - evaluate_source_instruction, - evaluate_transforming_instruction, + Variable, ) @@ -68,6 +63,7 @@ def get_virtual_copy(self) -> AbstractCache: def evaluate_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_evaluate: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, DataFrame]: mapping = {} diff --git a/packages/kestrel_core/src/kestrel/cache/sql.py b/packages/kestrel_core/src/kestrel/cache/sql.py index 4939723d..e0b5f313 100644 --- a/packages/kestrel_core/src/kestrel/cache/sql.py +++ b/packages/kestrel_core/src/kestrel/cache/sql.py @@ -1,6 +1,6 @@ import logging from copy import copy -from typing import Iterable, Mapping, Optional, Union, Any +from typing import Any, Iterable, Mapping, MutableMapping, Optional, Union from uuid import UUID import sqlalchemy @@ -9,19 +9,19 @@ from typeguard import typechecked from kestrel.cache.base import AbstractCache +from kestrel.display import GraphletExplanation, NativeQuery from kestrel.interface.codegen.sql import SqlTranslator from kestrel.ir.graph import IRGraphEvaluable -from kestrel.display import GraphletExplanation, NativeQuery from kestrel.ir.instructions import ( Construct, - Instruction, - Return, Explain, - Variable, Filter, + Instruction, + Return, + SolePredecessorTransformingInstruction, SourceInstruction, TransformingInstruction, - SolePredecessorTransformingInstruction, + Variable, ) _logger = logging.getLogger(__name__) @@ -90,6 +90,7 @@ def get_virtual_copy(self) -> AbstractCache: def evaluate_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_evaluate: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, DataFrame]: mapping = {} diff --git a/packages/kestrel_core/src/kestrel/cli.py b/packages/kestrel_core/src/kestrel/cli.py index e69de29b..8a5adc8a 100644 --- a/packages/kestrel_core/src/kestrel/cli.py +++ b/packages/kestrel_core/src/kestrel/cli.py @@ -0,0 +1,122 @@ +################################################################ +# +# Kestrel Command-line Utilities +# - kestrel +# - ikestrel +# +################################################################ + +import argparse +import cmd +import logging + +from kestrel.exceptions import KestrelError +from kestrel.session import Session + + +def add_logging_handler(handler, if_debug): + fmt = "%(asctime)s %(levelname)s %(name)s %(message)s" + datefmt = "%H:%M:%S" + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger() + current_logging_level = root_logger.getEffectiveLevel() + root_logger.addHandler(handler) + root_logger.setLevel(logging.DEBUG if if_debug else logging.INFO) + + return handler, current_logging_level + + +def kestrel(): + parser = argparse.ArgumentParser(description="Kestrel Interpreter") + parser.add_argument("huntflow", help="huntflow in .hf file") + parser.add_argument( + "-v", "--verbose", help="print verbose log", action="store_true" + ) + parser.add_argument( + "--debug", help="debug level log (default is info level)", action="store_true" + ) + args = parser.parse_args() + + if args.verbose: + add_logging_handler(logging.StreamHandler(), args.debug) + + with Session() as session: + with open(args.huntflow, "r") as fp: + huntflow = fp.read() + outputs = session.execute(huntflow) + results = "\n\n".join([o.to_string() for o in outputs]) + print(results) + + +# TODO: fix #405 so we do not need this +CMDS = [ # command_no_result from kestrel.lark + "APPLY", + "DISP", + "INFO", + "SAVE", +] + + +def display_outputs(outputs): + for i in outputs: + print(i) + + +class IKestrel(cmd.Cmd): + prompt = "> " + + def __init__(self, session: Session): + self.session = session + self.buf = "" + super().__init__() + + def default(self, line: str): + try: + outputs = self.session.execute(line) + display_outputs(outputs) + except KestrelError as e: + print(e) + + def completenames(self, text, *ignored): + code, _start, _end = ignored + if code.isupper(): + # Probably a command? + results = [i for i in CMDS if i.startswith(code)] + else: + # Try all commands and vars + results = [i for i in CMDS if i.lower().startswith(code)] + results += [ + i for i in self.session.get_variable_names() if i.startswith(code) + ] + return results + + def completedefault(self, *ignored): + _, code, start, end = ignored + results = self.session.do_complete(code, end) + stub = code[start:] + return [stub + suffix for suffix in results] + + def do_EOF(self, _line: str): + print() + return True + + +def ikestrel(): + parser = argparse.ArgumentParser(description="Kestrel Interpreter") + parser.add_argument( + "-v", "--verbose", help="print verbose log", action="store_true" + ) + parser.add_argument( + "--debug", help="debug level log (default is info level)", action="store_true" + ) + args = parser.parse_args() + + if args.verbose: + add_logging_handler(logging.StreamHandler(), args.debug) + + with Session() as s: + ik = IKestrel(s) + ik.cmdloop() diff --git a/packages/kestrel_core/src/kestrel/config/utils.py b/packages/kestrel_core/src/kestrel/config/utils.py index 369ffc60..2d20b22a 100644 --- a/packages/kestrel_core/src/kestrel/config/utils.py +++ b/packages/kestrel_core/src/kestrel/config/utils.py @@ -1,12 +1,13 @@ +import logging import os -import yaml from pathlib import Path -import logging -from typeguard import typechecked from typing import Mapping, Union -from kestrel.utils import update_nested_dict, load_data_file -from kestrel.exceptions import InvalidYamlInConfig, InvalidKestrelConfig +import yaml +from typeguard import typechecked + +from kestrel.exceptions import InvalidKestrelConfig, InvalidYamlInConfig +from kestrel.utils import load_data_file, update_nested_dict CONFIG_DIR_DEFAULT = Path.home() / ".config" / "kestrel" CONFIG_PATH_DEFAULT = CONFIG_DIR_DEFAULT / "kestrel.yaml" diff --git a/packages/kestrel_core/src/kestrel/display.py b/packages/kestrel_core/src/kestrel/display.py index e6729f85..0cc5175a 100644 --- a/packages/kestrel_core/src/kestrel/display.py +++ b/packages/kestrel_core/src/kestrel/display.py @@ -1,5 +1,6 @@ -from typing import List, Union, Mapping from dataclasses import dataclass +from typing import List, Mapping, Union + from mashumaro.mixins.json import DataClassJSONMixin from pandas import DataFrame diff --git a/packages/kestrel_core/src/kestrel/exceptions.py b/packages/kestrel_core/src/kestrel/exceptions.py index c5caf011..bfcdd2ab 100644 --- a/packages/kestrel_core/src/kestrel/exceptions.py +++ b/packages/kestrel_core/src/kestrel/exceptions.py @@ -130,3 +130,23 @@ class UnsupportedOperatorError(KestrelError): class IncompleteDataMapping(KestrelError): pass + + +class InvalidAnalytics(KestrelError): + pass + + +class InvalidAnalyticsArgumentCount(KestrelError): + pass + + +class InvalidAnalyticsInterfaceImplementation(KestrelError): + pass + + +class InvalidAnalyticsOutput(KestrelError): + pass + + +class AnalyticsError(KestrelError): + pass diff --git a/packages/kestrel_core/src/kestrel/frontend/compile.py b/packages/kestrel_core/src/kestrel/frontend/compile.py index b5730d4d..9ab78b56 100644 --- a/packages/kestrel_core/src/kestrel/frontend/compile.py +++ b/packages/kestrel_core/src/kestrel/frontend/compile.py @@ -5,38 +5,34 @@ from functools import reduce from dateutil.parser import parse as to_datetime -from lark import Transformer, Token +from lark import Token, Transformer from typeguard import typechecked -from kestrel.mapping.data_model import ( - translate_comparison_to_ocsf, - translate_projection_to_ocsf, -) -from kestrel.utils import unescape_quoted_string +from kestrel.exceptions import IRGraphMissingNode from kestrel.ir.filter import ( - FExpression, + BoolExp, + ExpOp, FComparison, - IntComparison, + FExpression, FloatComparison, - StrComparison, + IntComparison, ListComparison, - RefComparison, - ReferenceValue, - MultiComp, ListOp, + MultiComp, NumCompOp, + RefComparison, + ReferenceValue, + StrComparison, StrCompOp, - ExpOp, - BoolExp, TimeRange, ) -from kestrel.ir.graph import ( - IRGraph, - compose, -) +from kestrel.ir.graph import IRGraph, compose from kestrel.ir.instructions import ( + Analytic, + AnalyticsInterface, Construct, DataSource, + Explain, Filter, Instruction, Limit, @@ -47,10 +43,12 @@ Return, Sort, Variable, - Explain, ) -from kestrel.exceptions import IRGraphMissingNode - +from kestrel.mapping.data_model import ( + translate_comparison_to_ocsf, + translate_projection_to_ocsf, +) +from kestrel.utils import unescape_quoted_string _logger = logging.getLogger(__name__) @@ -202,6 +200,7 @@ def __init__( self.token_prefix = token_prefix self.entity_map = entity_map self.property_map = property_map # TODO: rename to data_model_map? + self.variable_map = {} # To cache var type info super().__init__() def start(self, args): @@ -215,6 +214,7 @@ def assignment(self, args): graph, root = args[1] entity_type, native_type = self._get_type_from_predecessors(graph, root) variable_node = Variable(args[0].value, entity_type, native_type) + self.variable_map[args[0].value] = (entity_type, native_type) graph.add_node(variable_node, root) return graph @@ -283,6 +283,9 @@ def json_value(self, args): v = float(v) if "." in v else int(v) return v + def variables(self, args): + return [Reference(arg.value) for arg in args] + def get(self, args): graph = IRGraph() entity_name = args[0].value @@ -313,6 +316,23 @@ def get(self, args): root = graph.add_node(arg, projection_node) return graph, root + def apply(self, args): + scheme, analytic_name = args[0] + refvar = args[1][0] # TODO - this is a list of refs? + params = args[2] if len(args) > 2 else {} + vds = AnalyticsInterface(interface=scheme) + analytic = Analytic(name=analytic_name, params=params) + _logger.debug("apply: analytic: %s", analytic) + graph = IRGraph() + graph.add_node(refvar) + graph.add_node(analytic, refvar) + graph.add_node(vds) + graph.add_edge(vds, analytic) + entity_type, native_type = self.variable_map.get(refvar.name) + variable_node = Variable(refvar.name, entity_type, native_type) + graph.add_node(variable_node, analytic) + return graph + def where_clause(self, args): exp = args[0] return Filter(exp) @@ -340,6 +360,17 @@ def comparison_std(self, args): comp = _create_comp(field, op, value) return comp + def args(self, args): + return dict(args) + + def arg_kv_pair(self, args): + name = args[0].value + if isinstance(args[1], ReferenceValue): + value = args[1].reference + else: + value = args[1] # Should be int or float? + return (name, value) + def op(self, args): """Convert operator token to a plain string""" return " ".join([arg.upper() for arg in args]) @@ -377,6 +408,11 @@ def literal(self, args): def datasource(self, args): return DataSource(args[0].value) + def analytics_uri(self, args): + scheme, _, analytic = args[0].value.partition("://") + _logger.debug("analytics_uri: %s %s", scheme, analytic) + return scheme, analytic + # Timespans def timespan_relative(self, args): num = int(args[0]) diff --git a/packages/kestrel_core/src/kestrel/frontend/parser.py b/packages/kestrel_core/src/kestrel/frontend/parser.py index 015a1cc7..dcc553a4 100644 --- a/packages/kestrel_core/src/kestrel/frontend/parser.py +++ b/packages/kestrel_core/src/kestrel/frontend/parser.py @@ -3,13 +3,13 @@ import logging from itertools import chain -from kestrel.frontend.compile import _KestrelT -from kestrel.mapping.data_model import reverse_mapping -from kestrel.utils import load_data_file, list_folder_files +import yaml from lark import Lark from typeguard import typechecked -import yaml +from kestrel.frontend.compile import _KestrelT +from kestrel.mapping.data_model import reverse_mapping +from kestrel.utils import list_folder_files, load_data_file _logger = logging.getLogger(__name__) diff --git a/packages/kestrel_core/src/kestrel/interface/base.py b/packages/kestrel_core/src/kestrel/interface/base.py index 50f5601f..57c87d70 100644 --- a/packages/kestrel_core/src/kestrel/interface/base.py +++ b/packages/kestrel_core/src/kestrel/interface/base.py @@ -1,21 +1,14 @@ import json from abc import ABC, abstractmethod -from pandas import DataFrame +from typing import Any, Iterable, Mapping, MutableMapping, Optional from uuid import UUID -from typing import ( - Mapping, - MutableMapping, - Optional, - Iterable, -) + +from pandas import DataFrame from kestrel.display import GraphletExplanation -from kestrel.ir.instructions import Instruction +from kestrel.exceptions import InvalidSerializedDatasourceInterfaceCacheCatalog from kestrel.ir.graph import IRGraphEvaluable -from kestrel.exceptions import ( - InvalidSerializedDatasourceInterfaceCacheCatalog, -) - +from kestrel.ir.instructions import Instruction MODULE_PREFIX = "kestrel_interface_" @@ -93,6 +86,7 @@ def store( def evaluate_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_evaluate: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, DataFrame]: """Evaluate the IRGraph diff --git a/packages/kestrel_core/src/kestrel/interface/codegen/dataframe.py b/packages/kestrel_core/src/kestrel/interface/codegen/dataframe.py index 21ed706e..88ff0531 100644 --- a/packages/kestrel_core/src/kestrel/interface/codegen/dataframe.py +++ b/packages/kestrel_core/src/kestrel/interface/codegen/dataframe.py @@ -1,29 +1,30 @@ -import sys +import functools import inspect -import re import operator -import functools -from typeguard import typechecked -from pandas import DataFrame, Series +import re +import sys from typing import Callable +from pandas import DataFrame, Series +from typeguard import typechecked + +from kestrel.ir.filter import ( + BoolExp, + ExpOp, + FExpression, + ListOp, + MultiComp, + NumCompOp, + StrCompOp, +) from kestrel.ir.instructions import ( - SourceInstruction, - TransformingInstruction, Construct, + Filter, Limit, ProjectAttrs, ProjectEntity, - Filter, -) -from kestrel.ir.filter import ( - FExpression, - BoolExp, - MultiComp, - StrCompOp, - NumCompOp, - ExpOp, - ListOp, + SourceInstruction, + TransformingInstruction, ) diff --git a/packages/kestrel_core/src/kestrel/interface/codegen/sql.py b/packages/kestrel_core/src/kestrel/interface/codegen/sql.py index fd20943d..8496f8eb 100644 --- a/packages/kestrel_core/src/kestrel/interface/codegen/sql.py +++ b/packages/kestrel_core/src/kestrel/interface/codegen/sql.py @@ -2,7 +2,7 @@ from functools import reduce from typing import Callable -from sqlalchemy import and_, column, or_, select, FromClause, asc, desc +from sqlalchemy import FromClause, and_, asc, column, desc, or_, select from sqlalchemy.engine import Compiled, default from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList from sqlalchemy.sql.expression import ColumnClause, ColumnOperators @@ -30,7 +30,6 @@ SortDirection, ) - _logger = logging.getLogger(__name__) # SQLAlchemy comparison operator functions diff --git a/packages/kestrel_core/src/kestrel/interface/manager.py b/packages/kestrel_core/src/kestrel/interface/manager.py index b5fd0904..7775b326 100644 --- a/packages/kestrel_core/src/kestrel/interface/manager.py +++ b/packages/kestrel_core/src/kestrel/interface/manager.py @@ -1,23 +1,24 @@ from __future__ import annotations + import importlib -import pkgutil -import logging import inspect -import sys import itertools +import logging +import pkgutil +import sys from copy import copy +from typing import Iterable, Mapping, Type + from typeguard import typechecked -from typing import Mapping, Iterable, Type +from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER from kestrel.exceptions import ( + ConflictingInterfaceScheme, InterfaceNotConfigured, InterfaceNotFound, InvalidInterfaceImplementation, - ConflictingInterfaceScheme, ) from kestrel.interface.base import MODULE_PREFIX, AbstractInterface -from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER - _logger = logging.getLogger(__name__) diff --git a/packages/kestrel_core/src/kestrel/ir/filter.py b/packages/kestrel_core/src/kestrel/ir/filter.py index ebdd6856..b65c545a 100644 --- a/packages/kestrel_core/src/kestrel/ir/filter.py +++ b/packages/kestrel_core/src/kestrel/ir/filter.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typeguard import typechecked from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import List, Optional, Union, Iterable, Any, Callable +from typing import Any, Callable, Iterable, List, Optional, Union from mashumaro.mixins.json import DataClassJSONMixin +from typeguard import typechecked class NumCompOp(str, Enum): diff --git a/packages/kestrel_core/src/kestrel/ir/graph.py b/packages/kestrel_core/src/kestrel/ir/graph.py index 725b90df..bcaa518e 100644 --- a/packages/kestrel_core/src/kestrel/ir/graph.py +++ b/packages/kestrel_core/src/kestrel/ir/graph.py @@ -1,46 +1,54 @@ from __future__ import annotations -from typeguard import typechecked -from typing import Any, Iterable, Tuple, Mapping, MutableMapping, Union, Optional + +import json +import logging from collections import defaultdict from itertools import combinations +from typing import Any, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from uuid import UUID + import networkx -import json -from kestrel.ir.instructions import ( - Instruction, - TransformingInstruction, - SolePredecessorTransformingInstruction, - IntermediateInstruction, - SourceInstruction, - Variable, - DataSource, - Reference, - Return, - Filter, - ProjectAttrs, - instruction_from_dict, -) -from kestrel.ir.filter import ReferenceValue +from typeguard import typechecked + +from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER from kestrel.exceptions import ( - InstructionNotFound, - InvalidSeralizedGraph, - VariableNotFound, - ReferenceNotFound, + DanglingFilter, + DanglingReferenceInFilter, DataSourceNotFound, - DuplicatedVariable, - DuplicatedReference, DuplicatedDataSource, + DuplicatedReference, + DuplicatedReferenceInFilter, DuplicatedSingletonInstruction, - MultiInterfacesInGraph, - MultiSourcesInGraph, + DuplicatedVariable, InevaluableInstruction, + InstructionNotFound, + InvalidSeralizedGraph, LargerThanOneIndegreeInstruction, - DuplicatedReferenceInFilter, MissingReferenceInFilter, - DanglingReferenceInFilter, - DanglingFilter, + MultiInterfacesInGraph, + MultiSourcesInGraph, + ReferenceNotFound, + VariableNotFound, ) -from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER +from kestrel.ir.filter import ReferenceValue +from kestrel.ir.instructions import ( + Analytic, + AnalyticsInterface, + DataSource, + Filter, + Instruction, + IntermediateInstruction, + ProjectAttrs, + Reference, + Return, + SolePredecessorTransformingInstruction, + SourceInstruction, + TransformingInstruction, + Variable, + instruction_from_dict, +) + +_logger = logging.getLogger(__name__) @typechecked @@ -405,6 +413,8 @@ def get_trunk_n_branches( elif len(ps) > 1: raise DanglingReferenceInFilter(ps) return ps[0], r2n + elif isinstance(node, (Analytic, AnalyticsInterface)): + return ps[0], {} else: raise NotImplementedError(f"unknown instruction type: {node}") diff --git a/packages/kestrel_core/src/kestrel/ir/instructions.py b/packages/kestrel_core/src/kestrel/ir/instructions.py index 627b7ab0..2144f027 100644 --- a/packages/kestrel_core/src/kestrel/ir/instructions.py +++ b/packages/kestrel_core/src/kestrel/ir/instructions.py @@ -7,7 +7,10 @@ import uuid from dataclasses import InitVar, dataclass, field, fields from enum import Enum -from typing import Any, Iterable, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Iterable, List, Mapping, Optional, Type, Union + +from mashumaro.mixins.json import DataClassJSONMixin +from typeguard import typechecked from kestrel.__future__ import is_python_older_than_minor_version from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER @@ -23,8 +26,6 @@ get_references_from_exp, resolve_reference_with_function, ) -from mashumaro.mixins.json import DataClassJSONMixin -from typeguard import typechecked # https://stackoverflow.com/questions/70400639/how-do-i-get-python-dataclass-initvar-fields-to-work-with-typing-get-type-hints if is_python_older_than_minor_version(11): @@ -154,6 +155,17 @@ def __post_init__(self, uri: Optional[str], default_interface: Optional[str]): pass +@dataclass(eq=False) +class AnalyticsInterface(SourceInstruction): + interface: str + + +@dataclass(eq=False) +class Analytic(TransformingInstruction): + name: str + params: Mapping[str, Union[str, int, float, bool]] + + @dataclass(eq=False) class Variable(SolePredecessorTransformingInstruction): name: str diff --git a/packages/kestrel_core/src/kestrel/mapping/data_model.py b/packages/kestrel_core/src/kestrel/mapping/data_model.py index 45e7b0a1..a4651117 100644 --- a/packages/kestrel_core/src/kestrel/mapping/data_model.py +++ b/packages/kestrel_core/src/kestrel/mapping/data_model.py @@ -9,12 +9,9 @@ from pandas import DataFrame from typeguard import typechecked -from kestrel.mapping.transformers import ( - run_transformer, - run_transformer_on_series, -) -from kestrel.utils import list_folder_files from kestrel.exceptions import IncompleteDataMapping +from kestrel.mapping.transformers import run_transformer, run_transformer_on_series +from kestrel.utils import list_folder_files _logger = logging.getLogger(__name__) diff --git a/packages/kestrel_core/src/kestrel/mapping/transformers.py b/packages/kestrel_core/src/kestrel/mapping/transformers.py index 3f8ea8c5..8dc55f10 100644 --- a/packages/kestrel_core/src/kestrel/mapping/transformers.py +++ b/packages/kestrel_core/src/kestrel/mapping/transformers.py @@ -7,7 +7,6 @@ from kestrel.mapping.path import Path - # Dict of "registered" transformers _transformers = {} diff --git a/packages/kestrel_core/src/kestrel/session.py b/packages/kestrel_core/src/kestrel/session.py index 9196e94a..bb39dcf3 100644 --- a/packages/kestrel_core/src/kestrel/session.py +++ b/packages/kestrel_core/src/kestrel/session.py @@ -1,18 +1,19 @@ import logging from contextlib import AbstractContextManager -from uuid import uuid4 from typing import Iterable +from uuid import uuid4 + from typeguard import typechecked -from kestrel.display import Display, GraphExplanation -from kestrel.ir.graph import IRGraph -from kestrel.ir.instructions import Instruction, Explain -from kestrel.frontend.parser import parse_kestrel +from kestrel.analytics import PythonAnalyticsInterface from kestrel.cache import SqlCache from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER -from kestrel.interface import InterfaceManager +from kestrel.display import Display, GraphExplanation from kestrel.exceptions import InstructionNotFound - +from kestrel.frontend.parser import parse_kestrel +from kestrel.interface import InterfaceManager +from kestrel.ir.graph import IRGraph +from kestrel.ir.instructions import Explain, Instruction _logger = logging.getLogger(__name__) @@ -27,7 +28,11 @@ def __init__(self): # load all interfaces; cache is a special interface cache = SqlCache() - self.interface_manager = InterfaceManager([cache]) + + # Python analytics are "built-in" + pyanalytics = PythonAnalyticsInterface() + + self.interface_manager = InterfaceManager([cache, pyanalytics]) def execute(self, huntflow_block: str) -> Iterable[Display]: """Execute a Kestrel huntflow block. @@ -90,16 +95,18 @@ def evaluate_instruction(self, ins: Instruction) -> Display: while True: for g in self.irgraph.find_dependent_subgraphs_of_node(ins, _cache): interface = _interface_manager[g.interface] + _logger.debug("eval: subgraph: %s", [i.instruction for i in g.nodes()]) + _logger.debug("eval: interface = %s", interface) for iid, _display in ( interface.explain_graph(g) if is_explain - else interface.evaluate_graph(g) + else interface.evaluate_graph(g, _cache) ).items(): if is_explain: display.graphlets.append(_display) + _cache[iid] = True # virtual cache; value type does not matter else: display = _display - if interface is not _cache: _cache[iid] = display if iid == ins.id: return display diff --git a/packages/kestrel_core/src/kestrel/utils.py b/packages/kestrel_core/src/kestrel/utils.py index 02cbb5b3..d834d5a0 100644 --- a/packages/kestrel_core/src/kestrel/utils.py +++ b/packages/kestrel_core/src/kestrel/utils.py @@ -1,11 +1,13 @@ import collections.abc -from importlib import resources -from kestrel.__future__ import is_python_older_than_minor_version import os +from importlib import resources from pathlib import Path from pkgutil import get_data +from typing import Iterable, Mapping, Optional + from typeguard import typechecked -from typing import Optional, Mapping, Iterable + +from kestrel.__future__ import is_python_older_than_minor_version @typechecked diff --git a/packages/kestrel_core/tests/conftest.py b/packages/kestrel_core/tests/conftest.py new file mode 100644 index 00000000..8b81dac9 --- /dev/null +++ b/packages/kestrel_core/tests/conftest.py @@ -0,0 +1,10 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(tmpdir): + # Setup: remove any old DB + Path("cache.db").unlink(missing_ok=True) + yield # this is where the testing happens diff --git a/packages/kestrel_core/tests/test_analytic.py b/packages/kestrel_core/tests/test_analytic.py new file mode 100644 index 00000000..aa840572 --- /dev/null +++ b/packages/kestrel_core/tests/test_analytic.py @@ -0,0 +1,25 @@ +import logging +import os + +from pandas import DataFrame + +_logger = logging.getLogger(__name__) + + +def do_something(df: DataFrame, **kwargs): + _logger.debug("python analytics: run pseudo-analytic") + for k, v in kwargs.items(): + df[k] = v + return df + + +def do_something_no_annotations(df): + _logger.debug("python analytics: run pseudo-analytic with env vars") + name = os.environ.get("name", "new_column") + value = int(os.environ.get("value", 0)) + df[name] = value + return df + + +def do_something_env(df: DataFrame): + return do_something_no_annotations(df) diff --git a/packages/kestrel_core/tests/test_cache_inmemory.py b/packages/kestrel_core/tests/test_cache_inmemory.py index 7d0b84bc..0cc8d07f 100644 --- a/packages/kestrel_core/tests/test_cache_inmemory.py +++ b/packages/kestrel_core/tests/test_cache_inmemory.py @@ -40,7 +40,7 @@ def test_eval_new_filter_disp(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = InMemoryCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) # check the return is correct rets = graph.get_returns() @@ -77,7 +77,7 @@ def test_eval_filter_with_ref(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = InMemoryCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) # check the return is correct rets = graph.get_returns() @@ -96,7 +96,7 @@ def test_get_virtual_copy(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = InMemoryCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) v = c.get_virtual_copy() new_entry = uuid4() v[new_entry] = True diff --git a/packages/kestrel_core/tests/test_cache_sqlite.py b/packages/kestrel_core/tests/test_cache_sqlite.py index 9b134256..d9280d07 100644 --- a/packages/kestrel_core/tests/test_cache_sqlite.py +++ b/packages/kestrel_core/tests/test_cache_sqlite.py @@ -39,7 +39,7 @@ def test_eval_new_disp(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = SqlCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) # check the return is correct rets = graph.get_returns() @@ -64,7 +64,7 @@ def test_eval_new_filter_disp(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = SqlCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) # check the return is correct rets = graph.get_returns() @@ -93,7 +93,7 @@ def test_eval_two_returns(): # first DISP gs = graph.find_dependent_subgraphs_of_node(rets[0], c) assert len(gs) == 1 - mapping = c.evaluate_graph(gs[0]) + mapping = c.evaluate_graph(gs[0], c) df1 = DataFrame([ {"name": "explorer.exe", "pid": 99} , {"name": "firefox.exe", "pid": 201} , {"name": "chrome.exe", "pid": 205} @@ -104,7 +104,7 @@ def test_eval_two_returns(): # second DISP gs = graph.find_dependent_subgraphs_of_node(rets[1], c) assert len(gs) == 1 - mapping = c.evaluate_graph(gs[0]) + mapping = c.evaluate_graph(gs[0], c) df2 = DataFrame([ {"pid": 99} , {"pid": 201} , {"pid": 205} @@ -127,7 +127,7 @@ def test_issue_446(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = SqlCache() - _ = c.evaluate_graph(graph) + _ = c.evaluate_graph(graph, c) def test_eval_filter_with_ref(): @@ -144,7 +144,7 @@ def test_eval_filter_with_ref(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = SqlCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) # check the return is correct rets = graph.get_returns() @@ -163,7 +163,7 @@ def test_get_virtual_copy(): """ graph = IRGraphEvaluable(parse_kestrel(stmt)) c = SqlCache() - mapping = c.evaluate_graph(graph) + mapping = c.evaluate_graph(graph, c) v = c.get_virtual_copy() new_entry = uuid4() v[new_entry] = True diff --git a/packages/kestrel_core/tests/test_ir_graph.py b/packages/kestrel_core/tests/test_ir_graph.py index 27a23a06..98df8ef7 100644 --- a/packages/kestrel_core/tests/test_ir_graph.py +++ b/packages/kestrel_core/tests/test_ir_graph.py @@ -301,7 +301,7 @@ def test_find_dependent_subgraphs_of_node(): assert len(gs[1]) == 6 assert Counter(map(type, gs[1].nodes())) == Counter([Filter, Filter, Variable, Variable, ProjectEntity, DataSource]) - c.evaluate_graph(gs[0]) + c.evaluate_graph(gs[0], c) assert p1_projattr.id in c assert p1.id in c assert len(c) == 2 diff --git a/packages/kestrel_core/tests/test_session.py b/packages/kestrel_core/tests/test_session.py index 2448714f..c67bee74 100644 --- a/packages/kestrel_core/tests/test_session.py +++ b/packages/kestrel_core/tests/test_session.py @@ -259,34 +259,94 @@ def schemes(): assert isinstance(disp, GraphExplanation) assert len(disp.graphlets) == 4 + # DISP procs assert len(disp.graphlets[0].graph["nodes"]) == 5 query = disp.graphlets[0].query.statement.replace('"', '') procs = session.irgraph.get_variable("procs") c1 = next(session.irgraph.predecessors(procs)) assert query == f"WITH procs AS \n(SELECT * \nFROM {c1.id.hex}), \np2 AS \n(SELECT * \nFROM procs \nWHERE name IN ('firefox.exe', 'chrome.exe'))\n SELECT pid \nFROM p2" + # DISP nt assert len(disp.graphlets[1].graph["nodes"]) == 2 query = disp.graphlets[1].query.statement.replace('"', '') nt = session.irgraph.get_variable("nt") c2 = next(session.irgraph.predecessors(nt)) assert query == f"WITH nt AS \n(SELECT * \nFROM {c2.id.hex})\n SELECT * \nFROM nt" - # the current session.execute_to_generate() logic does not store - # in cache if evaluated by cache; the behavior may change in the future + # DISP domain assert len(disp.graphlets[2].graph["nodes"]) == 2 query = disp.graphlets[2].query.statement.replace('"', '') domain = session.irgraph.get_variable("domain") c3 = next(session.irgraph.predecessors(domain)) assert query == f"WITH domain AS \n(SELECT * \nFROM {c3.id.hex})\n SELECT * \nFROM domain" - assert len(disp.graphlets[3].graph["nodes"]) == 12 + # EXPLAIN d2 + assert len(disp.graphlets[3].graph["nodes"]) == 11 query = disp.graphlets[3].query.statement.replace('"', '') p2 = session.irgraph.get_variable("p2") p2pa = next(session.irgraph.successors(p2)) - assert query == f"WITH domain AS \n(SELECT * \nFROM {c3.id.hex}), \nntx AS \n(SELECT * \nFROM {nt.id.hex}v \nWHERE pid IN (SELECT * \nFROM {p2pa.id.hex}v)), \nd2 AS \n(SELECT * \nFROM domain \nWHERE ip IN (SELECT destination \nFROM ntx))\n SELECT * \nFROM d2" + assert query == f"WITH ntx AS \n(SELECT * \nFROM {nt.id.hex}v \nWHERE pid IN (SELECT * \nFROM {p2pa.id.hex}v)), \nd2 AS \n(SELECT * \nFROM {domain.id.hex}v \nWHERE ip IN (SELECT destination \nFROM ntx))\n SELECT * \nFROM d2" df_ref = DataFrame([{"ip": "1.1.1.2", "domain": "xyz.cloudflare.com"}]) assert df_ref.equals(df_res) for db_file in extra_db: os.remove(db_file) + + +def test_apply_on_construct(): + hf = """ +proclist = NEW process [ {"name": "cmd.exe", "pid": 123} + , {"name": "explorer.exe", "pid": 99} + , {"name": "firefox.exe", "pid": 201} + , {"name": "chrome.exe", "pid": 205} + ] +APPLY python://something ON proclist WITH foo=abc,bar=1,baz=1.5 +DISP proclist ATTR name, foo, bar, baz +""" + b1 = DataFrame([ {"name": "cmd.exe", "foo": "abc", "bar": 1, "baz": 1.5} + , {"name": "explorer.exe", "foo": "abc", "bar": 1, "baz": 1.5} + , {"name": "firefox.exe", "foo": "abc", "bar": 1, "baz": 1.5} + , {"name": "chrome.exe", "foo": "abc", "bar": 1, "baz": 1.5} + ]) + with Session() as session: + # Add test analytic + test_dir = os.path.dirname(os.path.abspath(__file__)) + session.interface_manager["python"].config["something"] = { + "module": os.path.join(test_dir, "test_analytic.py"), + "func": "do_something" + } + res = session.execute_to_generate(hf) + disp = next(res) + assert b1.equals(disp) + with pytest.raises(StopIteration): + next(res) + + +def test_apply_on_construct_use_env(): + hf = """ +proclist = NEW process [ {"name": "cmd.exe", "pid": 123} + , {"name": "explorer.exe", "pid": 99} + , {"name": "firefox.exe", "pid": 201} + , {"name": "chrome.exe", "pid": 205} + ] +APPLY python://something ON proclist WITH name=foo,value=1 +DISP proclist ATTR name, foo +""" + b1 = DataFrame([ {"name": "cmd.exe", "foo": 1} + , {"name": "explorer.exe", "foo": 1} + , {"name": "firefox.exe", "foo": 1} + , {"name": "chrome.exe", "foo": 1} + ]) + with Session() as session: + # Add test analytic + test_dir = os.path.dirname(os.path.abspath(__file__)) + session.interface_manager["python"].config["something"] = { + "module": os.path.join(test_dir, "test_analytic.py"), + "func": "do_something_env" + } + res = session.execute_to_generate(hf) + disp = next(res) + assert b1.equals(disp) + with pytest.raises(StopIteration): + next(res) diff --git a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/config.py b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/config.py index 185dd7aa..b0ce0732 100644 --- a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/config.py +++ b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/config.py @@ -6,16 +6,15 @@ from kestrel.config.utils import ( CONFIG_DIR_DEFAULT, - load_user_config, load_kestrel_config, + load_user_config, ) from kestrel.exceptions import InterfaceNotConfigured from kestrel.mapping.data_model import ( - load_default_mapping, check_entity_identifier_existence_in_mapping, + load_default_mapping, ) - PROFILE_PATH_DEFAULT = CONFIG_DIR_DEFAULT / "opensearch.yaml" PROFILE_PATH_ENV_VAR = "KESTREL_OPENSEARCH_CONFIG" diff --git a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py index 332f00be..8b77d341 100644 --- a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py +++ b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py @@ -1,7 +1,9 @@ import logging -from typing import Iterable, Mapping, Optional, Tuple +from typing import Any, Iterable, Mapping, MutableMapping, Optional, Tuple from uuid import UUID +from kestrel_interface_opensearch.config import load_config +from kestrel_interface_opensearch.ossql import OpenSearchTranslator from opensearchpy import OpenSearch from pandas import DataFrame, concat @@ -11,20 +13,16 @@ from kestrel.ir.graph import IRGraphEvaluable from kestrel.ir.instructions import ( DataSource, + Filter, Instruction, Return, - Variable, - Filter, + SolePredecessorTransformingInstruction, SourceInstruction, TransformingInstruction, - SolePredecessorTransformingInstruction, + Variable, ) from kestrel.mapping.data_model import translate_dataframe -from kestrel_interface_opensearch.config import load_config -from kestrel_interface_opensearch.ossql import OpenSearchTranslator - - _logger = logging.getLogger(__name__) @@ -111,6 +109,7 @@ def store( def evaluate_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_evaluate: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, DataFrame]: mapping = {} diff --git a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/ossql.py b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/ossql.py index fa17e817..0589cdc3 100644 --- a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/ossql.py +++ b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/ossql.py @@ -30,7 +30,6 @@ translate_projection_to_native, ) - _logger = logging.getLogger(__name__) diff --git a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/config.py b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/config.py index 10d84228..1a6e0412 100644 --- a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/config.py +++ b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/config.py @@ -6,16 +6,15 @@ from kestrel.config.utils import ( CONFIG_DIR_DEFAULT, - load_user_config, load_kestrel_config, + load_user_config, ) from kestrel.exceptions import InterfaceNotConfigured from kestrel.mapping.data_model import ( - load_default_mapping, check_entity_identifier_existence_in_mapping, + load_default_mapping, ) - PROFILE_PATH_DEFAULT = CONFIG_DIR_DEFAULT / "sqlalchemy.yaml" PROFILE_PATH_ENV_VAR = "KESTREL_SQLALCHEMY_CONFIG" diff --git a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py index 6463551a..eea85ee9 100644 --- a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py +++ b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py @@ -1,10 +1,11 @@ import logging from functools import reduce -from typing import Callable, Iterable, Mapping, Optional +from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional from uuid import UUID -from pandas import DataFrame, read_sql import sqlalchemy +from kestrel_interface_sqlalchemy.config import load_config +from pandas import DataFrame, read_sql from sqlalchemy import column, or_ from sqlalchemy.sql.expression import ColumnClause from typeguard import typechecked @@ -39,9 +40,6 @@ translate_projection_to_native, ) -from kestrel_interface_sqlalchemy.config import load_config - - _logger = logging.getLogger(__name__) @@ -191,6 +189,7 @@ def evaluate_graph( def explain_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_explain: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, GraphletExplanation]: mapping = {} diff --git a/packages/kestrel_interface_sqlalchemy/tests/test_translate.py b/packages/kestrel_interface_sqlalchemy/tests/test_translate.py new file mode 100644 index 00000000..396a7562 --- /dev/null +++ b/packages/kestrel_interface_sqlalchemy/tests/test_translate.py @@ -0,0 +1,118 @@ +from datetime import datetime +from dateutil import parser + +from kestrel_interface_sqlalchemy.interface import SQLAlchemyTranslator +from kestrel.exceptions import UnsupportedOperatorError +from kestrel.ir.filter import ( + ExpOp, + IntComparison, + ListOp, + ListComparison, + MultiComp, + NumCompOp, + StrCompOp, + StrComparison, + TimeRange +) +from kestrel.ir.instructions import ( + Filter, + Limit, + Offset, + ProjectAttrs, + ProjectEntity, + Sort +) + +# Use sqlite3 for testing +import sqlalchemy + +import pytest + + +ENGINE = sqlalchemy.create_engine("sqlite:///test.db") +DIALECT = ENGINE.dialect +TABLE = sqlalchemy.table("my_table") + + +TIMEFMT = '%Y-%m-%dT%H:%M:%S.%fZ' + + +def timefmt(dt: datetime): + return f"{dt}Z" + + +# A much-simplified test mapping +data_model_map = { + "process": { + "cmd_line": "CommandLine", + "file": { + "path": "Image", + # "name": [ + # { + # "native_field": "Image", + # "native_value": "basename", + # "ocsf_op": "LIKE", + # "ocsf_value": "endswith" + # } + # ] + }, + "pid": "ProcessId", + "parent_process": { + "pid": "ParentProcessId", + }, + }, +} + +def _dt(timestr: str) -> datetime: + return parser.parse(timestr) + + +def _remove_nl(s): + return s.replace('\n', '') + + +@pytest.mark.parametrize( + "iseq, sql", [ + # Try a simple filter + ([Filter(IntComparison('foo', NumCompOp.GE, 0))], + "SELECT {} FROM my_table WHERE foo >= ?"), + # Try a simple filter with sorting + ([Filter(IntComparison('foo', NumCompOp.GE, 0)), Sort('bar')], + "SELECT {} FROM my_table WHERE foo >= ? ORDER BY bar DESC"), + # Simple filter plus time range + ([Filter(IntComparison('foo', NumCompOp.GE, 0), timerange=TimeRange(_dt('2023-12-06T08:17:00Z'), _dt('2023-12-07T08:17:00Z')))], + "SELECT {} FROM my_table WHERE foo >= ? AND timestamp >= ? AND timestamp < ?"), + # Add a limit and projection + ([Limit(3), ProjectAttrs(['foo', 'bar', 'baz']), Filter(StrComparison('foo', StrCompOp.EQ, 'abc'))], + "SELECT foo AS foo, bar AS bar, baz AS baz FROM my_table WHERE foo = ? LIMIT ? OFFSET ?"), + # Same as above but reverse order + ([Filter(StrComparison('foo', StrCompOp.EQ, 'abc')), ProjectAttrs(['foo', 'bar', 'baz']), Limit(3)], + "SELECT foo AS foo, bar AS bar, baz AS baz FROM my_table WHERE foo = ? LIMIT ? OFFSET ?"), + ([Filter(ListComparison('foo', ListOp.NIN, ['abc', 'def']))], + "SELECT {} FROM my_table WHERE (foo NOT IN (__[POSTCOMPILE_foo_1]))"), + ([Filter(StrComparison('foo', StrCompOp.MATCHES, '.*abc.*'))], + "SELECT {} FROM my_table WHERE foo REGEXP ?"), + ([Filter(StrComparison('foo', StrCompOp.NMATCHES, '.*abc.*'))], + "SELECT {} FROM my_table WHERE foo NOT REGEXP ?"), + ([Filter(MultiComp(ExpOp.OR, [IntComparison('foo', NumCompOp.EQ, 1), IntComparison('bar', NumCompOp.EQ, 1)]))], + "SELECT {} FROM my_table WHERE foo = ? OR bar = ?"), + ([Filter(MultiComp(ExpOp.AND, [IntComparison('foo', NumCompOp.EQ, 1), IntComparison('bar', NumCompOp.EQ, 1)]))], + "SELECT {} FROM my_table WHERE foo = ? AND bar = ?"), + ([Limit(1000), Offset(2000)], + "SELECT {} FROM my_table LIMIT ? OFFSET ?"), + # Test entity projection + ([Limit(3), Filter(StrComparison('cmd_line', StrCompOp.EQ, 'foo bar')), ProjectEntity('process', 'process')], + "SELECT {} FROM my_table WHERE \"CommandLine\" = ? LIMIT ? OFFSET ?"), + ] +) +def test_sqlalchemy_translator(iseq, sql): + if ProjectEntity in {type(i) for i in iseq}: + cols = '"CommandLine" AS cmd_line, "Image" AS "file.path", "ProcessId" AS pid, "ParentProcessId" AS "parent_process.pid"' + else: + cols = '"CommandLine" AS "process.cmd_line", "Image" AS "process.file.path", "ProcessId" AS "process.pid", "ParentProcessId" AS "process.parent_process.pid"' + trans = SQLAlchemyTranslator(DIALECT, timefmt, "timestamp", TABLE, data_model_map) + for i in iseq: + trans.add_instruction(i) + #result = trans.result_w_literal_binds() + result = trans.result() + assert _remove_nl(str(result)) == sql.format(cols)