From 98a58ee57490b91dba7380577203c0e502bb6357 Mon Sep 17 00:00:00 2001 From: Valentin Belyaev Date: Wed, 17 Jan 2024 06:23:14 +0300 Subject: [PATCH] Add GradientBoostingClassifier, add test cases for boosting --- .../__pycache__/boosting.cpython-310.pyc | Bin 3702 -> 6433 bytes src/ensemble/boosting.py | 87 ++++++++++++++++-- tests/base/__pycache__/config.cpython-310.pyc | Bin 1838 -> 1855 bytes tests/base/config.py | 1 + ...test_boosting.cpython-310-pytest-7.4.4.pyc | Bin 604 -> 890 bytes tests/ensemble/test_boosting.py | 12 ++- 6 files changed, 87 insertions(+), 13 deletions(-) diff --git a/src/ensemble/__pycache__/boosting.cpython-310.pyc b/src/ensemble/__pycache__/boosting.cpython-310.pyc index 526d2d1f6fd895e32a5a5c650e93967388fe1feb..1a2bf45ae704df205bc69968aa0504f6f4bc5b4e 100644 GIT binary patch delta 2611 zcmb7F&2Jl35Z||6f9!hgI>k7)lfw3cBrDR;rVVYQ(uxL=a%kGBMWt3%v+;XzHnn%% zw`*E&9g5Tli9-+JAtVk+F_#K~kRTEl{sk@wIIvvW6GGy`g)7WkeTl4ZuVrlme6$+j081%+?HQR0r{%)vg+}OxR<2 znB5Df*s7Y=yZOeiH z$E*S3@Dfx|N>oq=NuqZIV)er5frMDlWjYRf=mb|fC_h>(V!Ko<26xiKe0~tA9}?Nz z6ik*tg6m^~gfh}-EHBe3c0Ilid4SlfH%JW2|3(c;R^B!bB@AdPIEnYH357ULfons4 zz8lGI<<3!yV@%Ya%nDsaARP=K(?k+m^)*zIs0AfG;j z{4jPKH3BulQScRyz~jY=d5iebrHXC1FCS;$sk!kFc*Ho;sfc&->F%+wqtCHyBGb1t zs7I+A(0zj@4AqTc*k~e z;m=sE5+rhob7=SD<1k1Neokc=;kH{sTxg+TVT9MJ=zZFA?SgR&z-)8sJ4U4re_UHz zBrv5-9GTRP>KZjHsKGwh1}rmBwq3HlSa&%y2g@8oo<5FI1;q=9r-)&7UfOoW@7!)( zY@-_@cWXM@iI?xwQ?Tz7kO1U#)I0F>2u*dOx`mK=YWCuz!Jj@Mx2qqcTVX8XP%E?= zYK9jM+{i> zHMn6%7Cn}7qZ_fN)ItzQO*0<%!?tLAI_6V5L zX4RuGVna&$eo$jU;?QZny0z8Q(=hiXNSntaZ5~%dZ(^5NY(Zi{SGVyyJT3`qY3XM+bI4kyll^uwbuF_uM{n5qPR=y5+}ES7@q q?uV^$>bFN8<*dAs&0 zh5RUAh5Bgmu`TI3N%-jXljyw`Q_=uQ8k7U4JjWTgWMy>TtxA#F9YEzVl~s;?LNw>Ait}{Ti!qW#QMt46Nes!hKtEJ<&r9 z25^Ex3?RmrVh&RLEOr7qgovyYb<<@1%g_3ICWmTpZH&U(OOFy)E$F9OG_a&0bDB-L zztwDBE}&|Lj|^1shac6J_4S#qzS+nn45ryNR~a_(vRtdZrA&Hebr1J?0~O= np """ pass + @abstractmethod + def _calculate_predictions(self, x: np.ndarray) -> np.ndarray: + """ + Calculate predictions for input data. + :param x: Input data. + :return: Predictions. + """ + def predict(self, x: np.ndarray) -> np.ndarray: """ Predict target feature using pretrained boosting trees. :param x: Test data. :return: Test predictions. """ - n_samples, _ = x.shape - - predictions = np.ones(shape=(n_samples,)) * self.constant_prediction - - for tree in self.trees: - predictions = predictions + self.learning_rate * tree.predict(x) - + predictions = self._calculate_predictions(x) return predictions @@ -108,11 +111,75 @@ def _calculate_initial_prediction(self, y: np.ndarray) -> np.ndarray: def _calculate_loss_gradient(self, y: np.ndarray, predictions: np.ndarray) -> np.ndarray: """ - Find mean value for the targets. + Calculate gradient of mean-squared error loss. + :param predictions: Target predictions. + :param y: Targets. + :return: Gradient of loss function with respect to predictions. + """ + return y - predictions + + def _calculate_predictions(self, x: np.ndarray) -> np.ndarray: + n_samples, _ = x.shape + + predictions = np.ones(n_samples) * self.constant_prediction + for tree in self.trees: + predictions = predictions + self.learning_rate * tree.predict(x) + + return predictions + + +class GradientBoostingClassifier(_GradientBoosting): + """ + Gradient Boosting for the classification. + Uses cross-entropy as loss. + """ + + def _calculate_initial_prediction(self, y: np.ndarray) -> np.ndarray: + """ + Find natural logarithm of odds. :param y: Targets. :return: Initial predictions. """ - return predictions - y + return np.zeros_like(y, dtype=np.float64) + + def _calculate_loss_gradient(self, y: np.ndarray, predictions: np.ndarray) -> np.ndarray: + """ + Calculate cross-entropy gradient. + :param y: Targets. + :return: Gradient of loss function with respect to predictions. + """ + return y - GradientBoostingClassifier.sigmoid(predictions) + @staticmethod + def sigmoid(x: np.ndarray) -> np.ndarray: + """ + Makes input values to be in (0, 1) range. + :param x: Input array. + :return: Output array of the same shape as an input array. + """ + return 1 / (1 + np.exp(-x)) + + def _calculate_predictions(self, x: np.ndarray) -> np.ndarray: + """ + Calculate targets using prediction probability. + :param x: Input array. + :return: Predictions. + """ + predictions_proba = self.predict_proba(x) + predictions = np.where(predictions_proba >= 0.5, 1, 0) + return predictions + + def predict_proba(self, x): + """ + Predict label using sigmoid function. + :param x: Input array. + :return: Predictions. + """ + n_samples, _ = x.shape + + predictions = np.ones(n_samples) * self.constant_prediction + for tree in self.trees: + predictions = predictions + self.learning_rate * tree.predict(x) + return GradientBoostingClassifier.sigmoid(predictions) diff --git a/tests/base/__pycache__/config.cpython-310.pyc b/tests/base/__pycache__/config.cpython-310.pyc index ea5d8433f847c74d390ab15e81447d0c121963e8..cd81cb23b3740054c7444cf29168af028c06be04 100644 GIT binary patch delta 157 zcmZ3-x1WzUpO=@50SLCaE>E@G$ZNuC=LF={Ff3rGVXR?V$ixU@GuALJWUOT_VX9$n zW@Kb2VXk2=VaZ}`W-1oHWST4u>D_JX3+l+5Ik xTdV~|nR%0g*pdZ6cJXj=F>){pFo`f0Nlre@rYs-^6fEKc5dt7Wa`HDe3jo2TB_IF* delta 140 zcmdnbw~miDpO=@50SFeXSd^;1k=KOP%nHb>VOYRW!&t+#kckn*W~^ac$XLr;*-sDVfQW6WNmaIDn?{ haB(pTFo`f0Nld=NrYs-|6e{8a5dt7WVlq3s1prxCAJqT= diff --git a/tests/base/config.py b/tests/base/config.py index 9648eba..01ec3b8 100644 --- a/tests/base/config.py +++ b/tests/base/config.py @@ -55,6 +55,7 @@ def check_fit_predict(model, x: np.ndarray, y: np.ndarray): # Fit the model on the mock dataset model.fit(x, y) preds = model.predict(x) + print(preds) assert isinstance(preds, np.ndarray) assert preds.shape == y.shape diff --git a/tests/ensemble/__pycache__/test_boosting.cpython-310-pytest-7.4.4.pyc b/tests/ensemble/__pycache__/test_boosting.cpython-310-pytest-7.4.4.pyc index c51f55986f59ebf98afaa46de3d53ccaf4aaa30a..56cf065478c911aca1f165f0dd2f3befb105b988 100644 GIT binary patch delta 474 zcmYjN%SyvQ6rGz%(==`B15r>niXu@c#gz+Ds7QBR=q?B~cBUbahfGHBk*;*>B6Ouo zH?=?EhXjAXUvTA}sL)xQd%2Hu&Ye&F+h``XZ4rELrxU)f<;`_C%lCE;BGWa2Fv4h9 zi!9fQ>TW%1xDBl7tnx%$n;EQ%vB}_^IOdn)=s8+`zt$~QyTW4uk7XxGQW?blN$90% zFbueGXtBsV>7`uu1@{HVdJ-4Q|Fyy2fM zgJ@to|LC;o%t}S&RRyMkngaD_zfi}l*dAZQ%GQi|JW4fErD}qqsBKK0x$y*smqO9I z3zDZYZC&AKv<68$4E)laDD~80>89M9(JPfpQS%TFxC!|Y;B2bA?UE%hX~XydC|qz0 delta 232 zcmeyxc87&8pO=@50SE*h%uh9DoX97`=rmDVU4xM!g&~DGha;CWiZhohiVMhRNnvea zh~iFR3}(<|dkNCxr^z_6OIlcy>6TDRVo73gYDs)iYI;#>aq+}A29p(;G!&VDs)|^E zgeFT7GmyH)SaFN7@)m1ueoAW2N`@ki$<2(_+-!^jj71<