5
5
__all__ = ["ShapeletClassifierVisualizer" , "ShapeletTransformerVisualizer" ]
6
6
7
7
import copy
8
+ import warnings
8
9
9
10
import numpy as np
10
11
from sklearn .ensemble ._forest import BaseForest
@@ -681,13 +682,35 @@ def _get_shp_importance(self, class_id):
681
682
if isinstance (classifier , Pipeline ):
682
683
classifier = classifier [- 1 ]
683
684
684
- # This suppose that the higher the coef linked to each feature, the most
685
- # impact this feature makes on classification for the given class_id
685
+ # This supposes that the higher (with the exception of distance features)
686
+ # the coef linked to each feature, the most impact this feature makes on
687
+ # classification for the given class_id
686
688
if isinstance (classifier , LinearClassifierMixin ):
687
689
coefs = classifier .coef_
688
690
n_classes = coefs .shape [0 ]
689
691
if n_classes == 1 :
690
- coefs = np .append (- coefs , coefs , axis = 0 )
692
+ if isinstance (self .estimator , RDSTClassifier ):
693
+ class_0_coefs = np .copy (coefs )
694
+ class_1_coefs = np .copy (coefs )
695
+
696
+ mask = np .ones (class_0_coefs .shape [1 ], dtype = bool )
697
+ mask [::3 ] = False
698
+ class_0_coefs [:, mask ] = - class_0_coefs [:, mask ]
699
+ class_1_coefs [:, ::3 ] = - class_1_coefs [:, ::3 ]
700
+
701
+ # Append the two modified coefs arrays along axis 0
702
+ coefs = np .append (class_0_coefs , class_1_coefs , axis = 0 )
703
+ warnings .warn (
704
+ "Shapelet importance ranking may be unreliable "
705
+ "when using linear classifiers with RDST. "
706
+ "This is due to the interaction between argmin "
707
+ "and shapelet occurrence features, which can distort "
708
+ "the rankings. Consider evaluating the results carefully "
709
+ "or using an alternative method." ,
710
+ stacklevel = 1 ,
711
+ )
712
+ else :
713
+ coefs = np .append (coefs , - coefs , axis = 0 )
691
714
coefs = coefs [class_id ]
692
715
693
716
elif isinstance (classifier , (BaseForest , BaseDecisionTree )):
@@ -699,7 +722,7 @@ def _get_shp_importance(self, class_id):
699
722
"classifier inheriting from LinearClassifierMixin, BaseForest or "
700
723
f"BaseDecisionTree but got { type (classifier )} "
701
724
)
702
- # coefs = coefs[idx]
725
+
703
726
if isinstance (self .estimator , RDSTClassifier ):
704
727
# As each shapelet generate 3 features, divide feature id by 3 so all
705
728
# features generated by one shapelet share the same ID
0 commit comments