Skip to content

Commit

Permalink
fixed build break
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Aug 10, 2016
1 parent 10380ab commit a060fa9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
2 changes: 1 addition & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestContextGetSet(t *testing.T) {
assert.Nil(t, c.Value("abcd"))
assert.Nil(t, c.Value(123))
deadline, ok := c.Deadline()
assert.Nil(t, deadline)
assert.Zero(t, deadline)
assert.False(t, ok)
assert.Nil(t, c.Done())
assert.Nil(t, c.Err())
Expand Down
5 changes: 0 additions & 5 deletions fault/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ import (
"github.com/go-ozzo/ozzo-routing"
)

type (
// ConvertErrorFunc converts an error into a different format so that it is more appropriate for rendering purpose.
ConvertErrorFunc func(*routing.Context, error) error
)

// ErrorHandler returns a handler that handles errors returned by the handlers following this one.
// If the error implements routing.HTTPError, the handler will set the HTTP status code accordingly.
// Otherwise the HTTP status is set as http.StatusInternalServerError. The handler will also write the error
Expand Down
12 changes: 9 additions & 3 deletions fault/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ type (
// LogFunc should be thread safe.
LogFunc func(format string, a ...interface{})

// HandleErrorFunc is called whenever a panic or error is captured by the middleware.
HandleErrorFunc func(c *routing.Context, err error, log LogFunc)
// ConvertErrorFunc converts an error into a different format so that it is more appropriate for rendering purpose.
ConvertErrorFunc func(*routing.Context, error) error
)

// Recovery returns a handler that handles both panics and errors occurred while servicing an HTTP request.
Expand All @@ -27,6 +27,9 @@ type (
//
// A log function can be provided to log a message whenever an error is handled. If nil, no message will be logged.
//
// An optional error conversion function can also be provided to convert an error into a normalized one
// before sending it to the response.
//
// import (
// "log"
// "github.com/go-ozzo/ozzo-routing"
Expand All @@ -35,13 +38,16 @@ type (
//
// r := routing.New()
// r.Use(fault.Recovery(log.Printf))
func Recovery(logf LogFunc) routing.Handler {
func Recovery(logf LogFunc, errorf ...ConvertErrorFunc) routing.Handler {
handlePanic := PanicHandler(logf)
return func(c *routing.Context) error {
if err := handlePanic(c); err != nil {
if logf != nil {
logf("%v", err)
}
if len(errorf) > 0 {
err = errorf[0](c, err)
}
writeError(c, err)
c.Abort()
}
Expand Down
21 changes: 21 additions & 0 deletions fault/recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ func TestRecovery(t *testing.T) {
assert.Equal(t, "123", res.Body.String())
assert.Contains(t, buf.String(), "recovery_test.go")
assert.Contains(t, buf.String(), "123")

buf.Reset()
h = Recovery(getLogger(&buf), convertError)
res = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/users/", nil)
c = routing.NewContext(res, req, h, handler3, handler2)
assert.Nil(t, c.Next())
assert.Equal(t, http.StatusInternalServerError, res.Code)
assert.Equal(t, "123", res.Body.String())
assert.Contains(t, buf.String(), "recovery_test.go")
assert.Contains(t, buf.String(), "xyz")

buf.Reset()
h = Recovery(getLogger(&buf), convertError)
res = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/users/", nil)
c = routing.NewContext(res, req, h, handler1, handler2)
assert.Nil(t, c.Next())
assert.Equal(t, http.StatusInternalServerError, res.Code)
assert.Equal(t, "123", res.Body.String())
assert.Equal(t, "abc", buf.String())
}

func getLogger(buf *bytes.Buffer) LogFunc {
Expand Down

0 comments on commit a060fa9

Please sign in to comment.