From 6cd1289b433a83c9f79130b524da11af841238d5 Mon Sep 17 00:00:00 2001 From: fushall <19303416+fushall@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:59:22 +0800 Subject: [PATCH 1/3] feat: implement ExtraBody support for ChatCompletionRequest --- chat.go | 3 + chat_test.go | 423 +++++++++++++++++++++++++++++++++++++++++++++++++++ client.go | 42 +++++ 3 files changed, 468 insertions(+) diff --git a/chat.go b/chat.go index c8a3e81b3..18a0b4106 100644 --- a/chat.go +++ b/chat.go @@ -280,6 +280,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Add additional JSON properties to the request + ExtraBody map[string]any `json:"extra_body,omitempty"` } type StreamOptions struct { @@ -425,6 +427,7 @@ func (c *Client) CreateChatCompletion( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraBody(request.ExtraBody), ) if err != nil { return diff --git a/chat_test.go b/chat_test.go index 514706c96..28edcbdbe 100644 --- a/chat_test.go +++ b/chat_test.go @@ -916,6 +916,429 @@ func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error return completion, nil } +// Helper functions for TestChatCompletionRequestExtraBody to reduce complexity and improve maintainability + +func createBaseChatRequest() openai.ChatCompletionRequest { + return openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } +} + +func verifyJSONContainsFields(t *testing.T, jsonStr string, expectedFields map[string]string) { + t.Helper() + for _, expected := range expectedFields { + if !strings.Contains(jsonStr, expected) { + t.Errorf("Expected JSON to contain %s, got: %s", expected, jsonStr) + } + } +} + +func verifyExtraBodyExists(t *testing.T, extraBody map[string]any) { + t.Helper() + if extraBody == nil { + t.Fatal("ExtraBody should not be nil after unmarshaling") + } +} + +func verifyStringField(t *testing.T, extraBody map[string]any, fieldName, expected string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + if value != expected { + t.Errorf("Expected %s to be '%s', got %v", fieldName, expected, value) + } +} + +func verifyFloatField(t *testing.T, extraBody map[string]any, fieldName string, expected float64) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if floatValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, floatValue) + } +} + +func verifyIntField(t *testing.T, extraBody map[string]any, fieldName string, expected int) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if int(floatValue) != expected { + t.Errorf("Expected %s to be %d, got %v", fieldName, expected, int(floatValue)) + } +} + +func verifyBoolField(t *testing.T, extraBody map[string]any, fieldName string, expected bool) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + boolValue, ok := value.(bool) + if !ok { + t.Errorf("Expected %s to be bool, got type %T", fieldName, value) + return + } + if boolValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, boolValue) + } +} + +func verifyArrayField(t *testing.T, extraBody map[string]any, fieldName string, expected []interface{}) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + arrayValue, ok := value.([]interface{}) + if !ok { + t.Errorf("Expected %s to be []interface{}, got type %T", fieldName, value) + return + } + if len(arrayValue) != len(expected) { + t.Errorf("Expected %s to have %d elements, got %d", fieldName, len(expected), len(arrayValue)) + return + } + for i, expectedVal := range expected { + if arrayValue[i] != expectedVal { + t.Errorf("%s[%d]: expected %v, got %v", fieldName, i, expectedVal, arrayValue[i]) + } + } +} + +func verifyNestedObject(t *testing.T, extraBody map[string]any, fieldName, nestedKey, expectedValue string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + objectValue, ok := value.(map[string]interface{}) + if !ok { + t.Errorf("Expected %s to be map[string]interface{}, got type %T", fieldName, value) + return + } + nestedValue, nestedExists := objectValue[nestedKey] + if !nestedExists { + t.Errorf("%s should exist in %s", nestedKey, fieldName) + return + } + if nestedValue != expectedValue { + t.Errorf("Expected %s.%s to be '%s', got %v", fieldName, nestedKey, expectedValue, nestedValue) + } +} + +func verifyDeepNesting(t *testing.T, extraBody map[string]any) { + t.Helper() + deepNesting, ok := extraBody["deep_nesting"].(map[string]interface{}) + if !ok { + t.Error("deep_nesting should be map[string]interface{}") + return + } + level1, ok := deepNesting["level1"].(map[string]interface{}) + if !ok { + t.Error("level1 should be map[string]interface{}") + return + } + level2, ok := level1["level2"].(map[string]interface{}) + if !ok { + t.Error("level2 should be map[string]interface{}") + return + } + value, ok := level2["value"].(string) + if !ok { + t.Error("deep nested value should be string") + return + } + if value != "deep_value" { + t.Errorf("Expected deep nested value to be 'deep_value', got %v", value) + } +} + +func testExtraBodySerialization(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "custom_param": "custom_value", + "numeric_param": 42, + "boolean_param": true, + "array_param": []string{"item1", "item2"}, + "object_param": map[string]any{ + "nested_key": "nested_value", + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with ExtraBody") + + // Verify JSON serialization + expectedFields := map[string]string{ + "extra_body": `"extra_body"`, + "custom_param": `"custom_param":"custom_value"`, + "numeric_param": `"numeric_param":42`, + "boolean_param": `"boolean_param":true`, + } + verifyJSONContainsFields(t, string(data), expectedFields) + + // Verify deserialization + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyStringField(t, unmarshaled.ExtraBody, "custom_param", "custom_value") + verifyIntField(t, unmarshaled.ExtraBody, "numeric_param", 42) + verifyBoolField(t, unmarshaled.ExtraBody, "boolean_param", true) + verifyArrayField(t, unmarshaled.ExtraBody, "array_param", []interface{}{"item1", "item2"}) + verifyNestedObject(t, unmarshaled.ExtraBody, "object_param", "nested_key", "nested_value") +} + +func testEmptyExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{} + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with empty ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Empty ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with empty ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should be nil when empty ExtraBody is omitted from JSON") + } +} + +func testNilExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = nil + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with nil ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Nil ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with nil ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should remain nil when not present in JSON") + } +} + +func testComplexDataTypes(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "float_param": 3.14159, + "negative_int": -42, + "zero_value": 0, + "empty_string": "", + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_arrays": []any{[]string{"a", "b"}, []int{1, 2, 3}}, + "mixed_array": []any{"string", 42, true, nil}, + "deep_nesting": map[string]any{ + "level1": map[string]any{ + "level2": map[string]any{ + "value": "deep_value", + }, + }, + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with complex ExtraBody") + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with complex ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyFloatField(t, unmarshaled.ExtraBody, "float_param", 3.14159) + verifyIntField(t, unmarshaled.ExtraBody, "negative_int", -42) + verifyStringField(t, unmarshaled.ExtraBody, "unicode_text", "你好世界") + verifyArrayField(t, unmarshaled.ExtraBody, "mixed_array", []interface{}{"string", float64(42), true, nil}) + verifyDeepNesting(t, unmarshaled.ExtraBody) +} + +func testInvalidJSONHandling(t *testing.T) { + t.Helper() + invalidJSON := `{"model":"gpt-4","extra_body":{"invalid_json":}}` + var req openai.ChatCompletionRequest + err := json.Unmarshal([]byte(invalidJSON), &req) + if err == nil { + t.Error("Expected error when unmarshaling invalid JSON, but got nil") + } +} + +func testExtraBodyFieldConflicts(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.MaxTokens = 100 + req.ExtraBody = map[string]any{ + "model": "should-not-override", + "max_tokens": 9999, + "custom_field": "custom_value", + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with field conflicts in ExtraBody") + + var jsonMap map[string]any + err = json.Unmarshal(data, &jsonMap) + checks.NoError(t, err, "Failed to unmarshal JSON to generic map") + + if jsonMap["model"] != "gpt-4" { + t.Errorf("Standard model field should be 'gpt-4', got %v", jsonMap["model"]) + } + + maxTokens, ok := jsonMap["max_tokens"].(float64) + if !ok || int(maxTokens) != 100 { + t.Errorf("Standard max_tokens field should be 100, got %v", jsonMap["max_tokens"]) + } + + extraBody, ok := jsonMap["extra_body"].(map[string]interface{}) + if !ok { + t.Error("ExtraBody should be present in JSON") + return + } + customField, ok := extraBody["custom_field"].(string) + if !ok || customField != "custom_value" { + t.Errorf("Expected custom_field to be 'custom_value', got %v", customField) + } +} + +func TestChatCompletionRequestExtraBody(t *testing.T) { + t.Run("ExtraBodySerialization", testExtraBodySerialization) + t.Run("EmptyExtraBody", testEmptyExtraBody) + t.Run("NilExtraBody", testNilExtraBody) + t.Run("ComplexDataTypes", testComplexDataTypes) + t.Run("InvalidJSONHandling", testInvalidJSONHandling) + t.Run("ExtraBodyFieldConflicts", testExtraBodyFieldConflicts) +} + +func TestChatCompletionWithExtraBody(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Set up a handler that verifies ExtraBody fields are merged into the request body + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + http.Error(w, "Failed to parse request body", http.StatusInternalServerError) + return + } + + // Verify that ExtraBody fields are merged at the top level + if reqBody["custom_parameter"] != "test_value" { + http.Error(w, "ExtraBody custom_parameter not found in request", http.StatusBadRequest) + return + } + if reqBody["additional_config"] != true { + http.Error(w, "ExtraBody additional_config not found in request", http.StatusBadRequest) + return + } + + // Verify standard fields are still present + if reqBody["model"] != "gpt-4" { + http.Error(w, "Standard model field not found", http.StatusBadRequest) + return + } + + // Return a mock response + res := openai.ChatCompletionResponse{ + ID: "test-id", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "gpt-4", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: "Hello! I received your message with extra parameters.", + }, + FinishReason: openai.FinishReasonStop, + }, + }, + Usage: openai.Usage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + }, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(res) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + }) + + // Test ChatCompletion with ExtraBody + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + ExtraBody: map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": 123, + "array_setting": []string{"option1", "option2"}, + }, + }) + + checks.NoError(t, err, "CreateChatCompletion with ExtraBody should not fail") +} + func TestFinishReason(t *testing.T) { c := &openai.ChatCompletionChoice{ FinishReason: openai.FinishReasonNull, diff --git a/client.go b/client.go index cef375348..edb66ed00 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,48 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + if len(extraBody) == 0 { + return // No extra body to merge + } + + // Check if args.body is already a map[string]any + if bodyMap, ok := args.body.(map[string]any); ok { + // If it's already a map[string]any, directly add extraBody fields + for key, value := range extraBody { + bodyMap[key] = value + } + return + } + + // If args.body is a struct, convert it to map[string]any first + if args.body != nil { + var err error + var jsonBytes []byte + // Marshal the struct to JSON bytes + jsonBytes, err = json.Marshal(args.body) + if err != nil { + return // If marshaling fails, skip merging ExtraBody + } + + // Unmarshal JSON bytes to map[string]any + var bodyMap map[string]any + if err = json.Unmarshal(jsonBytes, &bodyMap); err != nil { + return // If unmarshaling fails, skip merging ExtraBody + } + + // Merge ExtraBody fields into the map + for key, value := range extraBody { + bodyMap[key] = value + } + + // Replace args.body with the merged map + args.body = bodyMap + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) From 7441e24b06fb9db3005610187885d4c6dcbeb7cc Mon Sep 17 00:00:00 2001 From: fushall Date: Wed, 2 Jul 2025 10:38:21 +0800 Subject: [PATCH 2/3] Revert "feat: implement ExtraBody support for ChatCompletionRequest" This reverts commit 6cd1289b433a83c9f79130b524da11af841238d5. --- chat.go | 3 - chat_test.go | 423 --------------------------------------------------- client.go | 42 ----- 3 files changed, 468 deletions(-) diff --git a/chat.go b/chat.go index 18a0b4106..c8a3e81b3 100644 --- a/chat.go +++ b/chat.go @@ -280,8 +280,6 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` - // Add additional JSON properties to the request - ExtraBody map[string]any `json:"extra_body,omitempty"` } type StreamOptions struct { @@ -427,7 +425,6 @@ func (c *Client) CreateChatCompletion( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), - withExtraBody(request.ExtraBody), ) if err != nil { return diff --git a/chat_test.go b/chat_test.go index 28edcbdbe..514706c96 100644 --- a/chat_test.go +++ b/chat_test.go @@ -916,429 +916,6 @@ func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error return completion, nil } -// Helper functions for TestChatCompletionRequestExtraBody to reduce complexity and improve maintainability - -func createBaseChatRequest() openai.ChatCompletionRequest { - return openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } -} - -func verifyJSONContainsFields(t *testing.T, jsonStr string, expectedFields map[string]string) { - t.Helper() - for _, expected := range expectedFields { - if !strings.Contains(jsonStr, expected) { - t.Errorf("Expected JSON to contain %s, got: %s", expected, jsonStr) - } - } -} - -func verifyExtraBodyExists(t *testing.T, extraBody map[string]any) { - t.Helper() - if extraBody == nil { - t.Fatal("ExtraBody should not be nil after unmarshaling") - } -} - -func verifyStringField(t *testing.T, extraBody map[string]any, fieldName, expected string) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - if value != expected { - t.Errorf("Expected %s to be '%s', got %v", fieldName, expected, value) - } -} - -func verifyFloatField(t *testing.T, extraBody map[string]any, fieldName string, expected float64) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - floatValue, ok := value.(float64) - if !ok { - t.Errorf("Expected %s to be float64, got type %T", fieldName, value) - return - } - if floatValue != expected { - t.Errorf("Expected %s to be %v, got %v", fieldName, expected, floatValue) - } -} - -func verifyIntField(t *testing.T, extraBody map[string]any, fieldName string, expected int) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - floatValue, ok := value.(float64) - if !ok { - t.Errorf("Expected %s to be float64, got type %T", fieldName, value) - return - } - if int(floatValue) != expected { - t.Errorf("Expected %s to be %d, got %v", fieldName, expected, int(floatValue)) - } -} - -func verifyBoolField(t *testing.T, extraBody map[string]any, fieldName string, expected bool) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - boolValue, ok := value.(bool) - if !ok { - t.Errorf("Expected %s to be bool, got type %T", fieldName, value) - return - } - if boolValue != expected { - t.Errorf("Expected %s to be %v, got %v", fieldName, expected, boolValue) - } -} - -func verifyArrayField(t *testing.T, extraBody map[string]any, fieldName string, expected []interface{}) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - arrayValue, ok := value.([]interface{}) - if !ok { - t.Errorf("Expected %s to be []interface{}, got type %T", fieldName, value) - return - } - if len(arrayValue) != len(expected) { - t.Errorf("Expected %s to have %d elements, got %d", fieldName, len(expected), len(arrayValue)) - return - } - for i, expectedVal := range expected { - if arrayValue[i] != expectedVal { - t.Errorf("%s[%d]: expected %v, got %v", fieldName, i, expectedVal, arrayValue[i]) - } - } -} - -func verifyNestedObject(t *testing.T, extraBody map[string]any, fieldName, nestedKey, expectedValue string) { - t.Helper() - value, exists := extraBody[fieldName] - if !exists { - t.Errorf("%s should exist in ExtraBody", fieldName) - return - } - objectValue, ok := value.(map[string]interface{}) - if !ok { - t.Errorf("Expected %s to be map[string]interface{}, got type %T", fieldName, value) - return - } - nestedValue, nestedExists := objectValue[nestedKey] - if !nestedExists { - t.Errorf("%s should exist in %s", nestedKey, fieldName) - return - } - if nestedValue != expectedValue { - t.Errorf("Expected %s.%s to be '%s', got %v", fieldName, nestedKey, expectedValue, nestedValue) - } -} - -func verifyDeepNesting(t *testing.T, extraBody map[string]any) { - t.Helper() - deepNesting, ok := extraBody["deep_nesting"].(map[string]interface{}) - if !ok { - t.Error("deep_nesting should be map[string]interface{}") - return - } - level1, ok := deepNesting["level1"].(map[string]interface{}) - if !ok { - t.Error("level1 should be map[string]interface{}") - return - } - level2, ok := level1["level2"].(map[string]interface{}) - if !ok { - t.Error("level2 should be map[string]interface{}") - return - } - value, ok := level2["value"].(string) - if !ok { - t.Error("deep nested value should be string") - return - } - if value != "deep_value" { - t.Errorf("Expected deep nested value to be 'deep_value', got %v", value) - } -} - -func testExtraBodySerialization(t *testing.T) { - t.Helper() - req := createBaseChatRequest() - req.ExtraBody = map[string]any{ - "custom_param": "custom_value", - "numeric_param": 42, - "boolean_param": true, - "array_param": []string{"item1", "item2"}, - "object_param": map[string]any{ - "nested_key": "nested_value", - }, - } - - data, err := json.Marshal(req) - checks.NoError(t, err, "Failed to marshal request with ExtraBody") - - // Verify JSON serialization - expectedFields := map[string]string{ - "extra_body": `"extra_body"`, - "custom_param": `"custom_param":"custom_value"`, - "numeric_param": `"numeric_param":42`, - "boolean_param": `"boolean_param":true`, - } - verifyJSONContainsFields(t, string(data), expectedFields) - - // Verify deserialization - var unmarshaled openai.ChatCompletionRequest - err = json.Unmarshal(data, &unmarshaled) - checks.NoError(t, err, "Failed to unmarshal request with ExtraBody") - - verifyExtraBodyExists(t, unmarshaled.ExtraBody) - verifyStringField(t, unmarshaled.ExtraBody, "custom_param", "custom_value") - verifyIntField(t, unmarshaled.ExtraBody, "numeric_param", 42) - verifyBoolField(t, unmarshaled.ExtraBody, "boolean_param", true) - verifyArrayField(t, unmarshaled.ExtraBody, "array_param", []interface{}{"item1", "item2"}) - verifyNestedObject(t, unmarshaled.ExtraBody, "object_param", "nested_key", "nested_value") -} - -func testEmptyExtraBody(t *testing.T) { - t.Helper() - req := createBaseChatRequest() - req.ExtraBody = map[string]any{} - - data, err := json.Marshal(req) - checks.NoError(t, err, "Failed to marshal request with empty ExtraBody") - - if strings.Contains(string(data), `"extra_body"`) { - t.Error("Empty ExtraBody should be omitted from JSON") - } - - var unmarshaled openai.ChatCompletionRequest - err = json.Unmarshal(data, &unmarshaled) - checks.NoError(t, err, "Failed to unmarshal request with empty ExtraBody") - - if unmarshaled.ExtraBody != nil { - t.Error("ExtraBody should be nil when empty ExtraBody is omitted from JSON") - } -} - -func testNilExtraBody(t *testing.T) { - t.Helper() - req := createBaseChatRequest() - req.ExtraBody = nil - - data, err := json.Marshal(req) - checks.NoError(t, err, "Failed to marshal request with nil ExtraBody") - - if strings.Contains(string(data), `"extra_body"`) { - t.Error("Nil ExtraBody should be omitted from JSON") - } - - var unmarshaled openai.ChatCompletionRequest - err = json.Unmarshal(data, &unmarshaled) - checks.NoError(t, err, "Failed to unmarshal request with nil ExtraBody") - - if unmarshaled.ExtraBody != nil { - t.Error("ExtraBody should remain nil when not present in JSON") - } -} - -func testComplexDataTypes(t *testing.T) { - t.Helper() - req := createBaseChatRequest() - req.ExtraBody = map[string]any{ - "float_param": 3.14159, - "negative_int": -42, - "zero_value": 0, - "empty_string": "", - "unicode_text": "你好世界", - "special_chars": "!@#$%^&*()", - "nested_arrays": []any{[]string{"a", "b"}, []int{1, 2, 3}}, - "mixed_array": []any{"string", 42, true, nil}, - "deep_nesting": map[string]any{ - "level1": map[string]any{ - "level2": map[string]any{ - "value": "deep_value", - }, - }, - }, - } - - data, err := json.Marshal(req) - checks.NoError(t, err, "Failed to marshal request with complex ExtraBody") - - var unmarshaled openai.ChatCompletionRequest - err = json.Unmarshal(data, &unmarshaled) - checks.NoError(t, err, "Failed to unmarshal request with complex ExtraBody") - - verifyExtraBodyExists(t, unmarshaled.ExtraBody) - verifyFloatField(t, unmarshaled.ExtraBody, "float_param", 3.14159) - verifyIntField(t, unmarshaled.ExtraBody, "negative_int", -42) - verifyStringField(t, unmarshaled.ExtraBody, "unicode_text", "你好世界") - verifyArrayField(t, unmarshaled.ExtraBody, "mixed_array", []interface{}{"string", float64(42), true, nil}) - verifyDeepNesting(t, unmarshaled.ExtraBody) -} - -func testInvalidJSONHandling(t *testing.T) { - t.Helper() - invalidJSON := `{"model":"gpt-4","extra_body":{"invalid_json":}}` - var req openai.ChatCompletionRequest - err := json.Unmarshal([]byte(invalidJSON), &req) - if err == nil { - t.Error("Expected error when unmarshaling invalid JSON, but got nil") - } -} - -func testExtraBodyFieldConflicts(t *testing.T) { - t.Helper() - req := createBaseChatRequest() - req.MaxTokens = 100 - req.ExtraBody = map[string]any{ - "model": "should-not-override", - "max_tokens": 9999, - "custom_field": "custom_value", - } - - data, err := json.Marshal(req) - checks.NoError(t, err, "Failed to marshal request with field conflicts in ExtraBody") - - var jsonMap map[string]any - err = json.Unmarshal(data, &jsonMap) - checks.NoError(t, err, "Failed to unmarshal JSON to generic map") - - if jsonMap["model"] != "gpt-4" { - t.Errorf("Standard model field should be 'gpt-4', got %v", jsonMap["model"]) - } - - maxTokens, ok := jsonMap["max_tokens"].(float64) - if !ok || int(maxTokens) != 100 { - t.Errorf("Standard max_tokens field should be 100, got %v", jsonMap["max_tokens"]) - } - - extraBody, ok := jsonMap["extra_body"].(map[string]interface{}) - if !ok { - t.Error("ExtraBody should be present in JSON") - return - } - customField, ok := extraBody["custom_field"].(string) - if !ok || customField != "custom_value" { - t.Errorf("Expected custom_field to be 'custom_value', got %v", customField) - } -} - -func TestChatCompletionRequestExtraBody(t *testing.T) { - t.Run("ExtraBodySerialization", testExtraBodySerialization) - t.Run("EmptyExtraBody", testEmptyExtraBody) - t.Run("NilExtraBody", testNilExtraBody) - t.Run("ComplexDataTypes", testComplexDataTypes) - t.Run("InvalidJSONHandling", testInvalidJSONHandling) - t.Run("ExtraBodyFieldConflicts", testExtraBodyFieldConflicts) -} - -func TestChatCompletionWithExtraBody(t *testing.T) { - client, server, teardown := setupOpenAITestServer() - defer teardown() - - // Set up a handler that verifies ExtraBody fields are merged into the request body - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - var reqBody map[string]any - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusInternalServerError) - return - } - - err = json.Unmarshal(body, &reqBody) - if err != nil { - http.Error(w, "Failed to parse request body", http.StatusInternalServerError) - return - } - - // Verify that ExtraBody fields are merged at the top level - if reqBody["custom_parameter"] != "test_value" { - http.Error(w, "ExtraBody custom_parameter not found in request", http.StatusBadRequest) - return - } - if reqBody["additional_config"] != true { - http.Error(w, "ExtraBody additional_config not found in request", http.StatusBadRequest) - return - } - - // Verify standard fields are still present - if reqBody["model"] != "gpt-4" { - http.Error(w, "Standard model field not found", http.StatusBadRequest) - return - } - - // Return a mock response - res := openai.ChatCompletionResponse{ - ID: "test-id", - Object: "chat.completion", - Created: time.Now().Unix(), - Model: "gpt-4", - Choices: []openai.ChatCompletionChoice{ - { - Index: 0, - Message: openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: "Hello! I received your message with extra parameters.", - }, - FinishReason: openai.FinishReasonStop, - }, - }, - Usage: openai.Usage{ - PromptTokens: 10, - CompletionTokens: 20, - TotalTokens: 30, - }, - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(res) - if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - return - } - }) - - // Test ChatCompletion with ExtraBody - _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: "gpt-4", - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - ExtraBody: map[string]any{ - "custom_parameter": "test_value", - "additional_config": true, - "numeric_setting": 123, - "array_setting": []string{"option1", "option2"}, - }, - }) - - checks.NoError(t, err, "CreateChatCompletion with ExtraBody should not fail") -} - func TestFinishReason(t *testing.T) { c := &openai.ChatCompletionChoice{ FinishReason: openai.FinishReasonNull, diff --git a/client.go b/client.go index edb66ed00..cef375348 100644 --- a/client.go +++ b/client.go @@ -84,48 +84,6 @@ func withBody(body any) requestOption { } } -func withExtraBody(extraBody map[string]any) requestOption { - return func(args *requestOptions) { - if len(extraBody) == 0 { - return // No extra body to merge - } - - // Check if args.body is already a map[string]any - if bodyMap, ok := args.body.(map[string]any); ok { - // If it's already a map[string]any, directly add extraBody fields - for key, value := range extraBody { - bodyMap[key] = value - } - return - } - - // If args.body is a struct, convert it to map[string]any first - if args.body != nil { - var err error - var jsonBytes []byte - // Marshal the struct to JSON bytes - jsonBytes, err = json.Marshal(args.body) - if err != nil { - return // If marshaling fails, skip merging ExtraBody - } - - // Unmarshal JSON bytes to map[string]any - var bodyMap map[string]any - if err = json.Unmarshal(jsonBytes, &bodyMap); err != nil { - return // If unmarshaling fails, skip merging ExtraBody - } - - // Merge ExtraBody fields into the map - for key, value := range extraBody { - bodyMap[key] = value - } - - // Replace args.body with the merged map - args.body = bodyMap - } - } -} - func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) From 9c5d0b6c7b4cf85ed941c456e91e3a1e4345b9b1 Mon Sep 17 00:00:00 2001 From: fushall <19303416+fushall@users.noreply.github.com> Date: Wed, 2 Jul 2025 10:45:31 +0800 Subject: [PATCH 3/3] feat: implement ExtraBody support for ChatCompletionRequest --- chat.go | 3 + chat_stream.go | 1 + chat_stream_test.go | 483 ++++++++++++++++++++++++++++++++++++++++++++ chat_test.go | 423 ++++++++++++++++++++++++++++++++++++++ client.go | 42 ++++ 5 files changed, 952 insertions(+) diff --git a/chat.go b/chat.go index c8a3e81b3..18a0b4106 100644 --- a/chat.go +++ b/chat.go @@ -280,6 +280,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Add additional JSON properties to the request + ExtraBody map[string]any `json:"extra_body,omitempty"` } type StreamOptions struct { @@ -425,6 +427,7 @@ func (c *Client) CreateChatCompletion( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraBody(request.ExtraBody), ) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..89c335d65 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -96,6 +96,7 @@ func (c *Client) CreateChatCompletionStream( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraBody(request.ExtraBody), ) if err != nil { return nil, err diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..c468d9452 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1021,3 +1021,486 @@ func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) } return true } + +// Helper functions for TestCreateChatCompletionStreamExtraBody to reduce complexity and improve maintainability + +func deepEqual(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // Use reflection for deep comparison to handle maps, slices, etc. + aJSON, aErr := json.Marshal(a) + bJSON, bErr := json.Marshal(b) + if aErr != nil || bErr != nil { + return false + } + return string(aJSON) == string(bJSON) +} + +func createBaseChatStreamRequest() openai.ChatCompletionRequest { + return openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } +} + +func validateExtraBodyFields(reqBody map[string]any, expectedExtraFields map[string]any) error { + for key, expectedValue := range expectedExtraFields { + actualValue, exists := reqBody[key] + if !exists { + return fmt.Errorf("ExtraBody field %s not found in request", key) + } + + // Handle complex types comparison safely + if !deepEqual(actualValue, expectedValue) { + return fmt.Errorf("ExtraBody field %s value mismatch: expected %v, got %v", + key, expectedValue, actualValue) + } + } + return nil +} + +func validateStandardFields(reqBody map[string]any) error { + if reqBody["model"] != "gpt-4" { + return fmt.Errorf("standard model field not found") + } + + if reqBody["stream"] != true { + return fmt.Errorf("stream field should be true") + } + return nil +} + +func parseRequestBody(r *http.Request) (map[string]any, error) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + return nil, fmt.Errorf("failed to parse request body: %w", err) + } + return reqBody, nil +} + +func writeStreamingResponse(t *testing.T, w http.ResponseWriter) { + t.Helper() + w.Header().Set("Content-Type", "text/event-stream") + + responses := []string{ + `{"id":"test-1","object":"chat.completion.chunk","created":1598069254,` + + `"model":"gpt-4","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`, + `{"id":"test-2","object":"chat.completion.chunk","created":1598069255,` + + `"model":"gpt-4","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + } + + dataBytes := []byte{} + for _, response := range responses { + dataBytes = append(dataBytes, []byte("event: message\n")...) + dataBytes = append(dataBytes, []byte("data: "+response+"\n\n")...) + } + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Failed to write response: %v", err) + } +} + +func createStreamHandler(t *testing.T, expectedExtraFields map[string]any) func( + w http.ResponseWriter, r *http.Request) { + t.Helper() + return func(w http.ResponseWriter, r *http.Request) { + if expectedExtraFields == nil { + writeStreamingResponse(t, w) + return + } + + reqBody, err := parseRequestBody(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if validationErr := validateExtraBodyFields(reqBody, expectedExtraFields); validationErr != nil { + http.Error(w, validationErr.Error(), http.StatusBadRequest) + return + } + + if standardErr := validateStandardFields(reqBody); standardErr != nil { + http.Error(w, standardErr.Error(), http.StatusBadRequest) + return + } + + writeStreamingResponse(t, w) + } +} + +func verifyStreamResponse(t *testing.T, stream *openai.ChatCompletionStream, + expectedResponses []openai.ChatCompletionStreamResponse) { + t.Helper() + + if stream == nil { + t.Fatal("Stream is nil - cannot verify response") + return + } + + defer stream.Close() + + for ix, expectedResponse := range expectedResponses { + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func testStreamExtraBodyWithParameters(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + expectedExtraFields := map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": float64(123), // JSON unmarshaling converts numbers to float64 + "temperature": float64(0.7), + } + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, expectedExtraFields)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": 123, + "temperature": 0.7, + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyComplexData(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + expectedExtraFields := map[string]any{ + "array_param": []interface{}{"item1", "item2"}, + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_config": map[string]interface{}{"enabled": true, "level": float64(5)}, + "mixed_array": []interface{}{"string", float64(42), true, nil}, + "float_param": 3.14159, + "negative_int": float64(-42), + "zero_value": float64(0), + } + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, expectedExtraFields)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{ + "array_param": []string{"item1", "item2"}, + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_config": map[string]any{ + "enabled": true, + "level": 5, + }, + "mixed_array": []any{"string", 42, true, nil}, + "float_param": 3.14159, + "negative_int": -42, + "zero_value": 0, + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with complex ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyEmpty(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, nil)) + + req := createBaseChatStreamRequest() + req.ExtraBody = map[string]any{} + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with empty ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyNil(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", createStreamHandler(t, nil)) + + req := createBaseChatStreamRequest() + req.ExtraBody = nil + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with nil ExtraBody should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Created: 1598069255, + Model: "gpt-4", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func testStreamExtraBodyFieldConflicts(t *testing.T) { + t.Helper() + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Handler that verifies ExtraBody fields override standard fields + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + http.Error(w, "Failed to parse request body", http.StatusInternalServerError) + return + } + + // Verify ExtraBody fields override standard fields + if reqBody["model"] != "overridden-model" { + msg := fmt.Sprintf("Model field should be overridden to 'overridden-model', got %v", + reqBody["model"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + if reqBody["stream"] != true { + http.Error(w, fmt.Sprintf("Stream field should remain true, got %v", reqBody["stream"]), http.StatusBadRequest) + return + } + + maxTokens, ok := reqBody["max_tokens"].(float64) + if !ok || int(maxTokens) != 9999 { + msg := fmt.Sprintf("MaxTokens field should be overridden to 9999, got %v", + reqBody["max_tokens"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + // Verify custom field from ExtraBody is present at top level + if reqBody["custom_field"] != "custom_value" { + msg := fmt.Sprintf("Custom field from ExtraBody should be 'custom_value', got %v", + reqBody["custom_field"]) + http.Error(w, msg, http.StatusBadRequest) + return + } + + // Send streaming response using the overridden model name + w.Header().Set("Content-Type", "text/event-stream") + data := `{"id":"test-1","object":"chat.completion.chunk","created":1598069254,` + + `"model":"overridden-model","system_fingerprint":"fp_test",` + + `"choices":[{"index":0,"delta":{"content":"Response"},"finish_reason":"stop"}]}` + _, writeErr := w.Write([]byte("data: " + data + "\n\ndata: [DONE]\n\n")) + if writeErr != nil { + t.Errorf("Failed to write response: %v", writeErr) + } + }) + + req := createBaseChatStreamRequest() + req.MaxTokens = 100 + req.ExtraBody = map[string]any{ + "model": "overridden-model", // this should override the standard model field + "max_tokens": 9999, // this should override the standard max_tokens field + "custom_field": "custom_value", // this is a new field + } + + stream, err := client.CreateChatCompletionStream(context.Background(), req) + checks.NoError(t, err, "CreateChatCompletionStream with field overrides should not fail") + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Created: 1598069254, + Model: "overridden-model", + SystemFingerprint: "fp_test", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Response", + }, + FinishReason: "stop", + }, + }, + }, + } + + verifyStreamResponse(t, stream, expectedResponses) +} + +func TestCreateChatCompletionStreamExtraBody(t *testing.T) { + t.Run("WithParameters", testStreamExtraBodyWithParameters) + t.Run("ComplexData", testStreamExtraBodyComplexData) + t.Run("EmptyExtraBody", testStreamExtraBodyEmpty) + t.Run("NilExtraBody", testStreamExtraBodyNil) + t.Run("FieldConflicts", testStreamExtraBodyFieldConflicts) +} diff --git a/chat_test.go b/chat_test.go index 514706c96..28edcbdbe 100644 --- a/chat_test.go +++ b/chat_test.go @@ -916,6 +916,429 @@ func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error return completion, nil } +// Helper functions for TestChatCompletionRequestExtraBody to reduce complexity and improve maintainability + +func createBaseChatRequest() openai.ChatCompletionRequest { + return openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } +} + +func verifyJSONContainsFields(t *testing.T, jsonStr string, expectedFields map[string]string) { + t.Helper() + for _, expected := range expectedFields { + if !strings.Contains(jsonStr, expected) { + t.Errorf("Expected JSON to contain %s, got: %s", expected, jsonStr) + } + } +} + +func verifyExtraBodyExists(t *testing.T, extraBody map[string]any) { + t.Helper() + if extraBody == nil { + t.Fatal("ExtraBody should not be nil after unmarshaling") + } +} + +func verifyStringField(t *testing.T, extraBody map[string]any, fieldName, expected string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + if value != expected { + t.Errorf("Expected %s to be '%s', got %v", fieldName, expected, value) + } +} + +func verifyFloatField(t *testing.T, extraBody map[string]any, fieldName string, expected float64) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if floatValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, floatValue) + } +} + +func verifyIntField(t *testing.T, extraBody map[string]any, fieldName string, expected int) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + floatValue, ok := value.(float64) + if !ok { + t.Errorf("Expected %s to be float64, got type %T", fieldName, value) + return + } + if int(floatValue) != expected { + t.Errorf("Expected %s to be %d, got %v", fieldName, expected, int(floatValue)) + } +} + +func verifyBoolField(t *testing.T, extraBody map[string]any, fieldName string, expected bool) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + boolValue, ok := value.(bool) + if !ok { + t.Errorf("Expected %s to be bool, got type %T", fieldName, value) + return + } + if boolValue != expected { + t.Errorf("Expected %s to be %v, got %v", fieldName, expected, boolValue) + } +} + +func verifyArrayField(t *testing.T, extraBody map[string]any, fieldName string, expected []interface{}) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + arrayValue, ok := value.([]interface{}) + if !ok { + t.Errorf("Expected %s to be []interface{}, got type %T", fieldName, value) + return + } + if len(arrayValue) != len(expected) { + t.Errorf("Expected %s to have %d elements, got %d", fieldName, len(expected), len(arrayValue)) + return + } + for i, expectedVal := range expected { + if arrayValue[i] != expectedVal { + t.Errorf("%s[%d]: expected %v, got %v", fieldName, i, expectedVal, arrayValue[i]) + } + } +} + +func verifyNestedObject(t *testing.T, extraBody map[string]any, fieldName, nestedKey, expectedValue string) { + t.Helper() + value, exists := extraBody[fieldName] + if !exists { + t.Errorf("%s should exist in ExtraBody", fieldName) + return + } + objectValue, ok := value.(map[string]interface{}) + if !ok { + t.Errorf("Expected %s to be map[string]interface{}, got type %T", fieldName, value) + return + } + nestedValue, nestedExists := objectValue[nestedKey] + if !nestedExists { + t.Errorf("%s should exist in %s", nestedKey, fieldName) + return + } + if nestedValue != expectedValue { + t.Errorf("Expected %s.%s to be '%s', got %v", fieldName, nestedKey, expectedValue, nestedValue) + } +} + +func verifyDeepNesting(t *testing.T, extraBody map[string]any) { + t.Helper() + deepNesting, ok := extraBody["deep_nesting"].(map[string]interface{}) + if !ok { + t.Error("deep_nesting should be map[string]interface{}") + return + } + level1, ok := deepNesting["level1"].(map[string]interface{}) + if !ok { + t.Error("level1 should be map[string]interface{}") + return + } + level2, ok := level1["level2"].(map[string]interface{}) + if !ok { + t.Error("level2 should be map[string]interface{}") + return + } + value, ok := level2["value"].(string) + if !ok { + t.Error("deep nested value should be string") + return + } + if value != "deep_value" { + t.Errorf("Expected deep nested value to be 'deep_value', got %v", value) + } +} + +func testExtraBodySerialization(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "custom_param": "custom_value", + "numeric_param": 42, + "boolean_param": true, + "array_param": []string{"item1", "item2"}, + "object_param": map[string]any{ + "nested_key": "nested_value", + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with ExtraBody") + + // Verify JSON serialization + expectedFields := map[string]string{ + "extra_body": `"extra_body"`, + "custom_param": `"custom_param":"custom_value"`, + "numeric_param": `"numeric_param":42`, + "boolean_param": `"boolean_param":true`, + } + verifyJSONContainsFields(t, string(data), expectedFields) + + // Verify deserialization + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyStringField(t, unmarshaled.ExtraBody, "custom_param", "custom_value") + verifyIntField(t, unmarshaled.ExtraBody, "numeric_param", 42) + verifyBoolField(t, unmarshaled.ExtraBody, "boolean_param", true) + verifyArrayField(t, unmarshaled.ExtraBody, "array_param", []interface{}{"item1", "item2"}) + verifyNestedObject(t, unmarshaled.ExtraBody, "object_param", "nested_key", "nested_value") +} + +func testEmptyExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{} + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with empty ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Empty ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with empty ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should be nil when empty ExtraBody is omitted from JSON") + } +} + +func testNilExtraBody(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = nil + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with nil ExtraBody") + + if strings.Contains(string(data), `"extra_body"`) { + t.Error("Nil ExtraBody should be omitted from JSON") + } + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with nil ExtraBody") + + if unmarshaled.ExtraBody != nil { + t.Error("ExtraBody should remain nil when not present in JSON") + } +} + +func testComplexDataTypes(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.ExtraBody = map[string]any{ + "float_param": 3.14159, + "negative_int": -42, + "zero_value": 0, + "empty_string": "", + "unicode_text": "你好世界", + "special_chars": "!@#$%^&*()", + "nested_arrays": []any{[]string{"a", "b"}, []int{1, 2, 3}}, + "mixed_array": []any{"string", 42, true, nil}, + "deep_nesting": map[string]any{ + "level1": map[string]any{ + "level2": map[string]any{ + "value": "deep_value", + }, + }, + }, + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with complex ExtraBody") + + var unmarshaled openai.ChatCompletionRequest + err = json.Unmarshal(data, &unmarshaled) + checks.NoError(t, err, "Failed to unmarshal request with complex ExtraBody") + + verifyExtraBodyExists(t, unmarshaled.ExtraBody) + verifyFloatField(t, unmarshaled.ExtraBody, "float_param", 3.14159) + verifyIntField(t, unmarshaled.ExtraBody, "negative_int", -42) + verifyStringField(t, unmarshaled.ExtraBody, "unicode_text", "你好世界") + verifyArrayField(t, unmarshaled.ExtraBody, "mixed_array", []interface{}{"string", float64(42), true, nil}) + verifyDeepNesting(t, unmarshaled.ExtraBody) +} + +func testInvalidJSONHandling(t *testing.T) { + t.Helper() + invalidJSON := `{"model":"gpt-4","extra_body":{"invalid_json":}}` + var req openai.ChatCompletionRequest + err := json.Unmarshal([]byte(invalidJSON), &req) + if err == nil { + t.Error("Expected error when unmarshaling invalid JSON, but got nil") + } +} + +func testExtraBodyFieldConflicts(t *testing.T) { + t.Helper() + req := createBaseChatRequest() + req.MaxTokens = 100 + req.ExtraBody = map[string]any{ + "model": "should-not-override", + "max_tokens": 9999, + "custom_field": "custom_value", + } + + data, err := json.Marshal(req) + checks.NoError(t, err, "Failed to marshal request with field conflicts in ExtraBody") + + var jsonMap map[string]any + err = json.Unmarshal(data, &jsonMap) + checks.NoError(t, err, "Failed to unmarshal JSON to generic map") + + if jsonMap["model"] != "gpt-4" { + t.Errorf("Standard model field should be 'gpt-4', got %v", jsonMap["model"]) + } + + maxTokens, ok := jsonMap["max_tokens"].(float64) + if !ok || int(maxTokens) != 100 { + t.Errorf("Standard max_tokens field should be 100, got %v", jsonMap["max_tokens"]) + } + + extraBody, ok := jsonMap["extra_body"].(map[string]interface{}) + if !ok { + t.Error("ExtraBody should be present in JSON") + return + } + customField, ok := extraBody["custom_field"].(string) + if !ok || customField != "custom_value" { + t.Errorf("Expected custom_field to be 'custom_value', got %v", customField) + } +} + +func TestChatCompletionRequestExtraBody(t *testing.T) { + t.Run("ExtraBodySerialization", testExtraBodySerialization) + t.Run("EmptyExtraBody", testEmptyExtraBody) + t.Run("NilExtraBody", testNilExtraBody) + t.Run("ComplexDataTypes", testComplexDataTypes) + t.Run("InvalidJSONHandling", testInvalidJSONHandling) + t.Run("ExtraBodyFieldConflicts", testExtraBodyFieldConflicts) +} + +func TestChatCompletionWithExtraBody(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Set up a handler that verifies ExtraBody fields are merged into the request body + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + var reqBody map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + err = json.Unmarshal(body, &reqBody) + if err != nil { + http.Error(w, "Failed to parse request body", http.StatusInternalServerError) + return + } + + // Verify that ExtraBody fields are merged at the top level + if reqBody["custom_parameter"] != "test_value" { + http.Error(w, "ExtraBody custom_parameter not found in request", http.StatusBadRequest) + return + } + if reqBody["additional_config"] != true { + http.Error(w, "ExtraBody additional_config not found in request", http.StatusBadRequest) + return + } + + // Verify standard fields are still present + if reqBody["model"] != "gpt-4" { + http.Error(w, "Standard model field not found", http.StatusBadRequest) + return + } + + // Return a mock response + res := openai.ChatCompletionResponse{ + ID: "test-id", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "gpt-4", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: "Hello! I received your message with extra parameters.", + }, + FinishReason: openai.FinishReasonStop, + }, + }, + Usage: openai.Usage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + }, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(res) + if err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + }) + + // Test ChatCompletion with ExtraBody + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + ExtraBody: map[string]any{ + "custom_parameter": "test_value", + "additional_config": true, + "numeric_setting": 123, + "array_setting": []string{"option1", "option2"}, + }, + }) + + checks.NoError(t, err, "CreateChatCompletion with ExtraBody should not fail") +} + func TestFinishReason(t *testing.T) { c := &openai.ChatCompletionChoice{ FinishReason: openai.FinishReasonNull, diff --git a/client.go b/client.go index cef375348..edb66ed00 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,48 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + if len(extraBody) == 0 { + return // No extra body to merge + } + + // Check if args.body is already a map[string]any + if bodyMap, ok := args.body.(map[string]any); ok { + // If it's already a map[string]any, directly add extraBody fields + for key, value := range extraBody { + bodyMap[key] = value + } + return + } + + // If args.body is a struct, convert it to map[string]any first + if args.body != nil { + var err error + var jsonBytes []byte + // Marshal the struct to JSON bytes + jsonBytes, err = json.Marshal(args.body) + if err != nil { + return // If marshaling fails, skip merging ExtraBody + } + + // Unmarshal JSON bytes to map[string]any + var bodyMap map[string]any + if err = json.Unmarshal(jsonBytes, &bodyMap); err != nil { + return // If unmarshaling fails, skip merging ExtraBody + } + + // Merge ExtraBody fields into the map + for key, value := range extraBody { + bodyMap[key] = value + } + + // Replace args.body with the merged map + args.body = bodyMap + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType)