Skip to content

Commit

Permalink
fix: use proper eval default main eval metrics for text regression model
Browse files Browse the repository at this point in the history
  • Loading branch information
MattGPT-ai committed Jan 26, 2025
1 parent 30974f2 commit 3cc131a
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def evaluate(
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"),
exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
Expand Down Expand Up @@ -195,16 +195,23 @@ def evaluate(
f"spearman: {metric.spearmanr():.4f}"
)

result: Result = Result(
main_score=metric.pearsonr(),
scores = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = scores[main_evaluation_metric[1]]
else:
main_score = scores["spearman"]

result = Result(
main_score=main_score,
detailed_results=detailed_result,
scores={
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
},
scores=scores,
)

return result
Expand Down

0 comments on commit 3cc131a

Please sign in to comment.