Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw
return r.Result, nil
}

// RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info.
// RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info.
func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) {
i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessa
return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb)
}

// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON.
// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info.
func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) {
return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb)
}
Expand Down
4 changes: 4 additions & 0 deletions go/core/schemas.config
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ SpanStatus omit
TimeEvent omit
TimeEventAnnotation omit
TraceData omit
SpanStartEvent omit
SpanEndEvent omit
SpanEventBase omit
TraceEvent omit

GenerationCommonConfig.maxOutputTokens type int
GenerationCommonConfig.topK type int
Expand Down
22 changes: 22 additions & 0 deletions go/core/tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ type SpanMetadata struct {

// RunInNewSpan runs f on input in a new span with the provided metadata.
// The metadata contains all span configuration including name, type, labels, etc.
// If a telemetry callback was set on the context via WithTelemetryCallback,
// it will be called with the trace ID and span ID as soon as the span is created.
func RunInNewSpan[I, O any](
ctx context.Context,
metadata *SpanMetadata,
Expand Down Expand Up @@ -239,6 +241,12 @@ func RunInNewSpan[I, O any](
TraceID: span.SpanContext().TraceID().String(),
SpanID: span.SpanContext().SpanID().String(),
}

// Fire telemetry callback immediately if one was set on the context
if cb := telemetryCallback(ctx); cb != nil {
cb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID)
}

defer span.End()
defer func() { span.SetAttributes(sm.attributes()...) }()
ctx = spanMetaKey.NewContext(ctx, sm)
Expand Down Expand Up @@ -371,6 +379,20 @@ func (sm *spanMetadata) attributes() []attribute.KeyValue {
// spanMetaKey is for storing spanMetadatas in a context.
var spanMetaKey = base.NewContextKey[*spanMetadata]()

// telemetryCbKey is the context key for telemetry callbacks.
var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]()

// WithTelemetryCallback returns a context with the telemetry callback attached.
// Used by the reflection server to pass callbacks to actions.
func WithTelemetryCallback(ctx context.Context, cb func(traceID, spanID string)) context.Context {
return telemetryCbKey.NewContext(ctx, cb)
}

// telemetryCallback retrieves the telemetry callback from context, or nil if not set.
func telemetryCallback(ctx context.Context) func(traceID, spanID string) {
return telemetryCbKey.FromContext(ctx)
}

// SpanPath returns the path as recorded in the current span metadata.
func SpanPath(ctx context.Context) string {
return spanMetaKey.FromContext(ctx).Path
Expand Down
216 changes: 195 additions & 21 deletions go/genkit/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package genkit
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
Expand All @@ -28,6 +29,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/firebase/genkit/go/core"
Expand All @@ -52,7 +54,46 @@ type runtimeFileData struct {
// reflectionServer encapsulates everything needed to serve the Reflection API.
type reflectionServer struct {
*http.Server
RuntimeFilePath string // Path to the runtime file that was written at startup.
RuntimeFilePath string // Path to the runtime file that was written at startup.
activeActions *activeActionsMap // Tracks active actions for cancellation support.
}

// activeAction represents an in-flight action that can be cancelled.
type activeAction struct {
cancel context.CancelFunc
startTime time.Time
traceID string
}

// activeActionsMap safely manages active actions.
type activeActionsMap struct {
mu sync.RWMutex
actions map[string]*activeAction
}

func newActiveActionsMap() *activeActionsMap {
return &activeActionsMap{
actions: make(map[string]*activeAction),
}
}

func (m *activeActionsMap) Set(traceID string, action *activeAction) {
m.mu.Lock()
defer m.mu.Unlock()
m.actions[traceID] = action
}

func (m *activeActionsMap) Get(traceID string) (*activeAction, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
action, ok := m.actions[traceID]
return action, ok
}

func (m *activeActionsMap) Delete(traceID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.actions, traceID)
}

func (s *reflectionServer) runtimeID() string {
Expand Down Expand Up @@ -102,6 +143,7 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s
Server: &http.Server{
Addr: addr,
},
activeActions: newActiveActionsMap(),
}
s.Handler = serveMux(g, s)

Expand Down Expand Up @@ -258,8 +300,9 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("GET /api/actions", wrapReflectionHandler(handleListActions(g)))
mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g)))
mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions)))
mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify()))
mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions)))
return mux
}

Expand Down Expand Up @@ -290,7 +333,7 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error)

// handleRunAction looks up an action by name in the registry, runs it with the
// provided JSON input, and writes back the JSON-marshaled request.
func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) error {
func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

Expand All @@ -312,11 +355,54 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err

logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream)

// Create cancellable context for this action
actionCtx, cancel := context.WithCancel(ctx)
defer cancel()

// Track whether headers have been sent
headersSent := false
var callbackTraceID string // Trace ID captured from telemetry callback for early header sending
var mu sync.Mutex

// Set up telemetry callback to capture and send trace ID early
// This is used for BOTH streaming and non-streaming to match JS behavior
telemetryCb := func(tid string, sid string) {
mu.Lock()
defer mu.Unlock()

if !headersSent {
callbackTraceID = tid

// Track active action for cancellation
activeActions.Set(callbackTraceID, &activeAction{
cancel: cancel,
startTime: time.Now(),
traceID: callbackTraceID,
})

// Send headers immediately with trace ID
w.Header().Set("X-Genkit-Trace-Id", callbackTraceID)
w.Header().Set("X-Genkit-Span-Id", sid)
w.Header().Set("X-Genkit-Version", "go/"+internal.Version)

if stream {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
} else {
w.Header().Set("Content-Type", "application/json")
}

w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
headersSent = true
}
}

// Set up streaming callback if needed
var cb streamingCallback[json.RawMessage]
if stream {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
cb = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
if err != nil {
Expand All @@ -334,35 +420,119 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err
json.Unmarshal(body.Context, &contextMap)
}

resp, err := runAction(ctx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap)
// Attach telemetry callback to context so action can invoke it when span is created
actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb)
resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap)

// Clean up active action using the trace ID from response
if resp != nil && resp.Telemetry.TraceID != "" {
activeActions.Delete(resp.Telemetry.TraceID)
}

if err != nil {
if stream {
refErr := core.ToReflectionError(err)
refErr.Details.TraceID = &resp.Telemetry.TraceID
reflectErr, err := json.Marshal(refErr)
if err != nil {
return err
// Check if context was cancelled
if errors.Is(err, context.Canceled) {
// Use gRPC CANCELLED code (1) in JSON body to match TypeScript behavior
var traceIDPtr *string
if resp != nil && resp.Telemetry.TraceID != "" {
traceIDPtr = &resp.Telemetry.TraceID
}
errResp := errorResponse{
Error: core.ReflectionError{
Code: core.CodeCancelled, // gRPC CANCELLED = 1
Message: "Action was cancelled",
Details: &core.ReflectionErrorDetails{
TraceID: traceIDPtr,
},
},
}

_, err = fmt.Fprintf(w, "{\"error\": %s }", reflectErr)
if err != nil {
return err
if stream {
// For streaming, write error as final chunk
json.NewEncoder(w).Encode(errResp)
} else {
// For non-streaming, return error response
if !headersSent {
w.WriteHeader(http.StatusOK) // Match TS: response.status(200).json(...)
}
json.NewEncoder(w).Encode(errResp)
}
return nil
}

if f, ok := w.(http.Flusher); ok {
f.Flush()
// Handle other errors
if stream {
refErr := core.ToReflectionError(err)
if resp != nil && resp.Telemetry.TraceID != "" {
refErr.Details.TraceID = &resp.Telemetry.TraceID
}

json.NewEncoder(w).Encode(errorResponse{Error: refErr})
return nil
}

// Non-streaming error
errorResponse := core.ToReflectionError(err)
if resp != nil {
if resp != nil && resp.Telemetry.TraceID != "" {
errorResponse.Details.TraceID = &resp.Telemetry.TraceID
}
w.WriteHeader(errorResponse.Code)

if !headersSent {
w.WriteHeader(errorResponse.Code)
}
return writeJSON(ctx, w, errorResponse)
}

return writeJSON(ctx, w, resp)
// Success case
if stream {
// For streaming, write the final chunk with result and telemetry
// This matches JS: response.write(JSON.stringify({result, telemetry}))
finalResponse := runActionResponse{
Result: resp.Result,
Telemetry: telemetry{TraceID: resp.Telemetry.TraceID},
}
json.NewEncoder(w).Encode(finalResponse)
} else {
// For non-streaming, headers were already sent via telemetry callback
// Response already includes telemetry.traceId in body
return writeJSON(ctx, w, resp)
}

return nil
}
}

// handleCancelAction cancels an in-flight action by trace ID.
func handleCancelAction(activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
var body struct {
TraceID string `json:"traceId"`
}

defer r.Body.Close()
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return core.NewError(core.INVALID_ARGUMENT, err.Error())
}

if body.TraceID == "" {
return core.NewError(core.INVALID_ARGUMENT, "traceId is required")
}

action, exists := activeActions.Get(body.TraceID)
if !exists {
w.WriteHeader(http.StatusNotFound)
return writeJSON(r.Context(), w, map[string]string{
"error": "Action not found or already completed",
})
}

// Cancel the action's context
action.cancel()
activeActions.Delete(body.TraceID)

return writeJSON(r.Context(), w, map[string]string{
"message": "Action cancelled",
})
}
}

Expand Down Expand Up @@ -462,6 +632,10 @@ type telemetry struct {
TraceID string `json:"traceId"`
}

type errorResponse struct {
Error core.ReflectionError `json:"error"`
}

func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) {
action := g.reg.ResolveAction(key)
if action == nil {
Expand Down
Loading
Loading