Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ type DialOptions struct {
// for CompressionContextTakeover.
CompressionThreshold int

// MaxErrorResponseBodyBytes controls how many bytes of the HTTP response body
// are captured and made available via resp.Body when the WebSocket handshake
// fails (i.e. Dial returns a non-nil error after receiving an HTTP response).
//
// Semantics:
// 0 => preserve current behavior and capture up to 1024 bytes (default)
// >0 => capture up to that many bytes
// <0 => do not capture any bytes; resp.Body will remain nil on error
//
// Regardless of this setting, the original HTTP response body is always closed.
MaxErrorResponseBodyBytes int

// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
Expand Down Expand Up @@ -110,7 +122,8 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
// By default, up to the first 1024 bytes of the body are available; this limit
// can be adjusted via DialOptions.MaxErrorResponseBodyBytes.
//
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
Expand Down Expand Up @@ -147,9 +160,20 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
if err != nil && respBody != nil {
// Capture a limited portion of the response body for easier debugging,
// following the limit configured by MaxErrorResponseBodyBytes.
limit := opts.MaxErrorResponseBodyBytes
if limit == 0 {
limit = 1024
}
if limit < 0 {
// Do not capture any body bytes; ensure original body is closed.
respBody.Close()
return
}

r := io.LimitReader(respBody, int64(limit))

timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
Expand Down
131 changes: 131 additions & 0 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,134 @@ func TestDialViaProxy(t *testing.T) {
assertEcho(t, ctx, c)
assertClose(t, c)
}

// Additional tests for error response body capture behavior.
// A tracking body to verify Close is called when capture is disabled.
type trackingBodyDialTest struct {
io.ReadCloser
closed *bool
}

func (tb trackingBodyDialTest) Close() error {
*tb.closed = true
return tb.ReadCloser.Close()
}

func TestDial_ErrorResponseBodyCapture_DefaultAndCustom(t *testing.T) {
t.Parallel()

longBody := strings.Repeat("x", 4096)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
io.WriteString(w, longBody)
}))
defer s.Close()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

// Default behavior (zero value options): capture up to 1024 bytes
_, resp, err := websocket.Dial(ctx, s.URL, nil)
assert.Error(t, err)
if resp == nil {
t.Fatal("expected non-nil resp")
}
assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode)

b, rerr := io.ReadAll(resp.Body)
assert.Success(t, rerr)
if len(b) > 1024 {
t.Fatalf("expected captured body length <= 1024, got %d", len(b))
}
if exp := longBody[:len(b)]; string(b) != exp {
t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b))
}

// Custom limit (>0)
limit := 200
_, resp, err = websocket.Dial(ctx, s.URL, &websocket.DialOptions{MaxErrorResponseBodyBytes: limit})
assert.Error(t, err)
if resp == nil {
t.Fatal("expected non-nil resp")
}
assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode)

b, rerr = io.ReadAll(resp.Body)
assert.Success(t, rerr)
if len(b) > limit {
t.Fatalf("expected captured body length <= %d, got %d", limit, len(b))
}
if exp := longBody[:len(b)]; string(b) != exp {
t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b))
}
}

func TestDial_ErrorResponseBodyCapture_Disabled_NoBodyWithClose(t *testing.T) {
t.Parallel()

closed := false
rt := func(r *http.Request) (*http.Response, error) {
// Return a long body and a non-101 status to trigger error path.
return &http.Response{
StatusCode: http.StatusForbidden,
Body: trackingBodyDialTest{io.NopCloser(strings.NewReader(strings.Repeat("y", 4096))), &closed},
}, nil
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

_, resp, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
MaxErrorResponseBodyBytes: -1,
})
assert.Error(t, err)
if resp == nil {
t.Fatal("expected non-nil resp")
}
assert.Equal(t, "StatusCode", http.StatusForbidden, resp.StatusCode)
if resp.Body != nil {
// If any body is present, ensure it's empty.
b, rerr := io.ReadAll(resp.Body)
assert.Success(t, rerr)
if len(b) != 0 {
t.Fatalf("expected no body bytes when capture disabled, got %d", len(b))
}
}
if !closed {
t.Fatal("expected original body to be closed")
}
}

func TestDial_ErrorResponseBodyCapture_NilBody(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

rt := func(r *http.Request) (*http.Response, error) {
// No body returned; ensure Dial does not panic when attempting capture.
return &http.Response{
StatusCode: http.StatusForbidden,
Body: nil,
}, nil
}

_, resp, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Error(t, err)
if resp == nil {
t.Fatal("expected non-nil resp")
}
assert.Equal(t, "StatusCode", http.StatusForbidden, resp.StatusCode)
if resp.Body == nil {
return
}
b, rerr := io.ReadAll(resp.Body)
assert.Success(t, rerr)
if len(b) != 0 {
t.Fatalf("expected empty body when original body is nil, got %d bytes", len(b))
}
}
Loading