Skip to content

Commit

Permalink
Merge pull request #544 from opencybersecurityalliance/hotfix-538-tem…
Browse files Browse the repository at this point in the history
…p-table

implement temp table to fix #538
  • Loading branch information
subbyte authored Jul 17, 2024
2 parents cf822f8 + df2aa90 commit d8b3d72
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 48 deletions.
23 changes: 22 additions & 1 deletion packages/kestrel_core/src/kestrel/interface/codegen/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
translate_comparison_to_native,
translate_projection_to_native,
)
from pandas import DataFrame
from pandas.io.sql import SQLTable, pandasSQL_builder
from sqlalchemy import and_, asc, column, desc, or_, select, tuple_
from sqlalchemy.engine import Compiled, default
from sqlalchemy.engine import Compiled, Connection, default
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
from sqlalchemy.sql.expression import CTE, ColumnElement, ColumnOperators
from sqlalchemy.sql.selectable import Select
Expand All @@ -62,6 +64,25 @@
}


@typechecked
class _TemporaryTable(SQLTable):
def _execute_create(self):
self.table = self.table.to_metadata(self.pd_sql.meta)
self.table._prefixes.append("TEMPORARY")
with self.pd_sql.run_transaction():
self.table.create(bind=self.pd_sql.con)


@typechecked
def ingest_dataframe_to_temp_table(conn: Connection, df: DataFrame, table_name: str):
with pandasSQL_builder(conn) as pandas_engine:
table = _TemporaryTable(
table_name, pandas_engine, frame=df, if_exists="replace", index=False
)
table.create()
df.to_sql(table_name, con=conn, if_exists="append", index=False)


@typechecked
class SqlTranslator:
def __init__(
Expand Down
3 changes: 3 additions & 0 deletions packages/kestrel_core/tests/test_cache_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from uuid import uuid4

import pytest
import sqlalchemy
from pandas import DataFrame, read_csv

from kestrel.cache import SqlCache
Expand Down Expand Up @@ -35,7 +36,9 @@ def test_sql_cache_set_get_del():
idx = uuid4()
df = DataFrame({'foo': [1, 2, 3]})
c[idx] = df

assert df.equals(c[idx])

del c[idx]
assert idx not in c

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kestrel.display import GraphletExplanation, NativeQuery
from kestrel.exceptions import SourceNotFound
from kestrel.interface import AbstractInterface
from kestrel.interface.codegen.sql import ingest_dataframe_to_temp_table
from kestrel.ir.graph import IRGraphEvaluable
from kestrel.ir.instructions import (
DataSource,
Expand Down Expand Up @@ -64,6 +65,10 @@ def store(
):
raise NotImplementedError("SQLAlchemyInterface.store") # TEMP

def __del__(self):
for conn in self.conns.values():
conn.close()

def evaluate_graph(
self,
graph: IRGraphEvaluable,
Expand Down Expand Up @@ -138,55 +143,58 @@ def _evaluate_instruction_in_graph(
subquery_memory = {}

if instruction.id in cache:
# 1. get the datasource assocaited with the cached node
ds = None
for node in iter_argument_from_function_in_callstack(
"_evaluate_instruction_in_graph", "instruction"
):
try:
ds = graph.find_datasource_of_node(node)
except SourceNotFound:
continue
else:
break
if not ds:
_logger.error(
"backed tracked entire stack but still do not find source"
)
raise SourceNotFound(instruction)

# 2. check the datasource config to see if the datalake supports write
ds_config = self.config.datasources[ds.datasource]
conn_config = self.config.connections[ds_config.connection]

# 3. perform table creation or in-memory cache
if conn_config.table_creation_permission:
table_name = "kestrel_temp_" + instruction.id.hex

# create a new table for the cached DataFrame
cache[instruction.id].to_sql(
table_name,
con=self.conns[ds_config.connection],
if_exists="replace",
index=False,
)
if instruction.id in subquery_memory:
translator = subquery_memory[instruction.id]
else:
# 1. get the datasource assocaited with the cached node
ds = None
for node in iter_argument_from_function_in_callstack(
"_evaluate_instruction_in_graph", "instruction"
):
try:
ds = graph.find_datasource_of_node(node)
except SourceNotFound:
continue
else:
break
if not ds:
_logger.error(
"backed tracked entire stack but still do not find source"
)
raise SourceNotFound(instruction)

# SELECT * from the new table
translator = SQLAlchemyTranslator(
NativeTable(
self.engines[ds_config.connection].dialect,
# 2. check the datasource config to see if the datalake supports write
ds_config = self.config.datasources[ds.datasource]
conn_config = self.config.connections[ds_config.connection]

# 3. perform table creation or in-memory cache
if conn_config.table_creation_permission:
table_name = instruction.id.hex

# write to temp table
ingest_dataframe_to_temp_table(
self.conns[ds_config.connection],
cache[instruction.id],
table_name,
ds_config,
list(cache[instruction.id]),
None,
None,
None,
)
)

else:
raise NotImplementedError("Read-only data lake not handled")
# list(cache[instruction.id].itertuples(index=False, name=None))
# SELECT * from the new table
translator = SQLAlchemyTranslator(
NativeTable(
self.engines[ds_config.connection].dialect,
table_name,
ds_config,
list(cache[instruction.id]),
None,
None,
None,
)
)
subquery_memory[instruction.id] = translator

else:
raise NotImplementedError("Read-only data lake not handled")
# list(cache[instruction.id].itertuples(index=False, name=None))

if isinstance(instruction, SourceInstruction):
if isinstance(instruction, DataSource):
Expand Down
46 changes: 44 additions & 2 deletions packages/kestrel_interface_sqlalchemy/tests/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
from uuid import uuid4
import sqlite3
from collections import Counter

import pytest
from pandas import DataFrame
import yaml
from kestrel_interface_sqlalchemy.config import PROFILE_PATH_ENV_VAR
from pandas import read_csv

from kestrel.interface.codegen.sql import ingest_dataframe_to_temp_table
from pandas import read_csv, DataFrame, read_sql
from kestrel import Session
from kestrel.ir.filter import MultiComp
from kestrel.ir.instructions import DataSource, Filter, ProjectEntity, Variable
Expand Down Expand Up @@ -45,6 +47,21 @@ def setup_sqlite_ecs_process_creation(tmp_path):
del os.environ[PROFILE_PATH_ENV_VAR]


def test_write_to_temp_table(setup_sqlite_ecs_process_creation):
with Session() as session:
datalake = session.interface_manager["sqlalchemy"]
idx = uuid4().hex
df = DataFrame({'foo': [1, 2, 3]})
conn_name = list(datalake.conns.keys())[0]
conn = datalake.conns[conn_name]
ingest_dataframe_to_temp_table(conn, df, idx)
assert read_sql(f'SELECT * FROM "{idx}"', conn).equals(df)
conn.close()
conn = datalake.engines[conn_name].connect()
assert read_sql(f'SELECT * FROM "{idx}"', conn).empty



@pytest.mark.parametrize(
"where, ocsf_field", [
("name = 'bash'", "process.name"),
Expand Down Expand Up @@ -175,6 +192,19 @@ def test_find_entity_to_event(setup_sqlite_ecs_process_creation):
assert e2.shape[1] == 74 # full event: refer to test_get_sinple_event() for number


def test_find_entity_to_event_2(setup_sqlite_ecs_process_creation):
with Session() as session:
huntflow = """
procs = GET process FROM sqlalchemy://events WHERE os.name = "Linux"
e2 = FIND event ORIGINATED BY procs
DISP e2
"""
e2 = session.execute(huntflow)[0]
assert e2.shape[0] == 4
assert list(e2["process.name"]) == ["uname", "cat", "ping", "curl"]
assert e2.shape[1] == 74 # full event: refer to test_get_sinple_event() for number


def test_find_entity_to_entity(setup_sqlite_ecs_process_creation):
with Session() as session:
huntflow = """
Expand All @@ -195,3 +225,15 @@ def test_find_entity_to_entity(setup_sqlite_ecs_process_creation):

assert parents.shape[0] == 2
assert list(parents) == ['endpoint.uid', 'file.endpoint.uid', 'user.endpoint.uid', 'endpoint.name', 'file.endpoint.name', 'user.endpoint.name', 'endpoint.os', 'file.endpoint.os', 'user.endpoint.os', 'cmd_line', 'name', 'pid', 'uid']


def test_find_entity_to_entity_2(setup_sqlite_ecs_process_creation):
with Session() as session:
huntflow = """
procs = GET process FROM sqlalchemy://events WHERE os.name = "Linux"
parents = FIND process CREATED procs
DISP parents
"""
parents = session.execute(huntflow)[0]
assert parents.shape[0] == 2
assert list(parents) == ['endpoint.uid', 'file.endpoint.uid', 'user.endpoint.uid', 'endpoint.name', 'file.endpoint.name', 'user.endpoint.name', 'endpoint.os', 'file.endpoint.os', 'user.endpoint.os', 'cmd_line', 'name', 'pid', 'uid']

0 comments on commit d8b3d72

Please sign in to comment.