|
| 1 | +"""Interactive fitting widget using PyQt6.""" |
| 2 | + |
| 3 | +import warnings |
| 4 | +import numpy as np |
| 5 | +from typing import Dict, Any, Callable |
| 6 | +import sys |
| 7 | + |
| 8 | +with warnings.catch_warnings(): |
| 9 | + # ipywidgets produces deprecation warnings through use of internal APIs :( |
| 10 | + warnings.simplefilter("ignore") |
| 11 | + try: |
| 12 | + from PyQt6 import QtCore, QtGui, QtWidgets |
| 13 | + from matplotlib.figure import Figure |
| 14 | + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg |
| 15 | + from matplotlib import pyplot as plt |
| 16 | + except ModuleNotFoundError as e: |
| 17 | + e.msg += ( |
| 18 | + "\n\nPlease install PyQt6, and matplotlib to enable interactive " |
| 19 | + "outside of Jupyter notebooks." |
| 20 | + ) |
| 21 | + raise |
| 22 | + |
| 23 | + |
| 24 | +def make_widget( |
| 25 | + minuit: Any, |
| 26 | + plot: Callable[..., None], |
| 27 | + kwargs: Dict[str, Any], |
| 28 | + raise_on_exception: bool, |
| 29 | +): |
| 30 | + """Make interactive fitting widget.""" |
| 31 | + original_values = minuit.values[:] |
| 32 | + original_limits = minuit.limits[:] |
| 33 | + |
| 34 | + def plot_with_frame(from_fit, report_success): |
| 35 | + trans = plt.gca().transAxes |
| 36 | + try: |
| 37 | + with warnings.catch_warnings(): |
| 38 | + minuit.visualize(plot, **kwargs) |
| 39 | + except Exception: |
| 40 | + if raise_on_exception: |
| 41 | + raise |
| 42 | + |
| 43 | + import traceback |
| 44 | + |
| 45 | + plt.figtext( |
| 46 | + 0, |
| 47 | + 0.5, |
| 48 | + traceback.format_exc(limit=-1), |
| 49 | + fontdict={"family": "monospace", "size": "x-small"}, |
| 50 | + va="center", |
| 51 | + color="r", |
| 52 | + backgroundcolor="w", |
| 53 | + wrap=True, |
| 54 | + ) |
| 55 | + return |
| 56 | + |
| 57 | + fval = minuit.fmin.fval if from_fit else minuit._fcn(minuit.values) |
| 58 | + plt.text( |
| 59 | + 0.05, |
| 60 | + 1.05, |
| 61 | + f"FCN = {fval:.3f}", |
| 62 | + transform=trans, |
| 63 | + fontsize="x-large", |
| 64 | + ) |
| 65 | + if from_fit and report_success: |
| 66 | + plt.text( |
| 67 | + 0.95, |
| 68 | + 1.05, |
| 69 | + f"{'success' if minuit.valid and minuit.accurate else 'FAILURE'}", |
| 70 | + transform=trans, |
| 71 | + fontsize="x-large", |
| 72 | + ha="right", |
| 73 | + ) |
| 74 | + |
| 75 | + def fit(): |
| 76 | + if algo_choice.value == "Migrad": |
| 77 | + minuit.migrad() |
| 78 | + elif algo_choice.value == "Scipy": |
| 79 | + minuit.scipy() |
| 80 | + elif algo_choice.value == "Simplex": |
| 81 | + minuit.simplex() |
| 82 | + return False |
| 83 | + else: |
| 84 | + assert False # pragma: no cover, should never happen |
| 85 | + return True |
| 86 | + |
| 87 | + def do_fit(change): |
| 88 | + report_success = fit() |
| 89 | + for i, x in enumerate(parameters): |
| 90 | + x.reset(minuit.values[i]) |
| 91 | + if change is None: |
| 92 | + return report_success |
| 93 | + OnParameterChange()({"from_fit": True, "report_success": report_success}) |
| 94 | + |
| 95 | + def on_update_button_clicked(change): |
| 96 | + for x in parameters: |
| 97 | + x.slider.continuous_update = not x.slider.continuous_update |
| 98 | + |
| 99 | + def on_reset_button_clicked(change): |
| 100 | + minuit.reset() |
| 101 | + minuit.values = original_values |
| 102 | + minuit.limits = original_limits |
| 103 | + for i, x in enumerate(parameters): |
| 104 | + x.reset(minuit.values[i], minuit.limits[i]) |
| 105 | + OnParameterChange()() |
| 106 | + |
| 107 | + def on_parameter_change(value): |
| 108 | + pass |
| 109 | + |
| 110 | + |
| 111 | + class FloatSlider(QtWidgets.QSlider): |
| 112 | + floatValueChanged = QtCore.pyqtSignal(float) |
| 113 | + |
| 114 | + def __init__(self, label): |
| 115 | + super().__init__(QtCore.Qt.Orientation.Horizontal) |
| 116 | + super().setMinimum(0) |
| 117 | + super().setMaximum(1000) |
| 118 | + super().setValue(500) |
| 119 | + self._min = 0.0 |
| 120 | + self._max = 1.0 |
| 121 | + self._label = label |
| 122 | + self.valueChanged.connect(self._emit_float_value_changed) |
| 123 | + |
| 124 | + def _emit_float_value_changed(self, value): |
| 125 | + float_value = self._int_to_float(value) |
| 126 | + self._label.setText(str(float_value)) |
| 127 | + self.floatValueChanged.emit(float_value) |
| 128 | + |
| 129 | + def _int_to_float(self, value): |
| 130 | + return self._min + (value / 1000) * (self._max - self._min) |
| 131 | + |
| 132 | + def _float_to_int(self, value): |
| 133 | + return int((value - self._min) / (self._max - self._min) * 1000) |
| 134 | + |
| 135 | + def setMinimum(self, min_value): |
| 136 | + self._min = min_value |
| 137 | + |
| 138 | + def setMaximum(self, max_value): |
| 139 | + self._max = max_value |
| 140 | + |
| 141 | + def setValue(self, value): |
| 142 | + super().setValue(self._float_to_int(value)) |
| 143 | + |
| 144 | + def value(self): |
| 145 | + return self._int_to_float(super().value()) |
| 146 | + |
| 147 | + def setSliderPosition(self, value): |
| 148 | + super().setSliderPosition(self._float_to_int(value)) |
| 149 | + |
| 150 | + |
| 151 | + class Parameter(QtWidgets.QGroupBox): |
| 152 | + def __init__(self, minuit, par) -> None: |
| 153 | + super().__init__(par) |
| 154 | + self.par = par |
| 155 | + # Set up the Qt Widget |
| 156 | + layout = QtWidgets.QGridLayout() |
| 157 | + self.setLayout(layout) |
| 158 | + # Add line edit to display slider value |
| 159 | + self.value_label = QtWidgets.QLabel() |
| 160 | + # Add value slider |
| 161 | + self.slider = FloatSlider(line_edit=self.value_label) |
| 162 | + self.slider.floatValueChanged.connect() |
| 163 | + # Add line edit for changing the limits |
| 164 | + self.vmin = QtWidgets.QLineEdit() |
| 165 | + self.vmin.returnPressed.connect(self.on_limit_changed) |
| 166 | + self.vmax = QtWidgets.QLineEdit() |
| 167 | + self.vmax.returnPressed.connect(self.on_limit_changed) |
| 168 | + # Add buttons |
| 169 | + self.fix = QtWidgets.QPushButton("Fix") |
| 170 | + self.fix.setCheckable(True) |
| 171 | + self.fix.setChecked(minuit.fixed[par]) |
| 172 | + self.fix.clicked.connect(self.on_fix_toggled) |
| 173 | + self.fit = QtWidgets.QPushButton("Fit") |
| 174 | + self.fit.setCheckable(True) |
| 175 | + self.fit.setChecked(False) |
| 176 | + self.fit.clicked.connect(self.on_fit_toggled) |
| 177 | + # Add widgets to the layout |
| 178 | + layout.addWidget(self.slider, 0, 0) |
| 179 | + layout.addWidget(self.value_label, 0, 1) |
| 180 | + layout.addWidget(self.vmin, 1, 0) |
| 181 | + layout.addWidget(self.vmax, 1, 1) |
| 182 | + layout.addWidget(self.fix, 2, 0) |
| 183 | + layout.addWidget(self.fit, 2, 1) |
| 184 | + # Add tooltips |
| 185 | + self.slider.setToolTip("Parameter Value") |
| 186 | + self.value_label.setToolTip("Parameter Value") |
| 187 | + self.vmin.setToolTip("Lower Limit") |
| 188 | + self.vmax.setToolTip("Upper Limit") |
| 189 | + self.fix.setToolTip("Fix Parameter") |
| 190 | + self.fit.setToolTip("Fit Parameter") |
| 191 | + # Set initial value and limits |
| 192 | + val = minuit.values[par] |
| 193 | + vmin, vmax = minuit.limits[par] |
| 194 | + step = _guess_initial_step(val, vmin, vmax) |
| 195 | + vmin2 = vmin if np.isfinite(vmin) else val - 100 * step |
| 196 | + vmax2 = vmax if np.isfinite(vmax) else val + 100 * step |
| 197 | + self.slider.setMinimum(vmin2) |
| 198 | + self.slider.setMaximum(vmax2) |
| 199 | + self.slider.setValue(val) |
| 200 | + self.value_label.setText(f"{val:.1g}") |
| 201 | + self.vmin.setText(f"{vmin2:.1g}") |
| 202 | + self.vmax.setText(f"{vmax2:.1g}") |
| 203 | + |
| 204 | + def on_val_changed(self, val): |
| 205 | + self.minuit.values[self.par] = val |
| 206 | + self.value_label.setText(f"{val:.1g}") |
| 207 | + on_parameter_change() |
| 208 | + |
| 209 | + def on_limit_changed(self): |
| 210 | + vmin = float(self.vmin.text()) |
| 211 | + vmax = float(self.vmax.text()) |
| 212 | + self.minuit.limits[self.par] = (vmin, vmax) |
| 213 | + self.slider.setMinimum(vmin) |
| 214 | + self.slider.setMaximum(vmax) |
| 215 | + # Update the slider position |
| 216 | + current_value = self.slider.value() |
| 217 | + if current_value < vmin: |
| 218 | + self.slider.setValue(vmin) |
| 219 | + self.vmin.setText(f"{vmin:.1g}") |
| 220 | + on_parameter_change() |
| 221 | + elif current_value > vmax: |
| 222 | + self.slider.setValue(vmax) |
| 223 | + self.editValue.setText(f"{vmax:.1g}") |
| 224 | + on_parameter_change() |
| 225 | + else: |
| 226 | + self.slider.blockSignals(True) |
| 227 | + self.slider.setValue(vmin) |
| 228 | + self.slider.setValue(current_value) |
| 229 | + self.slider.blockSignals(False) |
| 230 | + |
| 231 | + def on_fix_toggled(self): |
| 232 | + self.minuit.fixed[self.par] = self.fix.isChecked() |
| 233 | + if self.fix.isChecked(): |
| 234 | + self.fit.setChecked(False) |
| 235 | + |
| 236 | + def on_fit_toggled(self): |
| 237 | + self.slider.setEnabled(not self.fit.isChecked()) |
| 238 | + if self.fit.isChecked(): |
| 239 | + self.fix.setChecked(False) |
| 240 | + on_parameter_change() |
| 241 | + |
| 242 | + # Set up the main window |
| 243 | + main_window = QtWidgets.QMainWindow() |
| 244 | + main_window.resize(1600, 1000) |
| 245 | + # Set the global font |
| 246 | + font = QtGui.QFont() |
| 247 | + font.setPointSize(12) |
| 248 | + main_window.setFont(font) |
| 249 | + # Create the central widget |
| 250 | + centralwidget = QtWidgets.QWidget(parent=main_window) |
| 251 | + main_window.setCentralWidget(centralwidget) |
| 252 | + central_layout = QtWidgets.QVBoxLayout(centralwidget) |
| 253 | + # Add tabs for interactive and results |
| 254 | + tab = QtWidgets.QTabWidget(parent=centralwidget) |
| 255 | + interactive_tab = QtWidgets.QWidget() |
| 256 | + tab.addTab(interactive_tab, "") |
| 257 | + results_tab = QtWidgets.QWidget() |
| 258 | + tab.addTab(results_tab, "") |
| 259 | + central_layout.addWidget(tab) |
| 260 | + # Interactive tab |
| 261 | + interactive_layout = QtWidgets.QGridLayout(interactive_tab) |
| 262 | + # Add the plot |
| 263 | + plot_group = QtWidgets.QGroupBox("", parent=interactive_tab) |
| 264 | + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, |
| 265 | + QtWidgets.QSizePolicy.Policy.Expanding) |
| 266 | + sizePolicy.setHeightForWidth(plot_group.sizePolicy().hasHeightForWidth()) |
| 267 | + plot_group.setSizePolicy(sizePolicy) |
| 268 | + plot_layout = QtWidgets.QVBoxLayout(plot_group) |
| 269 | + canvas = FigureCanvasQTAgg(Figure()) |
| 270 | + ax = canvas.figure.add_subplot(111) |
| 271 | + plot_layout.addWidget(canvas) |
| 272 | + interactive_layout.addWidget(plot_group, 0, 0, 2, 1) |
| 273 | + # Add buttons |
| 274 | + button_group = QtWidgets.QGroupBox("", parent=interactive_tab) |
| 275 | + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, |
| 276 | + QtWidgets.QSizePolicy.Policy.Fixed) |
| 277 | + sizePolicy.setHeightForWidth(button_group.sizePolicy().hasHeightForWidth()) |
| 278 | + button_group.setSizePolicy(sizePolicy) |
| 279 | + button_layout = QtWidgets.QHBoxLayout(button_group) |
| 280 | + fit_button = QtWidgets.QPushButton(parent=button_group) |
| 281 | + fit_button.clicked.connect(do_fit) |
| 282 | + button_layout.addWidget(fit_button) |
| 283 | + update_button = QtWidgets.QPushButton(parent=button_group) |
| 284 | + update_button.clicked.connect(on_update_button_clicked) |
| 285 | + button_layout.addWidget(update_button) |
| 286 | + reset_button = QtWidgets.QPushButton(parent=button_group) |
| 287 | + reset_button.clicked.connect(on_reset_button_clicked) |
| 288 | + button_layout.addWidget(reset_button) |
| 289 | + algo_choice = QtWidgets.QComboBox(parent=button_group) |
| 290 | + algo_choice.setStyleSheet("QComboBox { text-align: center; }") |
| 291 | + algo_choice.addItems(["Migrad", "Scipy", "Simplex"]) |
| 292 | + button_layout.addWidget(algo_choice) |
| 293 | + interactive_layout.addWidget(button_group, 0, 1, 1, 1) |
| 294 | + # Add the parameters |
| 295 | + parameter_group = QtWidgets.QGroupBox("", parent=interactive_tab) |
| 296 | + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, |
| 297 | + QtWidgets.QSizePolicy.Policy.Expanding) |
| 298 | + sizePolicy.setHeightForWidth( |
| 299 | + parameter_group.sizePolicy().hasHeightForWidth()) |
| 300 | + parameter_group.setSizePolicy(sizePolicy) |
| 301 | + parameter_group_layout = QtWidgets.QVBoxLayout(parameter_group) |
| 302 | + scroll_area = QtWidgets.QScrollArea(parent=parameter_group) |
| 303 | + scroll_area.setWidgetResizable(True) |
| 304 | + scroll_area_widget_contents = QtWidgets.QWidget() |
| 305 | + scroll_area_widget_contents.setGeometry(QtCore.QRect(0, 0, 751, 830)) |
| 306 | + parameter_layout = QtWidgets.QVBoxLayout(scroll_area_widget_contents) |
| 307 | + scroll_area.setWidget(scroll_area_widget_contents) |
| 308 | + parameter_group_layout.addWidget(scroll_area) |
| 309 | + interactive_layout.addWidget(parameter_group, 1, 1, 1, 1) |
| 310 | + # Results tab |
| 311 | + results_layout = QtWidgets.QVBoxLayout(results_tab) |
| 312 | + results_text = QtWidgets.QPlainTextEdit(parent=results_tab) |
| 313 | + font = QtGui.QFont() |
| 314 | + font.setFamily("FreeMono") |
| 315 | + results_text.setFont(font) |
| 316 | + results_text.setReadOnly(True) |
| 317 | + results_layout.addWidget(results_text) |
| 318 | + |
| 319 | + parameters = [Parameter(minuit, par) for par in minuit.parameters] |
| 320 | + |
| 321 | + |
| 322 | +def _make_finite(x: float) -> float: |
| 323 | + sign = -1 if x < 0 else 1 |
| 324 | + if abs(x) == np.inf: |
| 325 | + return sign * sys.float_info.max |
| 326 | + return x |
| 327 | + |
| 328 | + |
| 329 | +def _guess_initial_step(val: float, vmin: float, vmax: float) -> float: |
| 330 | + if np.isfinite(vmin) and np.isfinite(vmax): |
| 331 | + return 1e-2 * (vmax - vmin) |
| 332 | + return 1e-2 |
| 333 | + |
| 334 | + |
| 335 | +def _round(x: float) -> float: |
| 336 | + return float(f"{x:.1g}") |
0 commit comments