Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HandlersChain on handler #331

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
116 changes: 78 additions & 38 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ import (
"net/http"
"strings"
"sync"

"fmt"
"io"
"os"
"reflect"
"runtime"
)

// Handle is a function that can be registered to a route to handle HTTP
// requests. Like http.HandlerFunc, but has a third parameter for the values of
// wildcards (path variables).
type Handle func(http.ResponseWriter, *http.Request, Params)

var DefaultWriter io.Writer = os.Stdout

// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Expand Down Expand Up @@ -231,54 +239,61 @@ func (r *Router) putParams(ps *Params) {
}
}

func (r *Router) saveMatchedRoutePath(path string, handle Handle) Handle {
return func(w http.ResponseWriter, req *http.Request, ps Params) {
if ps == nil {
psp := r.getParams()
ps = (*psp)[0:1]
ps[0] = Param{Key: MatchedRoutePathParam, Value: path}
handle(w, req, ps)
r.putParams(psp)
} else {
ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path})
handle(w, req, ps)
// If enabled, adds the matched route path onto the http.Request context
// before invoking the handler.
// The matched route path is only added to handlers of routes that were
// registered when this option was enabled.
func (r *Router) saveMatchedRoutePath(path string, handles ...Handle) HandlersChain {
for i, handle := range handles {
handles[i] = func(w http.ResponseWriter, req *http.Request, ps Params) {
if ps == nil {
psp := r.getParams()
ps = (*psp)[0:1]
ps[0] = Param{Key: MatchedRoutePathParam, Value: path}
handle(w, req, ps)
r.putParams(psp)
} else {
ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path})
handle(w, req, ps)
}
}
}
return handles;
}

// GET is a shortcut for router.Handle(http.MethodGet, path, handle)
func (r *Router) GET(path string, handle Handle) {
r.Handle(http.MethodGet, path, handle)
func (r *Router) GET(path string, handle ...Handle) {
r.Handle(http.MethodGet, path, handle...)
}

// HEAD is a shortcut for router.Handle(http.MethodHead, path, handle)
func (r *Router) HEAD(path string, handle Handle) {
r.Handle(http.MethodHead, path, handle)
func (r *Router) HEAD(path string, handle ...Handle) {
r.Handle(http.MethodHead, path, handle...)
}

// OPTIONS is a shortcut for router.Handle(http.MethodOptions, path, handle)
func (r *Router) OPTIONS(path string, handle Handle) {
r.Handle(http.MethodOptions, path, handle)
func (r *Router) OPTIONS(path string, handle ...Handle) {
r.Handle(http.MethodOptions, path, handle...)
}

// POST is a shortcut for router.Handle(http.MethodPost, path, handle)
func (r *Router) POST(path string, handle Handle) {
r.Handle(http.MethodPost, path, handle)
func (r *Router) POST(path string, handle ...Handle) {
r.Handle(http.MethodPost, path, handle...)
}

// PUT is a shortcut for router.Handle(http.MethodPut, path, handle)
func (r *Router) PUT(path string, handle Handle) {
r.Handle(http.MethodPut, path, handle)
func (r *Router) PUT(path string, handle ...Handle) {
r.Handle(http.MethodPut, path, handle...)
}

// PATCH is a shortcut for router.Handle(http.MethodPatch, path, handle)
func (r *Router) PATCH(path string, handle Handle) {
r.Handle(http.MethodPatch, path, handle)
func (r *Router) PATCH(path string, handle ...Handle) {
r.Handle(http.MethodPatch, path, handle...)
}

// DELETE is a shortcut for router.Handle(http.MethodDelete, path, handle)
func (r *Router) DELETE(path string, handle Handle) {
r.Handle(http.MethodDelete, path, handle)
func (r *Router) DELETE(path string, handle ...Handle) {
r.Handle(http.MethodDelete, path, handle...)
}

// Handle registers a new request handle with the given path and method.
Expand All @@ -289,7 +304,7 @@ func (r *Router) DELETE(path string, handle Handle) {
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (r *Router) Handle(method, path string, handle Handle) {
func (r *Router) Handle(method, path string, handles ...Handle) {
varsCount := uint16(0)

if method == "" {
Expand All @@ -298,19 +313,29 @@ func (r *Router) Handle(method, path string, handle Handle) {
if len(path) < 1 || path[0] != '/' {
panic("path must begin with '/' in path '" + path + "'")
}
if handle == nil {
//HandlersChain is nil or array[0] is nil
if handles == nil || (len(handles)==1 && handles[0]==nil){
panic("handle must not be nil")
}

if r.SaveMatchedRoutePath {
varsCount++
handle = r.saveMatchedRoutePath(path, handle)
handles = r.saveMatchedRoutePath(path, handles...)
}

if r.trees == nil {
r.trees = make(map[string]*node)
}

handleNames := ""
for _, handle := range handles {
if handle != nil{
handleNames = handleNames+","+runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name()
}
}
fmt.Fprintf(DefaultWriter, "[router-debug] %-6s %-40s --> [%s] (%d handlers)\n", method, path,handleNames,len(handles) )


root := r.trees[method]
if root == nil {
root = new(node)
Expand All @@ -319,7 +344,7 @@ func (r *Router) Handle(method, path string, handle Handle) {
r.globalAllowed = r.allowed("*", "")
}

root.addRoute(path, handle)
root.addRoute(path, handles)

// Update maxParams
if paramsCount := countParams(path); paramsCount+varsCount > r.maxParams {
Expand Down Expand Up @@ -391,17 +416,17 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
// If the path was found, it returns the handle function and the path parameter
// values. Otherwise the third return value indicates whether a redirection to
// the same path with an extra / without the trailing slash should be performed.
func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
func (r *Router) Lookup(method, path string) (HandlersChain, Params, bool) {
if root := r.trees[method]; root != nil {
handle, ps, tsr := root.getValue(path, r.getParams)
if handle == nil {
handles, ps, tsr := root.getValue(path, r.getParams)
if handles == nil {
r.putParams(ps)
return nil, nil, tsr
}
if ps == nil {
return handle, nil, tsr
return handles, nil, tsr
}
return handle, *ps, tsr
return handles, *ps, tsr
}
return nil, nil, false
}
Expand Down Expand Up @@ -464,14 +489,29 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

path := req.URL.Path
//输出请求的处理日志
reqLog := fmt.Sprintf("[http-router] %-6s req.uri=%-12s", req.Method,req.RequestURI)

if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if handles, ps, tsr := root.getValue(path, r.getParams); handles != nil {
if ps != nil {
handle(w, req, *ps)
r.putParams(ps)
for _,handle := range handles {
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name())
handle(w, req, *ps)
r.putParams(ps)
}
} else {
handle(w, req, nil)
for _,handle := range handles {
if handle != nil{
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name())
handle(w, req, nil)
}else{
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle is nil,do nothing.\n")
}
}
}
return
} else if req.Method != http.MethodConnect && path != "/" {
Expand Down
64 changes: 38 additions & 26 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ func TestRouterAPI(t *testing.T) {
httpHandler := handlerStruct{&handler}

router := New()
router.GET("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
})

router.GET("/GET",
func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
},func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
},nil,
)
router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
head = true
})
Expand Down Expand Up @@ -509,56 +514,63 @@ func TestRouterLookup(t *testing.T) {
router := New()

// try empty router first
handle, _, tsr := router.Lookup(http.MethodGet, "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr := router.Lookup(http.MethodGet, "/nope")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
}

// insert route and try again
router.GET("/user/:name", wantHandle)
handle, params, _ := router.Lookup(http.MethodGet, "/user/gopher")
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
handles, params, _ := router.Lookup(http.MethodGet, "/user/gopher")

for _,handle := range handles {
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
}
}
}

if !reflect.DeepEqual(params, wantParams) {
t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
}
routed = false

// route without param
router.GET("/user", wantHandle)
handle, params, _ = router.Lookup(http.MethodGet, "/user")
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
handles, params, _ = router.Lookup(http.MethodGet, "/user")
for _,handle := range handles {
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
}
}
}

if params != nil {
t.Fatalf("Wrong parameter values: want %v, got %v", nil, params)
}

handle, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if !tsr {
t.Error("Got no TSR recommendation!")
}

handle, _, tsr = router.Lookup(http.MethodGet, "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr = router.Lookup(http.MethodGet, "/nope")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
Expand Down
Loading