diff --git a/failover.go b/failover.go new file mode 100644 index 0000000..86a39c3 --- /dev/null +++ b/failover.go @@ -0,0 +1,147 @@ +package failover + +import ( + "errors" + "fmt" + "net/url" + "sync" + "sync/atomic" + "time" +) + +type Failover interface { + AddUrl(url *url.URL) error // adds one url + AddUrls(urls ...*url.URL) error // adds multiple urls + + Request(requestFunc func(url *url.URL) error, localOpts ...func(*Options)) error // sends a request to the url using your requestFunc + + MustParseURL(url string) *url.URL // panics if url is invalid + ParseURL(url string) (*url.URL, error) // returns error if url is invalid +} + +type onErrType uint8 + +const ( + ReqOnErrRemoveAndReconnect onErrType = iota // remove url and reconnect (default attempts - 3) + ReqOnErrIgnore // ignore error (returns nil) + ReqOnErrReturnErr // return error + ReqOnErrReconnectNext // reconnect (to next url) + ReqOnErrReconnectCurrent // reconnect (to current url) +) + +type Options struct { + CheckConn func(url *url.URL) error // function for checking url + CheckUrlBeforeAdding bool // check url before adding (using CheckConn) + CheckUrlDelay time.Duration // delay before checking url (defualt 30s) + ReqOnErr onErrType // what to do on request error + MaxAttempts uint16 // max attempts (default 3) + +} +type storage struct { + activeUrls *atomic.Pointer[[]*url.URL] + badUrls *atomic.Pointer[[]*url.URL] + roundRobin roundRobin + options *Options + mu *sync.Mutex + startCronChan chan bool +} + +func New(checkConnection func(url *url.URL) error, options ...func(*Options)) Failover { + activeUrls := &atomic.Pointer[[]*url.URL]{} + badUrls := &atomic.Pointer[[]*url.URL]{} + + opts := &Options{ // default options + CheckUrlBeforeAdding: true, + CheckUrlDelay: time.Second * 30, + CheckConn: checkConnection, + } + + for _, opt := range options { + opt(opts) + } + + s := &storage{roundRobin: newRoundRobin(activeUrls), badUrls: badUrls, activeUrls: activeUrls, options: opts, mu: &sync.Mutex{}, startCronChan: make(chan bool)} + + // go printUrls(s.activeUrls, s.badUrls) + go s.runCronUrlCheck() + + return s + +} + +func (s *storage) AddUrls(urls ...*url.URL) error { + for _, url := range urls { + err := s.AddUrl(url) + if err != nil { + return err + } + } + return nil +} + +func (s *storage) AddUrl(url__ *url.URL) error { + s.mu.Lock() + defer s.mu.Unlock() + + if url__ == nil { + return errors.New("url is nil") + } + + if s.options.CheckUrlBeforeAdding { + err := s.options.CheckConn(url__) + if err != nil { + return fmt.Errorf("check connection: %w", err) + } + } + + urlsPtr := s.activeUrls.Load() + if urlsPtr == nil { + s.activeUrls.Store(&[]*url.URL{url__}) + return nil + } + urls := *urlsPtr + + urls = removeUrl(urls, url__) // remove url if already exists + urls = append(urls, url__) // append unique url + + s.activeUrls.Store(&urls) + + return nil +} + +func (s *storage) Request(requestFunc func(url *url.URL) error, localOpts ...func(*Options)) error { + s.mu.Lock() + defer s.mu.Unlock() + + urlsPtr := s.activeUrls.Load() + if urlsPtr == nil { + return errors.New("no urls found") + } + + urls := *urlsPtr + + url, found := s.roundRobin.Next() + if !found { + return errors.New("round robin: no urls found") + } + + optsCopy := *s.options // we need to copy opts because we need to change it (locally) + + for _, opt := range localOpts { + opt(&optsCopy) + } + + return s.request(requestFunc, &optsCopy, urls, url, 1) +} + +func (s *storage) MustParseURL(url__ string) *url.URL { + u, err := url.Parse(url__) + if err != nil { + panic(err) + } + return u +} + +func (s *storage) ParseURL(url__ string) (*url.URL, error) { + return url.Parse(url__) +} diff --git a/failover_test.go b/failover_test.go new file mode 100644 index 0000000..6aba56c --- /dev/null +++ b/failover_test.go @@ -0,0 +1,78 @@ +package failover + +import ( + "fmt" + "net/http" + "net/url" + "testing" + "time" +) + +func TestFailover(t *testing.T) { + f := New(CheckConnection, + OptCheckUrlBeforeAdding(true), + OptCheckUrlDelay(10*time.Second), + OptMaxAttempts(10), + ) + + tests := []*url.URL{ + // {Host: "google.com", Scheme: "https"}, + {Host: "0.0.0.0:8080", Scheme: "http"}, + {Host: "0.0.0.0:8081", Scheme: "http"}, + {Host: "0.0.0.0:8081", Scheme: "http"}, + {Host: "0.0.0.0:8081", Scheme: "http"}, + {Host: "0.0.0.0:8082", Scheme: "http"}, + {Host: "0.0.0.0:8082", Scheme: "http"}, + {Host: "0.0.0.0:8083", Scheme: "http"}, + {Host: "0.0.0.0:8084", Scheme: "http"}, + {Host: "0.0.0.0:8085", Scheme: "http"}, + {Host: "0.0.0.0:8086", Scheme: "http"}, + {Host: "0.0.0.0:8087", Scheme: "http"}, + } + + err := f.AddUrl(tests[0]) + if err != nil { + t.Fatal(err) + } + + err = f.AddUrl(tests[0]) // duplicate + if err != nil { + t.Fatal(err) + } + + err = f.AddUrl(tests[1]) + if err != nil { + t.Fatal(err) + } + + err = f.AddUrls(tests...) + if err != nil { + t.Fatal(err) + } + + time.Sleep(5 * time.Second) + // fmt.Println("STATUS", f.Request(Request, OptReqOnErr(ReqOnErrReconnectNext), OptMaxAttempts(4))) + fmt.Println("STATUS", f.Request(Request, OptReqOnErr(ReqOnErrReconnectNext), OptMaxAttempts(2))) + // f.Request(Request) + // f.Request(Request) + // f.Request(Request) + +} + +func CheckConnection(url *url.URL) error { + _, err := http.Get(url.String()) + if err != nil { + fmt.Println(err) + return err + } + return nil +} + +func Request(url *url.URL) error { + _, err := http.Get("http://" + url.Host) + if err != nil { + return err + } + + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ae98f6f --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module failover + +go 1.22.5 diff --git a/helpers.go b/helpers.go new file mode 100644 index 0000000..9cd47f6 --- /dev/null +++ b/helpers.go @@ -0,0 +1,211 @@ +package failover + +import ( + "errors" + "net/url" + "sync" + "sync/atomic" + "time" +) + +func addToBadUrls(activeUrls__ *atomic.Pointer[[]*url.URL], badUrls__ *atomic.Pointer[[]*url.URL], url__ *url.URL) { + activeUrlsPtr := activeUrls__.Load() + if activeUrlsPtr == nil { + return + } + + activeUrls := *activeUrlsPtr + + activeUrls = removeUrl(activeUrls, url__) + activeUrls__.Store(&activeUrls) + + badUrlsPtr := badUrls__.Load() + if badUrlsPtr == nil { + badUrls__.Store(&[]*url.URL{url__}) + return + } + + badUrls := *badUrlsPtr + badUrls = removeUrl(badUrls, url__) + badUrls = append(badUrls, url__) + badUrls__.Store(&badUrls) + +} + +func removeFromBadUrls(activeUrls__ *atomic.Pointer[[]*url.URL], badUrls__ *atomic.Pointer[[]*url.URL], url__ *url.URL) { + badUrlsPtr := badUrls__.Load() + if badUrlsPtr == nil { + return + } + + badUrls := *badUrlsPtr + + badUrls = removeUrl(badUrls, url__) + badUrls__.Store(&badUrls) + + activeUrlsPtr := activeUrls__.Load() + if activeUrlsPtr == nil { + activeUrls__.Store(&[]*url.URL{url__}) + return + } + + activeUrls := *activeUrlsPtr + activeUrls = removeUrl(activeUrls, url__) + activeUrls = append(activeUrls, url__) + activeUrls__.Store(&activeUrls) + +} + +func removeUrl(slice []*url.URL, element *url.URL) []*url.URL { + newSlice := make([]*url.URL, 0, len(slice)) + + for _, i := range slice { + if i.String() != element.String() { + newSlice = append(newSlice, i) + } + + } + + return newSlice +} + +func (s *storage) request(requestFunc func(url *url.URL) error, optsCopy *Options, urls []*url.URL, url *url.URL, attempts uint16) error { + err := requestFunc(url) + if err == nil { + return nil + } + + if attempts >= optsCopy.MaxAttempts { + return err + } + + switch optsCopy.ReqOnErr { + case ReqOnErrIgnore: + return nil + case ReqOnErrReturnErr: + return err + case ReqOnErrReconnectCurrent: + return s.request(requestFunc, optsCopy, urls, url, attempts+1) + case ReqOnErrReconnectNext: + url, found := s.roundRobin.Next() + if !found { + return errors.New("round robin: no urls found") + } + return s.request(requestFunc, optsCopy, urls, url, attempts+1) + case ReqOnErrRemoveAndReconnect: + addToBadUrls(s.activeUrls, s.badUrls, url) + // printUrlsOnce(s.activeUrls, s.badUrls) + + url, found := s.roundRobin.Next() + if !found { + return errors.New("round robin: no urls found") + } + return s.request(requestFunc, optsCopy, urls, url, attempts+1) + } + + return err +} + +func (s *storage) runCronUrlCheck() { + + ticker := time.NewTicker(s.options.CheckUrlDelay * time.Second) + defer ticker.Stop() + + for { + if s.activeUrls.Load() == nil && s.badUrls.Load() == nil { + continue + } + + if s.activeUrls.Load() == nil { + continue + } + + for range ticker.C { + if cronUrlCheck(s.activeUrls, s.badUrls, s.options) != nil { + break + } + } + } + +} + +func cronUrlCheck(urls__ *atomic.Pointer[[]*url.URL], badUrls__ *atomic.Pointer[[]*url.URL], options *Options) error { + var wg sync.WaitGroup + + wg.Add(2) // bad urls and active urls + + go func() { // check active urls + defer func() { + wg.Done() + }() + + urlsPtr := urls__.Load() + if urlsPtr == nil { + return + } + + urls := *urlsPtr + + for _, url := range urls { + err := options.CheckConn(url) + if err != nil { + addToBadUrls(urls__, badUrls__, url) + continue + } + } + }() + + go func() { + defer func() { + wg.Done() + }() + + badUrlsPtr := badUrls__.Load() + if badUrlsPtr == nil { + return + } + + badUrls := *badUrlsPtr + for _, url := range badUrls { + err := options.CheckConn(url) + if err == nil { + removeFromBadUrls(urls__, badUrls__, url) + continue + } + } + + }() + + wg.Wait() + + return nil + +} + +// func printUrls(activeUrls__ *atomic.Pointer[[]*url.URL], badUrls__ *atomic.Pointer[[]*url.URL]) { +// for { +// activeUrlsPtr := activeUrls__.Load() +// if activeUrlsPtr != nil { +// fmt.Printf("active urls: %v\n", *activeUrlsPtr) +// } + +// badUrlsPtr := badUrls__.Load() +// if badUrlsPtr != nil { +// fmt.Printf("bad urls: %v\n", *badUrlsPtr) +// } +// time.Sleep(1 * time.Second) + +// } +// } + +// func printUrlsOnce(activeUrls__ *atomic.Pointer[[]*url.URL], badUrls__ *atomic.Pointer[[]*url.URL]) { +// activeUrlsPtr := activeUrls__.Load() +// if activeUrlsPtr != nil { +// fmt.Printf("active urls: %v\n", *activeUrlsPtr) +// } + +// badUrlsPtr := badUrls__.Load() +// if badUrlsPtr != nil { +// fmt.Printf("bad urls: %v\n", *badUrlsPtr) +// } +// } diff --git a/opts.go b/opts.go new file mode 100644 index 0000000..8287818 --- /dev/null +++ b/opts.go @@ -0,0 +1,27 @@ +package failover + +import "time" + +func OptCheckUrlBeforeAdding(check bool) func(*Options) { + return func(options *Options) { + options.CheckUrlBeforeAdding = check + } +} + +func OptCheckUrlDelay(delay time.Duration) func(*Options) { + return func(options *Options) { + options.CheckUrlDelay = delay + } +} + +func OptReqOnErr(onErr onErrType) func(*Options) { + return func(options *Options) { + options.ReqOnErr = onErr + } +} + +func OptMaxAttempts(maxAttempts uint16) func(*Options) { + return func(options *Options) { + options.MaxAttempts = maxAttempts + } +} diff --git a/round-robin.go b/round-robin.go new file mode 100644 index 0000000..7e2baaa --- /dev/null +++ b/round-robin.go @@ -0,0 +1,48 @@ +package failover + +import ( + "net/url" + "sync" + "sync/atomic" +) + +type roundRobin interface { + Next() (*url.URL, bool) +} + +type rr struct { + mu *sync.Mutex + index *atomic.Uint32 + servers *atomic.Pointer[[]*url.URL] +} + +func newRoundRobin(servers *atomic.Pointer[[]*url.URL]) roundRobin { + + // fmt.Println(index.Load()) + + return rr{ + index: new(atomic.Uint32), + servers: servers, + mu: &sync.Mutex{}, + } +} + +func (rr rr) Next() (*url.URL, bool) { + rr.mu.Lock() + defer rr.mu.Unlock() + + servers := *rr.servers.Load() + + if len(servers) == 0 { + return nil, false + } + + n := rr.index.Add(1) + conn := servers[(int(n)-1)%len(servers)] + + if conn.String() != "" { + return conn, true + } + + return nil, false +} diff --git a/round-robin_test.go b/round-robin_test.go new file mode 100644 index 0000000..a4a0a5d --- /dev/null +++ b/round-robin_test.go @@ -0,0 +1,49 @@ +package failover + +import ( + "fmt" + "net/url" + "sync/atomic" + "testing" +) + +func TestRoundRobin(t *testing.T) { + + tests := []*url.URL{ + {Host: "127.0.0.1"}, + {Host: "127.0.0.2"}, + {Host: "127.0.0.3"}, + {Host: "127.0.0.4"}, + {Host: "127.0.0.5"}, + {Host: "127.0.0.6"}, + {Host: "127.0.0.7"}, + {Host: "127.0.0.8"}, + {Host: "127.0.0.9"}, + } + + var x atomic.Pointer[[]*url.URL] + x.Store(&[]*url.URL{}) + + arr := *x.Load() + arr = append(arr, tests...) + + x.Store(&arr) + rr := newRoundRobin(&x) + + for range 10 { + go func() { + c, ok := rr.Next() + fmt.Printf("c: %v, ok: %v\n", c, ok) + + c, ok = rr.Next() + fmt.Printf("c: %v, ok: %v\n", c, ok) + + c, ok = rr.Next() + fmt.Printf("c: %v, ok: %v\n", c, ok) + + c, ok = rr.Next() + fmt.Printf("c: %v, ok: %v\n", c, ok) + }() + } + +}