Skip to content

Commit

Permalink
start error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jbesraa committed Apr 7, 2024
1 parent 176a7ec commit c926828
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 250 deletions.
232 changes: 130 additions & 102 deletions lightning-payjoin/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub mod scheduler;
use lightning::ln::ChannelId;
pub use scheduler::FundingTxParams;

use bitcoin::absolute::LockTime;
use bitcoin::psbt::Psbt;
use bitcoin::secp256k1::PublicKey;
use bitcoin::{base64, ScriptBuf};
Expand All @@ -12,41 +12,66 @@ use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{HeaderMap, Request};
use hyper_util::rt::TokioIo;
use scheduler::ScheduledChannel;
use std::collections::HashMap;
use std::string::FromUtf8Error;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::task::JoinError;

pub trait PayjoinLNReceiver: Send + Sync + 'static + Clone {
#[derive(Debug)]
pub enum Error {
InvalidRequest(hyper::Error),
InvalidRequestBody(FromUtf8Error),
PsbtParseError(bitcoin::psbt::Error),
TokioJoinError(JoinError),
FundingTxParamsNotFound,
NoAvailableChannel,
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
Self::InvalidRequest(ref e) => write!(f, "Invalid request: {}", e),
Self::InvalidRequestBody(ref e) => write!(f, "Invalid request body: {}", e),
Self::PsbtParseError(ref e) => write!(f, "Psbt parse error: {}", e),
Self::TokioJoinError(ref e) => write!(f, "Tokio join error: {}", e),
Self::FundingTxParamsNotFound => write!(f, "Funding tx params not found"),
Self::NoAvailableChannel => write!(f, "No available channel"),
}
}
}

impl std::error::Error for Error {}

pub trait Receiver: Send + Sync + 'static + Clone {
fn is_mine(&self, script: &ScriptBuf) -> Result<bool, Box<dyn std::error::Error>>;
fn notify_funding_generated(
&self, temporary_channel_id: [u8; 32], counterparty_node_id: PublicKey,
fn funding_transaction_generated(
&self, temporary_channel_id: &ChannelId, counterparty_node_id: PublicKey,
funding_tx: bitcoin::Transaction,
) -> Result<(), Box<dyn std::error::Error>>;
}

#[derive(Clone)]
pub struct PayjoinService<P: PayjoinLNReceiver + Send + Sync + 'static + Clone> {
pub struct LightningPayjoin<P: Receiver + Send + Sync + 'static + Clone> {
receiver_handler: P,
scheduler: Arc<Mutex<scheduler::ChannelScheduler>>,
}

impl<P> PayjoinService<P>
impl<P> LightningPayjoin<P>
where
P: PayjoinLNReceiver + Send + Sync + 'static + Clone,
P: Receiver + Send + Sync + 'static + Clone,
{
pub fn new(receiver_handler: P, scheduler: Arc<Mutex<scheduler::ChannelScheduler>>) -> Self {
Self { receiver_handler, scheduler }
}

pub async fn serve_incoming_payjoin_requests(
&self, stream: TcpStream,
) -> Result<(), JoinError> {
pub async fn serve_incoming_http_request(&self, stream: TcpStream) -> Result<(), Error> {
let io = TokioIo::new(stream);
let receiver = self.receiver_handler.clone();
let scheduler = self.scheduler.clone();
let payjoin_lightning = Arc::new(Mutex::new(PayjoinService::new(receiver, scheduler)));
let payjoin_lightning = Arc::new(Mutex::new(LightningPayjoin::new(receiver, scheduler)));
tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
Expand All @@ -61,51 +86,86 @@ where
}
})
.await
.map_err(Error::TokioJoinError)
}

async fn convert_payjoin_request_to_funding_tx(
async fn get_next_channel(
&self, request: Request<Incoming>,
) -> Result<String, Box<dyn std::error::Error>> {
let is_output_mine =
|script: &ScriptBuf| self.receiver_handler.is_mine(script).map_err(|e| e.into());
let (psbt, amount_to_us) = extract_psbt_from_http_request(request, is_output_mine).await?;
let channel = match self.scheduler.lock().await.get_next_channel(amount_to_us) {
Some(channel) => channel,
None => {
panic!("No channel available for payjoin");
},
};
assert!(channel.is_channel_accepted());
let locktime = match channel.locktime() {
Some(locktime) => locktime,
None => unreachable!(),
};
let output_script = match channel.output_script() {
Some(output_script) => output_script,
None => unreachable!(),
};
let temporary_channel_id = match channel.temporary_channel_id() {
Some(temporary_channel_id) => temporary_channel_id,
None => unreachable!(),
};
let psbt = from_original_psbt_to_funding_psbt(
output_script,
channel.channel_value_satoshi(),
psbt,
locktime,
is_output_mine,
);
let funding_tx = psbt.clone().extract_tx();
) -> Result<Option<(Psbt, scheduler::ScheduledChannel)>, Error> {
let headers = request.headers().clone();
let body = request.into_body().collect().await.map_err(Error::InvalidRequest)?;
let body =
String::from_utf8(body.to_bytes().to_vec()).map_err(Error::InvalidRequestBody)?;
let mut psbt = body_to_psbt(headers.clone(), body.as_bytes());
let is_mine = |script: &ScriptBuf| self.receiver_handler.is_mine(script).map_err(|e| e);
let amount_to_us = psbt.unsigned_tx.output.iter().fold(0, |acc, output| {
if let Ok(is_mine) = is_mine(&output.script_pubkey) {
if is_mine {
acc + output.value
} else {
acc
}
} else {
acc
}
});
if let Some(channel) = self.scheduler.lock().await.get_next_channel(amount_to_us) {
let is_mine = |script: &ScriptBuf| self.receiver_handler.is_mine(script).map_err(|e| e);
let funding_tx = {
let funding_tx_params = match channel.funding_tx_params() {
Some(funding_tx_params) => funding_tx_params,
None => return Err(Error::FundingTxParamsNotFound),
};
let channel_value_sat = channel.channel_value_satoshi();
debug_assert_eq!(channel_value_sat, amount_to_us);
let output_script = funding_tx_params.output_script();
psbt.unsigned_tx.lock_time = funding_tx_params.locktime();
psbt.unsigned_tx.output.push(bitcoin::TxOut {
value: channel_value_sat,
script_pubkey: output_script.clone(),
});
psbt.unsigned_tx.output.retain(|output| {
let is_mine = match is_mine(&output.script_pubkey) {
Ok(is_mine) => is_mine,
Err(e) => panic!("{:?}", e),
};
!is_mine || output.script_pubkey == output_script
});
Psbt::from_unsigned_tx(psbt.unsigned_tx.clone()).map_err(Error::PsbtParseError)?
};
Ok(Some((funding_tx, channel)))
} else {
Ok(None)
}
}

#[allow(unused)]
async fn accept_normal_payjoin_transaction() -> Result<String, Error> {
unimplemented!()
}

async fn open_channel_from_payjoin_transaction(
&self, channel_funding_tx: Psbt, channel: ScheduledChannel,
) -> Result<String, Error> {
let funding_tx = channel_funding_tx.clone().extract_tx();
// tell the scheduler that the funding tx has been created
self.scheduler
.lock()
.await
.mark_as_funding_tx_created(channel.user_channel_id(), funding_tx.clone());
let counterparty_node_id = channel.node_id();
let _ = self.receiver_handler.notify_funding_generated(
temporary_channel_id.0,
counterparty_node_id,
funding_tx.clone(),
)?;
let temporary_channel_id =
channel.temporary_channel_id().expect("Temporary channel id should exist");
// tell the counterparty node that the funding tx has been created
let _ = self
.receiver_handler
.funding_transaction_generated(
&temporary_channel_id,
counterparty_node_id,
funding_tx.clone(),
)
.unwrap();
// wait for the counterparty node to return FundingSigned message
let res = tokio::time::timeout(tokio::time::Duration::from_secs(3), async move {
let txid = funding_tx.clone().txid();
loop {
Expand All @@ -117,40 +177,41 @@ where
.await;
if res.is_err() {
panic!("Funding tx not signed");
// broadcast original tx
}
Ok(psbt.to_string())
// return the funding psbt to the payjoin sender
// so they can sign and broadcast it
Ok(channel_funding_tx.to_string())
}

async fn handle_incoming_payjoin_request(
&self, request: Request<Incoming>,
) -> Result<String, Error> {
if let Some((channel_funding_tx, channel)) = self.get_next_channel(request).await? {
self.open_channel_from_payjoin_transaction(channel_funding_tx, channel).await
} else {
// accept_normal_payjoin_transaction()
Err(Error::NoAvailableChannel)
}
}

async fn http_router(
http_request: Request<Incoming>, payjoin_lightning: Arc<Mutex<PayjoinService<P>>>,
http_request: Request<Incoming>, payjoin_lightning: Arc<Mutex<LightningPayjoin<P>>>,
) -> Result<hyper::Response<Full<bytes::Bytes>>, hyper::Error> {
match (http_request.method(), http_request.uri().path()) {
(&hyper::Method::POST, "/payjoin") => {
let payjoin_lightning = payjoin_lightning.lock().await;
let payjoin_proposal = payjoin_lightning
.convert_payjoin_request_to_funding_tx(http_request)
.await
.unwrap();
return http_response(payjoin_proposal);
let payjoin_proposal =
payjoin_lightning.handle_incoming_payjoin_request(http_request).await.unwrap();
Ok(hyper::Response::builder()
.body(Full::new(bytes::Bytes::from(payjoin_proposal)))
.unwrap())
},
_ => http_response("404".into()),
_ => Ok(hyper::Response::builder().body(Full::new(bytes::Bytes::from("404"))).unwrap()),
}
}
}

pub async fn extract_psbt_from_http_request(
request: hyper::Request<Incoming>,
is_mine: impl Fn(&ScriptBuf) -> Result<bool, Box<dyn std::error::Error>>,
) -> Result<(Psbt, u64), Box<dyn std::error::Error>> {
let headers = request.headers().clone();
let body = request.into_body().collect().await?;
let body = String::from_utf8(body.to_bytes().to_vec()).unwrap();
let psbt = body_to_psbt(headers.clone(), body.as_bytes());
let amount_to_us = amount_directed_to_us_sat(psbt.clone(), is_mine);
Ok((psbt, amount_to_us))
}
pub fn body_to_psbt(headers: HeaderMap<HeaderValue>, mut body: impl std::io::Read) -> Psbt {
pub(crate) fn body_to_psbt(headers: HeaderMap<HeaderValue>, mut body: impl std::io::Read) -> Psbt {
let content_length =
headers.get("content-length").unwrap().to_str().unwrap().parse::<u64>().unwrap();
let mut buf = vec![0; content_length as usize]; // 4_000_000 * 4 / 3 fits in u32
Expand All @@ -160,39 +221,6 @@ pub fn body_to_psbt(headers: HeaderMap<HeaderValue>, mut body: impl std::io::Rea
psbt
}

pub fn from_original_psbt_to_funding_psbt(
output_script: ScriptBuf, channel_value_sat: u64, mut psbt: Psbt, locktime: LockTime,
is_mine: impl Fn(&ScriptBuf) -> Result<bool, Box<dyn std::error::Error>>,
) -> Psbt {
let multisig_script = output_script;
psbt.unsigned_tx.lock_time = locktime;
psbt.unsigned_tx
.output
.push(bitcoin::TxOut { value: channel_value_sat, script_pubkey: multisig_script.clone() });
psbt.unsigned_tx.output.retain(|output| {
let is_mine = is_mine(&output.script_pubkey).unwrap();
!is_mine || output.script_pubkey == multisig_script
});
let psbt = Psbt::from_unsigned_tx(psbt.unsigned_tx).unwrap();
psbt
}

fn amount_directed_to_us_sat(
psbt: Psbt, is_mine: impl Fn(&ScriptBuf) -> Result<bool, Box<dyn std::error::Error>>,
) -> u64 {
let mut ret = 0;
psbt.unsigned_tx.output.iter().for_each(|output| {
let is_mine = is_mine(&output.script_pubkey).unwrap();
if is_mine {
ret += output.value;
}
});
ret
}
pub fn http_response(s: String) -> Result<hyper::Response<Full<bytes::Bytes>>, hyper::Error> {
Ok(hyper::Response::builder().body(Full::new(bytes::Bytes::from(s))).unwrap())
}

struct RequestHeaders(HashMap<String, String>);

impl payjoin::receive::Headers for RequestHeaders {
Expand Down
22 changes: 22 additions & 0 deletions lightning-payjoin/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ impl FundingTxParams {
) -> Self {
Self { output_script, locktime, temporary_channel_id }
}
pub fn output_script(&self) -> ScriptBuf {
self.output_script.clone()
}
pub fn locktime(&self) -> LockTime {
self.locktime
}
pub fn temporary_channel_id(&self) -> ChannelId {
self.temporary_channel_id
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -299,6 +308,19 @@ impl ScheduledChannel {
}
}

pub fn funding_tx_params(&self) -> Option<FundingTxParams> {
match self.state.clone() {
ScheduledChannelState::ChannelAccepted(_, funding_tx_params) => Some(funding_tx_params),
ScheduledChannelState::FundingTxCreated(_, funding_tx_params, _) => {
Some(funding_tx_params)
},
ScheduledChannelState::FundingTxSigned(_, funding_tx_params, _, _) => {
Some(funding_tx_params)
},
_ => None,
}
}

pub fn temporary_channel_id(&self) -> Option<ChannelId> {
match self.state.clone() {
ScheduledChannelState::ChannelAccepted(_, funding_tx_params) => {
Expand Down
14 changes: 7 additions & 7 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::io::sqlite_store::SqliteStore;
use crate::liquidity::LiquiditySource;
use crate::logger::{log_error, FilesystemLogger, Logger};
use crate::message_handler::NodeCustomMessageHandler;
use crate::payjoin_handler::PayjoinChannelManager;
use crate::payjoin_handler::PayjoinManager;
use crate::payment_store::PaymentStore;
use crate::peer_store::PeerStore;
use crate::sweep::OutputSweeper;
Expand All @@ -20,7 +20,6 @@ use crate::types::{
OnionMessenger, PeerManager,
};
use crate::wallet::Wallet;
use crate::LDKPayjoin;
use crate::{LogLevel, Node};

use lightning::chain::{chainmonitor, BestBlock, Watch};
Expand Down Expand Up @@ -951,10 +950,11 @@ fn build_with_store_internal<K: KVStore + Sync + Send + 'static>(
};

let (stop_sender, _) = tokio::sync::watch::channel(());
let payjoin_channels_handler =
PayjoinChannelManager::new(Arc::clone(&wallet), Arc::clone(&channel_manager));
let payjoin =
Arc::new(LDKPayjoin::new(payjoin_channels_handler, Arc::clone(&channel_scheduler)));
let payjoin_manager = Arc::new(PayjoinManager::new(
Arc::clone(&wallet),
Arc::clone(&channel_manager),
Arc::clone(&channel_scheduler),
));

let is_listening = Arc::new(AtomicBool::new(false));
let latest_wallet_sync_timestamp = Arc::new(RwLock::new(None));
Expand All @@ -975,7 +975,7 @@ fn build_with_store_internal<K: KVStore + Sync + Send + 'static>(
channel_manager,
chain_monitor,
output_sweeper,
payjoin,
payjoin_manager,
channel_scheduler,
peer_manager,
keys_manager,
Expand Down
Loading

0 comments on commit c926828

Please sign in to comment.