diff --git a/dial.go b/dial.go index f5e4544b..17305e10 100644 --- a/dial.go +++ b/dial.go @@ -48,6 +48,18 @@ type DialOptions struct { // for CompressionContextTakeover. CompressionThreshold int + // MaxErrorResponseBodyBytes controls how many bytes of the HTTP response body + // are captured and made available via resp.Body when the WebSocket handshake + // fails (i.e. Dial returns a non-nil error after receiving an HTTP response). + // + // Semantics: + // 0 => preserve current behavior and capture up to 1024 bytes (default) + // >0 => capture up to that many bytes + // <0 => do not capture any bytes; resp.Body will remain nil on error + // + // Regardless of this setting, the original HTTP response body is always closed. + MaxErrorResponseBodyBytes int + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. // // The payload contains the application data of the ping frame. @@ -110,7 +122,8 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context // You never need to close resp.Body yourself. // // If an error occurs, the returned response may be non nil. -// However, you can only read the first 1024 bytes of the body. +// By default, up to the first 1024 bytes of the body are available; this limit +// can be adjusted via DialOptions.MaxErrorResponseBodyBytes. // // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. @@ -147,9 +160,20 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( respBody := resp.Body resp.Body = nil defer func() { - if err != nil { - // We read a bit of the body for easier debugging. - r := io.LimitReader(respBody, 1024) + if err != nil && respBody != nil { + // Capture a limited portion of the response body for easier debugging, + // following the limit configured by MaxErrorResponseBodyBytes. + limit := opts.MaxErrorResponseBodyBytes + if limit == 0 { + limit = 1024 + } + if limit < 0 { + // Do not capture any body bytes; ensure original body is closed. + respBody.Close() + return + } + + r := io.LimitReader(respBody, int64(limit)) timer := time.AfterFunc(time.Second*3, func() { respBody.Close() diff --git a/dial_test.go b/dial_test.go index 492ac6b3..7e7de795 100644 --- a/dial_test.go +++ b/dial_test.go @@ -416,3 +416,134 @@ func TestDialViaProxy(t *testing.T) { assertEcho(t, ctx, c) assertClose(t, c) } + +// Additional tests for error response body capture behavior. +// A tracking body to verify Close is called when capture is disabled. +type trackingBodyDialTest struct { + io.ReadCloser + closed *bool +} + +func (tb trackingBodyDialTest) Close() error { + *tb.closed = true + return tb.ReadCloser.Close() +} + +func TestDial_ErrorResponseBodyCapture_DefaultAndCustom(t *testing.T) { + t.Parallel() + + longBody := strings.Repeat("x", 4096) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + io.WriteString(w, longBody) + })) + defer s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + // Default behavior (zero value options): capture up to 1024 bytes + _, resp, err := websocket.Dial(ctx, s.URL, nil) + assert.Error(t, err) + if resp == nil { + t.Fatal("expected non-nil resp") + } + assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode) + + b, rerr := io.ReadAll(resp.Body) + assert.Success(t, rerr) + if len(b) > 1024 { + t.Fatalf("expected captured body length <= 1024, got %d", len(b)) + } + if exp := longBody[:len(b)]; string(b) != exp { + t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b)) + } + + // Custom limit (>0) + limit := 200 + _, resp, err = websocket.Dial(ctx, s.URL, &websocket.DialOptions{MaxErrorResponseBodyBytes: limit}) + assert.Error(t, err) + if resp == nil { + t.Fatal("expected non-nil resp") + } + assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode) + + b, rerr = io.ReadAll(resp.Body) + assert.Success(t, rerr) + if len(b) > limit { + t.Fatalf("expected captured body length <= %d, got %d", limit, len(b)) + } + if exp := longBody[:len(b)]; string(b) != exp { + t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b)) + } +} + +func TestDial_ErrorResponseBodyCapture_Disabled_NoBodyWithClose(t *testing.T) { + t.Parallel() + + closed := false + rt := func(r *http.Request) (*http.Response, error) { + // Return a long body and a non-101 status to trigger error path. + return &http.Response{ + StatusCode: http.StatusForbidden, + Body: trackingBodyDialTest{io.NopCloser(strings.NewReader(strings.Repeat("y", 4096))), &closed}, + }, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, resp, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(rt), + MaxErrorResponseBodyBytes: -1, + }) + assert.Error(t, err) + if resp == nil { + t.Fatal("expected non-nil resp") + } + assert.Equal(t, "StatusCode", http.StatusForbidden, resp.StatusCode) + if resp.Body != nil { + // If any body is present, ensure it's empty. + b, rerr := io.ReadAll(resp.Body) + assert.Success(t, rerr) + if len(b) != 0 { + t.Fatalf("expected no body bytes when capture disabled, got %d", len(b)) + } + } + if !closed { + t.Fatal("expected original body to be closed") + } +} + +func TestDial_ErrorResponseBodyCapture_NilBody(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + // No body returned; ensure Dial does not panic when attempting capture. + return &http.Response{ + StatusCode: http.StatusForbidden, + Body: nil, + }, nil + } + + _, resp, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ + HTTPClient: mockHTTPClient(rt), + }) + assert.Error(t, err) + if resp == nil { + t.Fatal("expected non-nil resp") + } + assert.Equal(t, "StatusCode", http.StatusForbidden, resp.StatusCode) + if resp.Body == nil { + return + } + b, rerr := io.ReadAll(resp.Body) + assert.Success(t, rerr) + if len(b) != 0 { + t.Fatalf("expected empty body when original body is nil, got %d bytes", len(b)) + } +}