diff --git a/FoldOptLib/input/input_data_processor.py b/FoldOptLib/input/input_data_processor.py index dcb4ded..b05a21c 100644 --- a/FoldOptLib/input/input_data_processor.py +++ b/FoldOptLib/input/input_data_processor.py @@ -1,19 +1,14 @@ import pandas as pd from typing import List, Optional, Dict -import numpy as np +import numpy from .input_data_checker import CheckInputData -from ..helper._helper import * from ..helper.utils import * - - -def _normalise_gradient(gradient: np.ndarray) -> np.ndarray: - """Normalise gradient vectors.""" - return gradient / np.linalg.norm(gradient, axis=1)[:, None] +from LoopStructural import BoundingBox class InputDataProcessor(CheckInputData): - def __init__(self, data: pd.DataFrame, bounding_box: np.ndarray, + def __init__(self, data: pd.DataFrame, bounding_box: BoundingBox, geological_knowledge: Dict = None) -> None: """ Constructs all the necessary attributes for the InputDataProcessor object. @@ -27,15 +22,13 @@ def __init__(self, data: pd.DataFrame, bounding_box: np.ndarray, geological_knowledge : Dict, optional geological knowledge dictionary. """ - super().__init__(data, bounding_box, geological_knowledge) # Assuming parent class requires this + super().__init__(data, bounding_box, geological_knowledge) self.data = data self.bounding_box = bounding_box self.knowledge_constraints = geological_knowledge def process_data(self): - check_data = CheckInputData(self.data, self.bounding_box, self.knowledge_constraints) - check_data.check_input_data() # check the input data is valid - + self.check_input_data() if 'strike' in self.data.columns and 'dip' in self.data.columns: strike = self.data['strike'].to_numpy() dip = self.data['dip'].to_numpy() @@ -45,7 +38,12 @@ def process_data(self): else: return None - gradient = _normalise_gradient(gradient) + gradient = InputDataProcessor.normalise(gradient) self.data['gx'], self.data['gy'], self.data['gz'] = gradient[:, 0], gradient[:, 1], gradient[:, 2] return self.data + + @staticmethod + def normalise(gradient: numpy.ndarray) -> numpy.ndarray: + """Normalise vectors.""" + return gradient / numpy.linalg.norm(gradient, axis=1)[:, None]