Skip to content

Commit

Permalink
Fixed losing connection to the HA websocket. Retries automatically.
Browse files Browse the repository at this point in the history
  • Loading branch information
morosanmihail committed Jan 9, 2025
1 parent 595732f commit d64c632
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
27 changes: 22 additions & 5 deletions src/homeassistant.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc, time::Duration};

use eyre::{OptionExt, Result};
use futures_util::{SinkExt, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Error, Value};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{
mpsc::{Receiver, Sender},
Mutex,
};
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use url::Url;

Expand Down Expand Up @@ -269,7 +272,7 @@ pub async fn listen_for_events(
access_token: String,
mut media_players: HashMap<String, MediaPlayerState>,
channels: HashMap<String, Sender<HAEvent>>,
mut mpris_rx: Receiver<(String, HAEvent)>,
mpris_rx: Arc<Mutex<Receiver<(String, HAEvent)>>>,
) -> Result<()> {
let (ws_stream, _) = connect_async(ha_url).await?;
let (mut write, mut read) = ws_stream.split();
Expand Down Expand Up @@ -304,7 +307,12 @@ pub async fn listen_for_events(
loop {
tokio::select! {
event = read.next() => {
let Some(Ok(Message::Text(text))) = event else { continue };
let text = match event {
Some(Ok(Message::Text(t))) => t,
Some(Ok(Message::Close(_))) => break Err(eyre::eyre!("Channel closed")),
None => break Err(eyre::eyre!("Restarting websocket channel due to unknown reason")),
_ => continue
};
let Ok(event): Result<serde_json::Value, Error> = serde_json::from_str(&text) else { continue };
let Some(entity_id) = event.get("event").and_then(|e| e.get("data")).and_then(|d| d.get("entity_id")).and_then(|e| e.as_str()) else { continue };
let Some(media_player) = media_players.get_mut(entity_id) else { continue };
Expand All @@ -329,7 +337,12 @@ pub async fn listen_for_events(
Err(e) => println!("Died during metadata update event with {e}"),
};
}
Some((entity_id, msg)) = mpris_rx.recv() => {

result = async {
let mut guard = mpris_rx.lock().await;
guard.recv().await
} => {
let Some((entity_id, msg)) = result else { continue };
let media = media_players.get_mut(&entity_id);
if let Some(mp) = media {
match msg {
Expand All @@ -345,6 +358,10 @@ pub async fn listen_for_events(
};
}
}

_ = tokio::time::sleep(Duration::from_secs(60)) => {
break Err(eyre::eyre!("Timed out"));
}
}
}
}
Expand Down
34 changes: 25 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::{collections::HashMap, io::Write, path::PathBuf};
use std::{collections::HashMap, io::Write, path::PathBuf, sync::Arc, time::Duration};

use eyre::{OptionExt, Result};
use homeassistant::{get_media_players, listen_for_events, MediaPlayerState};
use mpris::new_mpris_player;
use serde::{Deserialize, Serialize};
use tokio::{sync::mpsc, task::JoinSet};
use tokio::{
sync::{mpsc, Mutex},
task::JoinSet,
};

mod homeassistant;
mod mpris;
Expand Down Expand Up @@ -54,6 +57,7 @@ async fn main() -> Result<()> {

// Channel to handle events from MPRIS to HA
let (mpris_tx, mpris_rx) = mpsc::channel(100);
let mpris_rx = Arc::new(Mutex::new(mpris_rx));
let mut set = JoinSet::new();

let mut media_player_states = HashMap::new();
Expand Down Expand Up @@ -81,13 +85,25 @@ async fn main() -> Result<()> {
}

println!("Connected to {}", websocket_url);
let _ha_task = set.spawn(listen_for_events(
websocket_url,
config.home_assistant_token.to_string(),
media_player_states,
channels,
mpris_rx,
));
let _ha_task = set.spawn(async move {
loop {
let websocket_url = websocket_url.clone();
let media_player_states = media_player_states.clone();
let channels = channels.clone();
if let Err(e) = listen_for_events(
websocket_url,
config.home_assistant_token.to_string(),
media_player_states,
channels,
mpris_rx.clone(),
)
.await
{
println!("WebSocket connection lost. {e}. Retrying...");
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
});

while (set.join_next().await).is_some() {}
Ok(())
Expand Down

0 comments on commit d64c632

Please sign in to comment.