Skip to content

Commit

Permalink
Merge pull request #533 from opencybersecurityalliance/hotfix-postpro…
Browse files Browse the repository at this point in the history
…cess-entity-type

Upgrade Projection Entity
  • Loading branch information
subbyte authored Jul 3, 2024
2 parents f0b8e6a + e573adc commit 1251747
Show file tree
Hide file tree
Showing 17 changed files with 319 additions and 221 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/unit-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,29 @@ jobs:
run: pip install .
- name: Unit testing
run: pytest -vv

test-kestrel-interface-sqlalchemy:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
working-directory: ./packages/kestrel_interface_sqlalchemy
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install Python Tools
run: pip install --upgrade pip setuptools wheel pytest
- name: Install kestrel_core
working-directory: ./packages/kestrel_core
run: pip install .
- name: Install kestrel_interface_sqlalchemy
run: pip install .
- name: Unit testing
run: pytest -vv
1 change: 0 additions & 1 deletion packages/kestrel_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ dependencies = [
"mashumaro>=3.10",
"networkx>=3.1", # networkx==3.2.1 only for Python>=3.9
"SQLAlchemy>=2.0.23",
"dpath>=2.1.6",
]

[project.optional-dependencies]
Expand Down
67 changes: 42 additions & 25 deletions packages/kestrel_core/src/kestrel/frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
)
from kestrel.mapping.data_model import (
translate_comparison_to_ocsf,
translate_projection_to_ocsf,
translate_entity_projection_to_ocsf,
translate_attributes_projection_to_ocsf,
)
from kestrel.utils import unescape_quoted_string

Expand Down Expand Up @@ -100,10 +101,10 @@ def _create_comp(field: str, op_value: str, value) -> FComparison:

@typechecked
def _map_filter_exp(
native_entity_name: str,
mapped_entity_name: str,
native_projection_field: str,
ocsf_projection_field: str,
filter_exp: FExpression,
property_map: dict,
field_map: dict,
) -> FExpression:
if isinstance(
filter_exp,
Expand All @@ -114,21 +115,21 @@ def _map_filter_exp(
# init map_result from direct mapping from field
map_result = set(
translate_comparison_to_ocsf(
property_map, field, filter_exp.op, filter_exp.value
field_map, field, filter_exp.op, filter_exp.value
)
)
# there is a case that `field` omits the return entity (prefix)
# this is only alloed when it refers to the return entity
# add mapping for those cases
for full_field in (
f"{native_entity_name}:{field}",
f"{native_entity_name}.{field}",
f"{native_projection_field}:{field}",
f"{native_projection_field}.{field}",
):
map_result |= set(
filter(
lambda x: x[0].startswith(mapped_entity_name + "."),
lambda x: x[0].startswith(ocsf_projection_field + "."),
translate_comparison_to_ocsf(
property_map, full_field, filter_exp.op, filter_exp.value
field_map, full_field, filter_exp.op, filter_exp.value
),
)
)
Expand All @@ -154,11 +155,17 @@ def _map_filter_exp(
# recursively map boolean expressions
filter_exp = BoolExp(
_map_filter_exp(
native_entity_name, mapped_entity_name, filter_exp.lhs, property_map
native_projection_field,
ocsf_projection_field,
filter_exp.lhs,
field_map,
),
filter_exp.op,
_map_filter_exp(
native_entity_name, mapped_entity_name, filter_exp.rhs, property_map
native_projection_field,
ocsf_projection_field,
filter_exp.rhs,
field_map,
),
)
elif isinstance(filter_exp, MultiComp):
Expand All @@ -169,7 +176,9 @@ def _map_filter_exp(
filter_exp = MultiComp(
filter_exp.op,
[
_map_filter_exp(native_entity_name, mapped_entity_name, x, property_map)
_map_filter_exp(
native_projection_field, ocsf_projection_field, x, field_map
)
for x in filter_exp.comps
],
)
Expand All @@ -192,14 +201,14 @@ def __init__(
self,
default_sort_order=DEFAULT_SORT_ORDER,
token_prefix="",
entity_map={},
property_map={},
type_map={},
field_map={},
):
# token_prefix is the modification by Lark when using `merge_transformers()`
self.default_sort_order = default_sort_order
self.token_prefix = token_prefix
self.entity_map = entity_map
self.property_map = property_map # TODO: rename to data_model_map?
self.type_map = type_map
self.field_map = field_map
self.variable_map = {} # To cache var type info
super().__init__()

Expand Down Expand Up @@ -288,13 +297,18 @@ def variables(self, args):

def get(self, args):
graph = IRGraph()
entity_name = args[0].value
mapped_entity_name = self.entity_map.get(entity_name, entity_name)
native_projection_field = args[0].value
ocsf_projection_field = translate_entity_projection_to_ocsf(
self.field_map, native_projection_field
)

# prepare Filter node
filter_node = args[2]
filter_node.exp = _map_filter_exp(
args[0].value, mapped_entity_name, filter_node.exp, self.property_map
native_projection_field,
ocsf_projection_field,
filter_node.exp,
self.field_map,
)

# add basic Source and Filter nodes
Expand All @@ -305,7 +319,7 @@ def get(self, args):
_add_reference_branches_for_filter(graph, filter_node)

projection_node = graph.add_node(
ProjectEntity(mapped_entity_name, entity_name), filter_node
ProjectEntity(ocsf_projection_field, native_projection_field), filter_node
)
root = projection_node
if len(args) > 3:
Expand Down Expand Up @@ -467,8 +481,8 @@ def disp(self, args):
_logger.debug(
"Map %s attrs to OCSF %s in %s", native_type, entity_type, root
)
root.attrs = translate_projection_to_ocsf(
self.property_map, native_type, entity_type, root.attrs
root.attrs = translate_attributes_projection_to_ocsf(
self.field_map, native_type, entity_type, root.attrs
)
graph.add_node(Return(), root)
return graph
Expand All @@ -488,10 +502,13 @@ def _get_type_from_predecessors(self, graph: IRGraph, root: Instruction):
curr = stack.pop()
_logger.debug("_get_type: curr = %s", curr)
stack.extend(graph.predecessors(curr))
if isinstance(curr, Construct):
native_type = curr.entity_type
entity_type = self.entity_map.get(native_type, native_type)
if isinstance(curr, ProjectEntity):
native_type = curr.native_field
entity_type = self.type_map.get(curr.ocsf_field, curr.ocsf_field)
elif isinstance(curr, Variable):
native_type = curr.native_type
entity_type = curr.entity_type
elif isinstance(curr, Construct):
native_type = curr.entity_type
entity_type = self.type_map.get(native_type, native_type)
return entity_type, native_type
5 changes: 4 additions & 1 deletion packages/kestrel_core/src/kestrel/frontend/kestrel.lark
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ assignment: VARIABLE "=" expression

find: "FIND"i ENTITY_TYPE RELATION (REVERSED)? VARIABLE where_clause? timespan? limit_clause?

get: "GET"i ENTITY_TYPE ("FROM"i datasource)? where_clause timespan? limit_clause?
get: "GET"i PROJECT_FIELD ("FROM"i datasource)? where_clause timespan? limit_clause?

group: "GROUP"i VARIABLE BY grp_spec ("WITH"i agg_list)?

Expand Down Expand Up @@ -282,6 +282,9 @@ ATTRIBUTES: ATTRIBUTE (WS* "," WS* ATTRIBUTE)*
ECNAME: (LETTER|"_") (LETTER|DIGIT|"_"|"-")*
ECNAME_W_QUOTE: (LETTER|DIGIT|"_"|"-"|"'")+

// extend ECNAME with "." and ":"
PROJECT_FIELD: (LETTER|"_") (LETTER|DIGIT|"_"|"-"|"."|":")*

PATH_SIMPLE: (ECNAME "://")? (LETTER|DIGIT|"_"|"-"|"."|"/")+

PATH_ESCAPED: "\"" (ECNAME "://")? _STRING_ESC_INNER "\""
Expand Down
23 changes: 10 additions & 13 deletions packages/kestrel_core/src/kestrel/frontend/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,25 @@
_logger = logging.getLogger(__name__)


MAPPING_MODULE = "kestrel.mapping"

# cache mapping in the module
frontend_mapping = {}


@typechecked
def get_frontend_mapping(mapping_type: str, mapping_pkg: str, submodule: str) -> dict:
def get_frontend_mapping(submodule: str, do_reverse_mapping: bool = False) -> dict:
global frontend_mapping
if mapping_type not in frontend_mapping:
if submodule not in frontend_mapping:
mapping = {}
for f in list_folder_files(mapping_pkg, submodule, extension="yaml"):
for f in list_folder_files(MAPPING_MODULE, submodule, extension="yaml"):
with open(f, "r") as fp:
mapping_ind = yaml.safe_load(fp)
if mapping_type == "property":
# New data model map is always OCSF->native
if do_reverse_mapping:
mapping_ind = reverse_mapping(mapping_ind)
# the entity mapping or reversed property mapping is flatten structure
# so just dict.update() will do
mapping.update(mapping_ind)
frontend_mapping[mapping_type] = mapping
return frontend_mapping[mapping_type]
frontend_mapping[submodule] = mapping
return frontend_mapping[submodule]


@typechecked
Expand All @@ -55,10 +54,8 @@ def get_keywords():
load_data_file("kestrel.frontend", "kestrel.lark"),
parser="lalr",
transformer=_KestrelT(
entity_map=get_frontend_mapping("entity", "kestrel.mapping", "entityname"),
property_map=get_frontend_mapping(
"property", "kestrel.mapping", "entityattribute"
),
type_map=get_frontend_mapping("types"),
field_map=get_frontend_mapping("fields", True),
),
)

Expand Down
2 changes: 1 addition & 1 deletion packages/kestrel_core/src/kestrel/interface/codegen/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def add_ProjectAttrs(self, proj: ProjectAttrs) -> None:

def add_ProjectEntity(self, proj: ProjectEntity) -> None:
self.query = self.query.with_only_columns(
column(proj.entity_type)
column(proj.ocsf_field + ".*")
) # TODO: mapping?

def add_Limit(self, lim: Limit) -> None:
Expand Down
4 changes: 2 additions & 2 deletions packages/kestrel_core/src/kestrel/ir/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def resolve_references(self, f: Callable[[ReferenceValue], Any]):

@dataclass(eq=False)
class ProjectEntity(SolePredecessorTransformingInstruction):
entity_type: str
native_type: str
ocsf_field: str
native_field: str


@dataclass(eq=False)
Expand Down
Loading

0 comments on commit 1251747

Please sign in to comment.