From 93b2a303c93ca029966a8461ecc6f82deaca1435 Mon Sep 17 00:00:00 2001 From: Shinku <17696928+Shinku-Chen@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:56:09 +0800 Subject: [PATCH] retry redirect to AlreadyVisitedUrl will loop error --- colly.go | 10 ++++++++-- colly_test.go | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/colly.go b/colly.go index ae74b7c3..ffb15581 100644 --- a/colly.go +++ b/colly.go @@ -202,7 +202,10 @@ var collectorCounter uint32 type key int // ProxyURLKey is the context key for the request proxy address. -const ProxyURLKey key = iota +const ( + ProxyURLKey key = iota + CheckRevisitKey +) var ( // ErrForbiddenDomain is the error thrown if visiting @@ -650,6 +653,7 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c } // note: once 1.13 is minimum supported Go version, // replace this with http.NewRequestWithContext + c.Context = context.WithValue(c.Context, CheckRevisitKey, checkRevisit) req = req.WithContext(c.Context) if err := c.requestCheck(parsedURL, method, req.GetBody, depth, checkRevisit); err != nil { return err @@ -1382,7 +1386,9 @@ func (c *Collector) checkRedirectFunc() func(req *http.Request, via []*http.Requ return err } if visited { - return &AlreadyVisitedError{req.URL} + if checkRevisit, ok := req.Context().Value(CheckRevisitKey).(bool); !ok || checkRevisit { + return &AlreadyVisitedError{req.URL} + } } err = c.store.Visited(uHash) if err != nil { diff --git a/colly_test.go b/colly_test.go index e70d2774..dbd4f1b8 100644 --- a/colly_test.go +++ b/colly_test.go @@ -1814,3 +1814,25 @@ func TestCollectorPostRetryUnseekable(t *testing.T) { t.Error("OnResponse Retry was called but BodyUnseekable") } } + +func TestRedirectErrorRetry(t *testing.T) { + ts := newTestServer() + defer ts.Close() + c := NewCollector() + c.OnError(func(r *Response, err error) { + if r.Ctx.Get("notFirst") == "" { + r.Ctx.Put("notFirst", "first") + _ = r.Request.Retry() + return + } + if e := (&AlreadyVisitedError{}); errors.As(err, &e) { + t.Error("loop AlreadyVisitedError") + } + + }) + c.OnResponse(func(response *Response) { + //println(1) + }) + c.Visit(ts.URL + "/redirected/") + c.Visit(ts.URL + "/redirect") +}