diff --git a/common/0rtt_utils.go b/common/0rtt_utils.go index f3b24db..dbed7e6 100644 --- a/common/0rtt_utils.go +++ b/common/0rtt_utils.go @@ -4,13 +4,12 @@ import ( "crypto/tls" "errors" "github.com/lucas-clemente/quic-go" - "net" - "time" ) // PingToGatherSessionTicketAndToken establishes a new QUIC connection. // As soon as the session ticket and the token is received, the connection is closed. // This function can be used to prepare for 0-RTT +// TODO add timeout func PingToGatherSessionTicketAndToken(addr string, tlsConf *tls.Config, config *quic.Config) error { if tlsConf.ClientSessionCache == nil { return errors.New("session cache is nil") @@ -18,53 +17,26 @@ func PingToGatherSessionTicketAndToken(addr string, tlsConf *tls.Config, config if config.TokenStore == nil { panic("session cache is nil") } - connection, err := quic.DialAddr(addr, tlsConf, config) + + singleSessionCache := NewSingleSessionCache() + singleTokenStore := NewSingleTokenStore() + + tmpTlsConf := tlsConf.Clone() + tmpTlsConf.ClientSessionCache = singleSessionCache + tmpConfig := config.Clone() + tmpConfig.TokenStore = singleTokenStore + + connection, err := quic.DialAddr(addr, tmpTlsConf, tmpConfig) if err != nil { return err } - sessionCacheKey := sessionCacheKey(connection.RemoteAddr(), tlsConf) - tokenStoreKey := tokenStoreKey(connection.RemoteAddr(), tlsConf) + tlsConf.ClientSessionCache.Put(singleSessionCache.Await()) + config.TokenStore.Put(singleTokenStore.Await()) - // await session ticket - for { - time.Sleep(time.Millisecond) - _, ok := tlsConf.ClientSessionCache.Get(sessionCacheKey) - if ok { - break - } - } - // await token - for { - time.Sleep(time.Millisecond) - token := config.TokenStore.Pop(tokenStoreKey) - if token != nil { - config.TokenStore.Put(tokenStoreKey, token) // put back again - break - } - } err = connection.CloseWithError(quic.ApplicationErrorCode(0), "cancel") if err != nil { return err } return nil } - -// inspired by qtls.clientSessionCacheKey implementation -// TODO avoid duplicate code -func sessionCacheKey(serverAddr net.Addr, tlsConf *tls.Config) string { - if len(tlsConf.ServerName) > 0 { - return "qtls-" + tlsConf.ServerName - } - return "qtls-" + serverAddr.String() -} - -// inspired by quic.newClientSession implementation -// TODO avoid duplicate code -func tokenStoreKey(serverAddr net.Addr, tlsConf *tls.Config) string { - if len(tlsConf.ServerName) > 0 { - return tlsConf.ServerName - } else { - return serverAddr.String() - } -} diff --git a/common/single_session_cache.go b/common/single_session_cache.go new file mode 100644 index 0000000..1f3bf05 --- /dev/null +++ b/common/single_session_cache.go @@ -0,0 +1,54 @@ +package common + +import ( + "context" + "crypto/tls" + "github.com/marten-seemann/qtls-go1-19" +) + +type SingleSessionCache struct { + emptyContext context.Context + emptyContextCancel context.CancelFunc + sessionKey *string + session *tls.ClientSessionState +} + +var _ qtls.ClientSessionCache = (*SingleSessionCache)(nil) + +func (s *SingleSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) { + select { + case <-s.emptyContext.Done(): + if sessionKey == *s.sessionKey { + return session, true + } + default: // do not wait + } + return nil, false +} + +func (s *SingleSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { + select { + case <-s.emptyContext.Done(): + return // already set + default: + //TODO make thread safe + s.sessionKey = &sessionKey + s.session = cs + s.emptyContextCancel() + } +} + +func (s *SingleSessionCache) Await() (string, *tls.ClientSessionState) { + <-s.emptyContext.Done() + return *s.sessionKey, s.session +} + +func NewSingleSessionCache() *SingleSessionCache { + emptyContext, emptyContextCancel := context.WithCancel(context.Background()) + return &SingleSessionCache{ + emptyContext, + emptyContextCancel, + nil, + nil, + } +} diff --git a/common/single_token_store.go b/common/single_token_store.go new file mode 100644 index 0000000..576d363 --- /dev/null +++ b/common/single_token_store.go @@ -0,0 +1,53 @@ +package common + +import ( + "context" + "github.com/lucas-clemente/quic-go" +) + +type SingleTokenStore struct { + emptyContext context.Context + emptyContextCancel context.CancelFunc + key *string + token *quic.ClientToken +} + +var _ quic.TokenStore = (*SingleTokenStore)(nil) + +// Pop does not remove the token +func (s *SingleTokenStore) Pop(key string) (token *quic.ClientToken) { + select { + case <-s.emptyContext.Done(): + if key == *s.key { + return s.token + } + default: // do not wait + } + return nil +} + +func (s *SingleTokenStore) Put(key string, token *quic.ClientToken) { + select { + case <-s.emptyContext.Done(): + return // already set + default: + //TODO make thread safe + s.key = &key + s.token = token + s.emptyContextCancel() + } +} + +func (s *SingleTokenStore) Await() (string, *quic.ClientToken) { + <-s.emptyContext.Done() + return *s.key, s.token +} +func NewSingleTokenStore() *SingleTokenStore { + emptyContext, emptyContextCancel := context.WithCancel(context.Background()) + return &SingleTokenStore{ + emptyContext, + emptyContextCancel, + nil, + nil, + } +}