Skip to content

Commit

Permalink
Merge branch 'Dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
troyyyyy committed Jan 9, 2024
2 parents dcfc535 + b6145e3 commit 12d8768
Show file tree
Hide file tree
Showing 19 changed files with 251 additions and 103 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

</div>

# Abductive Learning (ABL) Kit
# ABL Kit: A Python Toolkit for Abductive Learning

**ABL Kit** is an efficient Python toolkit for **Abductive Learning (ABL)**.
ABL is a novel paradigm that integrates machine learning and
Expand Down
10 changes: 8 additions & 2 deletions ablkit/bridge/base_bridge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This module contains the base class for the Bridge part.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -31,11 +37,11 @@ class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(type(model))
f"Expected an instance of ABLModel, but received type: {type(model)}"
)
if not isinstance(reasoner, Reasoner):
raise TypeError(
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner))
f"Expected an instance of Reasoner, but received type: {type(reasoner)}"
)

self.model = model
Expand Down
22 changes: 15 additions & 7 deletions ablkit/bridge/simple_bridge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This module contains a simple implementation of the Bridge part.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

import os.path as osp
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -221,7 +227,7 @@ def train(
Labeled data should be in the same format as ``train_data``. The only difference is
that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
utilized to train the model. Defaults to None.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long
Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
Expand Down Expand Up @@ -327,10 +333,11 @@ def valid(
Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be
either None or not, which depends on the evaluation metircs in ``self.metric_list``.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``.
"""
val_data_examples = self.data_preprocess("val", val_data)
self._valid(val_data_examples)
Expand All @@ -346,10 +353,11 @@ def test(
Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y``
can be either None or not, which depends on the evaluation metircs in ``self.metric_list``.
can be either None or not, which depends on the evaluation metircs in
``self.metric_list``.
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)
Expand Down
6 changes: 6 additions & 0 deletions ablkit/data/evaluation/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This module contains the base class used for evaluation.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional
Expand Down
10 changes: 9 additions & 1 deletion ablkit/data/evaluation/reasoning_metric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This module contains the ReasoningMetric, which is used for evaluating the model performance
on tasks that need reasoning.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Optional

from ...reasoning import KBBase
Expand All @@ -7,7 +14,7 @@

class ReasoningMetric(BaseMetric):
"""
A metrics class for evaluating the model performance on tasks need reasoning.
A metrics class for evaluating the model performance on tasks that need reasoning.
This class is designed to calculate the accuracy of the reasoing results. Reasoning
results are generated by first using the learning part to predict pseudo-labels
Expand All @@ -34,6 +41,7 @@ def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb

# pylint: disable=protected-access
def process(self, data_examples: ListData) -> None:
"""
Process a batch of data examples.
Expand Down
9 changes: 5 additions & 4 deletions ablkit/data/evaluation/symbol_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Optional
"""
This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

import numpy as np

Expand All @@ -20,9 +24,6 @@ class SymbolAccuracy(BaseMetric):
metrics of different tasks. Inherits from BaseMetric. Default to None.
"""

def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_examples: ListData) -> None:
"""
Processes a batch of data examples.
Expand Down
8 changes: 5 additions & 3 deletions ablkit/data/structures/base_data_element.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py
"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa: E501 pylint: disable=line-too-long
"""

import copy
from typing import Any, Iterator, Optional, Tuple, Type, Union
Expand Down
12 changes: 7 additions & 5 deletions ablkit/data/structures/list_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long
"""

from typing import List, Union

Expand Down Expand Up @@ -54,7 +56,7 @@ class ListData(BaseDataElement):
``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.
This design is inspired by and extends the functionalities of the ``BaseDataElement``
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 pylint: disable=line-too-long
Examples:
>>> from ablkit.data.structures import ListData
Expand All @@ -72,7 +74,7 @@ class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main
DATA FIELDS
Y: [1, 2, 3]
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long
) at 0x7f3bbf1991c0>
>>> print(data_examples[:1])
<ListData(
Expand Down
34 changes: 20 additions & 14 deletions ablkit/learning/abl_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This module contains the class ABLModel, which provides a unified interface for different
machine learning models.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

import pickle
from typing import Any, Dict

Expand Down Expand Up @@ -99,21 +106,20 @@ def _model_operation(self, operation: str, *args, **kwargs):
method = getattr(model, operation)
method(*args, **kwargs)
else:
if f"{operation}_path" not in kwargs.keys():
if f"{operation}_path" not in kwargs:
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except (OSError, pickle.PickleError):
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
)
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except (OSError, pickle.PickleError) as exc:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
) from exc

def save(self, *args, **kwargs) -> None:
"""
Expand Down
8 changes: 7 additions & 1 deletion ablkit/learning/basic_nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -474,7 +480,7 @@ def _data_loader(
raise ValueError("X should not be None.")
if y is None:
y = [0] * len(X)
if not (len(y) == len(X)):
if not len(y) == len(X):
raise ValueError("X and y should have equal length.")

dataset = ClassificationDataset(X, y, transform=self.train_transform)
Expand Down
6 changes: 6 additions & 0 deletions ablkit/learning/torch_dataset/classification_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for classification.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, Callable, List, Tuple, Optional

import torch
Expand Down
6 changes: 6 additions & 0 deletions ablkit/learning/torch_dataset/prediction_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for Prediction.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, Callable, List, Tuple, Optional

import torch
Expand Down
6 changes: 6 additions & 0 deletions ablkit/learning/torch_dataset/regression_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for regression.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, List, Tuple

from torch.utils.data import Dataset
Expand Down
27 changes: 19 additions & 8 deletions ablkit/reasoning/kb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers
for different kinds of knowledge bases.
Copyright (c) 2024 LAMDA. All rights reserved.
"""

import bisect
import inspect
import logging
Expand Down Expand Up @@ -394,7 +401,7 @@ def abduce_candidates(
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list:
if not self.GKB or len(pseudo_label) not in self.GKB_len_list:
return [], []

all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y)
Expand Down Expand Up @@ -478,7 +485,7 @@ def __init__(self, pseudo_label_list: List[Any], pl_file: str):
super().__init__(pseudo_label_list)

try:
import pyswip
import pyswip # pylint: disable=import-outside-toplevel
except (IndexError, ImportError):
print(
"A Prolog-based knowledge base is in use. Please install SWI-Prolog using the"
Expand All @@ -493,7 +500,7 @@ def __init__(self, pseudo_label_list: List[Any], pl_file: str):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)

def logic_forward(self, pseudo_label: List[Any]) -> Any:
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the
returned ``Res`` as the reasoning results. To use this default function, there must be
Expand All @@ -504,11 +511,15 @@ def logic_forward(self, pseudo_label: List[Any]) -> Any:
----------
pseudo_label : List[Any]
Pseudo-labels of an example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"]
result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"]
if result == "true":
return True
elif result == "false":
if result == "false":
return False
return result

Expand All @@ -517,7 +528,7 @@ def _revision_pseudo_label(
pseudo_label: List[Any],
revision_idx: List[int],
) -> List[Any]:
import re
import re # pylint: disable=import-outside-toplevel

revision_pseudo_label = pseudo_label.copy()
revision_pseudo_label = flatten(revision_pseudo_label)
Expand All @@ -533,7 +544,7 @@ def get_query_string(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
x: List[Any], # pylint: disable=unused-argument
revision_idx: List[int],
) -> str:
"""
Expand Down Expand Up @@ -563,7 +574,7 @@ def get_query_string(
query_string = "logic_forward("
query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
query_string += f",{y})." if not key_is_none_flag else ")."
return query_string

def revise_at_idx(
Expand Down
Loading

0 comments on commit 12d8768

Please sign in to comment.