diff --git a/hijack.go b/hijack.go index 477dd992..2ac58421 100644 --- a/hijack.go +++ b/hijack.go @@ -222,7 +222,21 @@ func (h *Hijack) ContinueRequest(cq *proto.FetchContinueRequest) { // LoadResponse will send request to the real destination and load the response as default response to override. func (h *Hijack) LoadResponse(client *http.Client, loadBody bool) error { - res, err := client.Do(h.Request.req) + req := h.Request.req + if req.Body != nil && req.GetBody == nil { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return err + } + + req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewBuffer(bodyBytes)), nil + } + } + res, err := client.Do(req) if err != nil { return err } diff --git a/hijack_test.go b/hijack_test.go index a5b114ab..1bee0dd4 100644 --- a/hijack_test.go +++ b/hijack_test.go @@ -6,6 +6,7 @@ import ( "io" "mime" "net/http" + "strings" "sync" "testing" "time" @@ -380,3 +381,55 @@ func TestHandleAuth(t *testing.T) { wait2() page2.MustClose() } + +func TestHijackWithRedirectAndLoadResponseGetBody(t *testing.T) { + g := setup(t) + + redirectCount := 0 + s := g.Serve() + s.Mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + g.Eq(r.Method, http.MethodPost) + body, err := io.ReadAll(r.Body) + g.Eq(err, nil) + g.Eq("test", string(body)) + + if redirectCount < 3 { + redirectCount++ + w.Header().Set("Location", s.URL("/test")) + w.WriteHeader(http.StatusTemporaryRedirect) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + router := g.page.HijackRequests() + defer router.MustStop() + + router.MustAdd(s.URL("/test"), func(ctx *rod.Hijack) { + ctx.Request.Req().Body = io.NopCloser(strings.NewReader("test")) + ctx.Request.Req().Method = http.MethodPost + ctx.Request.SetBody([]byte("test")) + + err := ctx.LoadResponse(http.DefaultClient, true) + g.Eq(err, nil) + + body, err := ctx.Request.Req().GetBody() + g.Eq(err, nil) + bodyBytes, err := io.ReadAll(body) + + g.Eq(err, nil) + g.Eq("test", string(bodyBytes)) + + }) + + go router.Run() + + g.page.MustNavigate(s.URL("/test")) + + content := g.page.MustElement("body").MustText() + g.Eq(content, "OK") + g.Eq(redirectCount, 3) + +}