agent_patch
runs/2026-02-28__13-38-10__gpt-5-1-codex-mini/flux-pr-1076/flux-pr-1076.1-of-1.2026-02-28__13-38-10__gpt-5-1-codex-mini/agent-logs/agent.patch
25981 bytes
diff --git a/tmp/agent-patch-flux-pr-1076.1-of-1.2026-02-28__13-38-10__gpt-5-1-codex-mini/app/v2/pkg/engine/resolve/resolve.go b/app/v2/pkg/engine/resolve/resolve.go
index 83a5a06..6278eee 100644
--- a/tmp/agent-patch-flux-pr-1076.1-of-1.2026-02-28__13-38-10__gpt-5-1-codex-mini/app/v2/pkg/engine/resolve/resolve.go
+++ b/app/v2/pkg/engine/resolve/resolve.go
@@ -7,11 +7,8 @@ import (
"context"
"fmt"
"io"
- "sync"
"time"
- "golang.org/x/sync/semaphore"
-
"github.com/buger/jsonparser"
"github.com/pkg/errors"
"go.uber.org/atomic"
@@ -50,13 +47,9 @@ type Resolver struct {
options ResolverOptions
maxConcurrency chan struct{}
- triggers map[uint64]*trigger
- heartbeatSubLock *sync.Mutex
- heartbeatSubscriptions map[*Context]*sub
- events chan subscriptionEvent
- triggerEventsSem *semaphore.Weighted
- triggerUpdatesSem *semaphore.Weighted
- triggerUpdateBuf *bytes.Buffer
+ triggers map[uint64]*trigger
+ events chan subscriptionEvent
+ triggerUpdateBuf *bytes.Buffer
allowedErrorExtensionFields map[string]struct{}
allowedErrorFields map[string]struct{}
@@ -66,9 +59,12 @@ type Resolver struct {
reporter Reporter
asyncErrorWriter AsyncErrorWriter
- propagateSubgraphErrors bool
- propagateSubgraphStatusCodes bool
- multipartSubHeartbeatInterval time.Duration
+ propagateSubgraphErrors bool
+ propagateSubgraphStatusCodes bool
+ // Multipart heartbeat interval
+ heartbeatInterval time.Duration
+ // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out
+ maxSubscriptionFetchTimeout time.Duration
}
func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) {
@@ -103,13 +99,6 @@ type ResolverOptions struct {
// This depends on the number of CPU cores available, the complexity of the operations, and the origin services
MaxConcurrency int
- // MaxSubscriptionWorkers limits the concurrency on how many subscription can be added / removed concurrently.
- // This does not include subscription updates, for that we have a separate semaphore MaxSubscriptionUpdates.
- MaxSubscriptionWorkers int
-
- // MaxSubscriptionUpdates limits the number of concurrent subscription updates that can be sent to the event loop.
- MaxSubscriptionUpdates int
-
Debug bool
Reporter Reporter
@@ -146,6 +135,8 @@ type ResolverOptions struct {
AllowedSubgraphErrorFields []string
// MultipartSubHeartbeatInterval defines the interval in which a heartbeat is sent to all multipart subscriptions
MultipartSubHeartbeatInterval time.Duration
+ // MaxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out
+ MaxSubscriptionFetchTimeout time.Duration
}
// New returns a new Resolver, ctx.Done() is used to cancel all active subscriptions & streams
@@ -171,6 +162,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver {
"path": {},
}
+ if options.MaxSubscriptionFetchTimeout == 0 {
+ options.MaxSubscriptionFetchTimeout = 30 * time.Second
+ }
+
if !options.OmitSubgraphErrorExtensions {
allowedErrorFields["extensions"] = struct{}{}
}
@@ -184,36 +179,26 @@ func New(ctx context.Context, options ResolverOptions) *Resolver {
}
resolver := &Resolver{
- ctx: ctx,
- options: options,
- propagateSubgraphErrors: options.PropagateSubgraphErrors,
- propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes,
- events: make(chan subscriptionEvent),
- triggers: make(map[uint64]*trigger),
- heartbeatSubLock: &sync.Mutex{},
- heartbeatSubscriptions: make(map[*Context]*sub),
- reporter: options.Reporter,
- asyncErrorWriter: options.AsyncErrorWriter,
- triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)),
- allowedErrorExtensionFields: allowedExtensionFields,
- allowedErrorFields: allowedErrorFields,
- multipartSubHeartbeatInterval: options.MultipartSubHeartbeatInterval,
+ ctx: ctx,
+ options: options,
+ propagateSubgraphErrors: options.PropagateSubgraphErrors,
+ propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes,
+ events: make(chan subscriptionEvent),
+ triggers: make(map[uint64]*trigger),
+ reporter: options.Reporter,
+ asyncErrorWriter: options.AsyncErrorWriter,
+ triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)),
+ allowedErrorExtensionFields: allowedExtensionFields,
+ allowedErrorFields: allowedErrorFields,
+ heartbeatInterval: options.MultipartSubHeartbeatInterval,
+ maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout,
}
resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency)
for i := 0; i < options.MaxConcurrency; i++ {
resolver.maxConcurrency <- struct{}{}
}
- if options.MaxSubscriptionWorkers == 0 {
- options.MaxSubscriptionWorkers = 1024
- }
- if options.MaxSubscriptionUpdates == 0 {
- options.MaxSubscriptionUpdates = 1024
- }
-
- resolver.triggerEventsSem = semaphore.NewWeighted(int64(options.MaxSubscriptionWorkers))
- resolver.triggerUpdatesSem = semaphore.NewWeighted(int64(options.MaxSubscriptionUpdates))
- go resolver.handleEvents()
+ go resolver.processEvents()
return resolver
}
@@ -285,31 +270,79 @@ type trigger struct {
}
type sub struct {
- mux sync.Mutex
resolve *GraphQLSubscription
+ resolver *Resolver
+ ctx *Context
writer SubscriptionResponseWriter
id SubscriptionIdentifier
- completed chan struct{}
+ heartbeat bool
lastWrite time.Time
+ completed chan struct{}
// executor is an optional argument that allows us to "schedule" the execution of an update on another thread
// e.g. if we're using SSE/Multipart Fetch, we can run the execution on the goroutine of the http request
// this ensures that ctx cancellation works properly when a client disconnects
executor chan func()
+ // workChan is used to send work to the writer goroutine. All work is processed sequentially.
+ workChan chan func()
+}
+
+// startWorker runs in its own goroutine to process fetches and write data to the client synchronously
+// it also takes care of sending heartbeats to the client but only if the subscription supports it
+func (s *sub) startWorker() {
+ if s.heartbeat {
+ s.startWorkerWithHeartbeat()
+ return
+ }
+ s.startWorkerWithoutHeartbeat()
}
-func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput []byte) {
+func (s *sub) startWorkerWithHeartbeat() {
+ heartbeatTicker := time.NewTicker(s.resolver.heartbeatInterval)
+ defer heartbeatTicker.Stop()
+
+ for {
+ select {
+ case <-heartbeatTicker.C:
+ s.resolver.handleHeartbeat(s, multipartHeartbeat)
+ case fn := <-s.workChan:
+ fn()
+ // Reset the heartbeat ticker after each write to avoid sending unnecessary heartbeats
+ heartbeatTicker.Reset(s.resolver.heartbeatInterval)
+ case <-s.completed: // Shutdown the writer when the subscription is completed
+ return
+ }
+ }
+}
+
+func (s *sub) startWorkerWithoutHeartbeat() {
+ for {
+ select {
+ case fn := <-s.workChan:
+ fn()
+ case <-s.completed: // Shutdown the writer when the subscription is completed
+ return
+ }
+ }
+}
+
+func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, sharedInput []byte) {
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID)
}
- t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields)
+ ctx, cancel := context.WithTimeout(resolveCtx.ctx, r.maxSubscriptionFetchTimeout)
+ defer cancel()
+
+ resolveCtx = resolveCtx.WithContext(ctx)
+
+ // Copy the input.
input := make([]byte, len(sharedInput))
copy(input, sharedInput)
- if err := t.resolvable.InitSubscription(ctx, input, sub.resolve.Trigger.PostProcessing); err != nil {
- sub.mux.Lock()
- r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer)
- sub.mux.Unlock()
+ t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields)
+
+ if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil {
+ r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer)
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID)
}
@@ -319,10 +352,8 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput
return
}
- if err := t.loader.LoadGraphQLResponseData(ctx, sub.resolve.Response, t.resolvable); err != nil {
- sub.mux.Lock()
- r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer)
- sub.mux.Unlock()
+ if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil {
+ r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer)
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID)
}
@@ -332,14 +363,8 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput
return
}
- sub.mux.Lock()
- defer func() {
- sub.lastWrite = time.Now()
- sub.mux.Unlock()
- }()
-
- if err := t.resolvable.Resolve(ctx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil {
- r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer)
+ if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil {
+ r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer)
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID)
}
@@ -348,30 +373,28 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput
}
return
}
- err := sub.writer.Flush()
- if err != nil {
- // client disconnected
+
+ if err := sub.writer.Flush(); err != nil {
+ // If flush fails (e.g. client disconnected), remove the subscription.
_ = r.AsyncUnsubscribeSubscription(sub.id)
return
}
+
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:flushed:%d\n", sub.id.SubscriptionID)
}
if r.reporter != nil {
r.reporter.SubscriptionUpdateSent()
}
- if t.resolvable.WroteErrorsWithoutData() {
- if r.options.Debug {
- fmt.Printf("resolver:trigger:subscription:completing:errors_without_data:%d\n", sub.id.SubscriptionID)
- }
+
+ if t.resolvable.WroteErrorsWithoutData() && r.options.Debug {
+ fmt.Printf("resolver:trigger:subscription:completing:errors_without_data:%d\n", sub.id.SubscriptionID)
}
}
-// handleEvents maintains the single threaded event loop that processes all events
-func (r *Resolver) handleEvents() {
+// processEvents maintains the single threaded event loop that processes all events
+func (r *Resolver) processEvents() {
done := r.ctx.Done()
- heartbeat := time.NewTicker(r.multipartSubHeartbeatInterval)
- defer heartbeat.Stop()
for {
select {
case <-done:
@@ -379,8 +402,6 @@ func (r *Resolver) handleEvents() {
return
case event := <-r.events:
r.handleEvent(event)
- case <-heartbeat.C:
- r.handleHeartbeat(multipartHeartbeat)
}
}
}
@@ -410,51 +431,44 @@ func (r *Resolver) handleEvent(event subscriptionEvent) {
}
}
-func (r *Resolver) handleHeartbeat(data []byte) {
- r.heartbeatSubLock.Lock()
- defer r.heartbeatSubLock.Unlock()
-
+// handleHeartbeat sends a heartbeat to the client. It needs to be executed on the same goroutine as the writer.
+func (r *Resolver) handleHeartbeat(sub *sub, data []byte) {
if r.options.Debug {
- fmt.Printf("resolver:heartbeat:%d\n", len(r.heartbeatSubscriptions))
+ fmt.Printf("resolver:heartbeat\n")
}
- now := time.Now()
- for c, s := range r.heartbeatSubscriptions {
- // check if the last write to the subscription was more than heartbeat interval ago
- c, s := c, s
- s.mux.Lock()
- skipHeartbeat := now.Sub(s.lastWrite) < r.multipartSubHeartbeatInterval
- s.mux.Unlock()
- if skipHeartbeat || (c.Context().Err() != nil && errors.Is(c.Context().Err(), context.Canceled)) {
- continue
- }
- if r.options.Debug {
- fmt.Printf("resolver:heartbeat:subscription:%d\n", s.id.SubscriptionID)
- }
+ if r.ctx.Err() != nil {
+ return
+ }
- s.mux.Lock()
- if _, err := s.writer.Write(data); err != nil {
- if errors.Is(err, context.Canceled) {
- // client disconnected
- s.mux.Unlock()
- _ = r.AsyncUnsubscribeSubscription(s.id)
- return
- }
- r.asyncErrorWriter.WriteError(c, err, nil, s.writer)
- }
- err := s.writer.Flush()
- s.mux.Unlock()
- if err != nil {
- // client disconnected
- _ = r.AsyncUnsubscribeSubscription(s.id)
+ if sub.ctx.Context().Err() != nil {
+ return
+ }
+
+ if r.options.Debug {
+ fmt.Printf("resolver:heartbeat:subscription:%d\n", sub.id.SubscriptionID)
+ }
+
+ if _, err := sub.writer.Write(data); err != nil {
+ if errors.Is(err, context.Canceled) {
+ // If Write fails (e.g. client disconnected), remove the subscription.
+ _ = r.AsyncUnsubscribeSubscription(sub.id)
return
}
- if r.options.Debug {
- fmt.Printf("resolver:heartbeat:subscription:flushed:%d\n", s.id.SubscriptionID)
- }
- if r.reporter != nil {
- r.reporter.SubscriptionUpdateSent()
- }
+ r.asyncErrorWriter.WriteError(sub.ctx, err, nil, sub.writer)
+ }
+ err := sub.writer.Flush()
+ if err != nil {
+ // If flush fails (e.g. client disconnected), remove the subscription.
+ _ = r.AsyncUnsubscribeSubscription(sub.id)
+ return
+ }
+
+ if r.options.Debug {
+ fmt.Printf("resolver:heartbeat:subscription:flushed:%d\n", sub.id.SubscriptionID)
+ }
+ if r.reporter != nil {
+ r.reporter.SubscriptionUpdateSent()
}
}
@@ -490,18 +504,24 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID)
}
s := &sub{
+ ctx: add.ctx,
resolve: add.resolve,
writer: add.writer,
id: add.id,
completed: add.completed,
- lastWrite: time.Now(),
executor: add.executor,
+ workChan: make(chan func(), 32),
+ resolver: r,
}
+
if add.ctx.ExecutionOptions.SendHeartbeat {
- r.heartbeatSubLock.Lock()
- r.heartbeatSubscriptions[add.ctx] = s
- r.heartbeatSubLock.Unlock()
+ s.heartbeat = true
}
+
+ // Start the dedicated worker goroutine where the subscription updates are processed
+ // and writes are written to the client in a single threaded manner
+ go s.startWorker()
+
trig, ok := r.triggers[triggerID]
if ok {
trig.subscriptions[add.ctx] = s
@@ -523,7 +543,6 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
triggerID: triggerID,
ch: r.events,
ctx: ctx,
- updateSem: r.triggerUpdatesSem,
}
cloneCtx := add.ctx.clone(ctx)
trig = &trigger{
@@ -580,11 +599,6 @@ func (r *Resolver) emitTriggerShutdown(triggerID uint64) error {
fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID)
}
- if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil {
- return err
- }
- defer r.triggerEventsSem.Release(1)
-
select {
case <-r.ctx.Done():
return r.ctx.Err()
@@ -602,11 +616,6 @@ func (r *Resolver) emitTriggerInitialized(triggerID uint64) error {
fmt.Printf("resolver:trigger:initialized:%d\n", triggerID)
}
- if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil {
- return err
- }
- defer r.triggerEventsSem.Release(1)
-
select {
case <-r.ctx.Done():
return r.ctx.Err()
@@ -664,6 +673,12 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
if r.options.Debug {
fmt.Printf("resolver:trigger:update:%d\n", id)
}
+
+ // collect all updates on the main thread
+ // then push them to the workChans or executors on a separate goroutine
+ // this ensures that we don't block the main thread when a client is slow to consume updates
+ updates := make([]func(), 0, len(trig.subscriptions))
+
for c, s := range trig.subscriptions {
c, s := c, s
if err := c.ctx.Err(); err != nil {
@@ -677,26 +692,40 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
if skip {
continue
}
+
fn := func() {
r.executeSubscriptionUpdate(c, s, data)
}
- // Needs to be executed in a separate goroutine to prevent blocking the event loop.
- go func() {
-
- // Send the update to the executor channel to be executed on the main thread
- // Only relevant for SSE/Multipart subscriptions
- if s.executor != nil {
+ // Send the update to the executor channel to be executed on the main thread
+ // Only relevant for SSE/Multipart subscriptions
+ if s.executor != nil {
+ updates = append(updates, func() {
select {
case <-r.ctx.Done():
case <-c.ctx.Done():
case s.executor <- fn: // Run the update on the main thread and close subscription
}
- } else {
- fn()
+ })
+ continue
+ }
+
+ // Regular updates are sent to the workChan to be processed by the writer goroutine
+ updates = append(updates, func() {
+ select {
+ case <-r.ctx.Done():
+ case <-c.ctx.Done():
+ case s.workChan <- fn: // Channel is buffered, but it can still block in case of a slow writer
}
- }()
+ })
}
+
+ go func() {
+ // push the updates non-blocking to the workChan or executor by executing the queued functions
+ for _, fn := range updates {
+ fn()
+ }
+ }()
}
func (r *Resolver) shutdownTrigger(id uint64) {
@@ -733,19 +762,21 @@ func (r *Resolver) shutdownTriggerSubscriptions(id uint64, shutdownMatcher func(
if shutdownMatcher != nil && !shutdownMatcher(s.id) {
continue
}
+
if c.Context().Err() == nil {
s.writer.Complete()
}
+
if s.completed != nil {
close(s.completed)
}
- r.heartbeatSubLock.Lock()
- delete(r.heartbeatSubscriptions, c)
- r.heartbeatSubLock.Unlock()
+
delete(trig.subscriptions, c)
+
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID)
}
+
removed++
}
return removed
@@ -771,11 +802,6 @@ type SubscriptionIdentifier struct {
}
func (r *Resolver) AsyncUnsubscribeSubscription(id SubscriptionIdentifier) error {
- if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil {
- return err
- }
- defer r.triggerEventsSem.Release(1)
-
select {
case <-r.ctx.Done():
return r.ctx.Err()
@@ -788,11 +814,6 @@ func (r *Resolver) AsyncUnsubscribeSubscription(id SubscriptionIdentifier) error
}
func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error {
- if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil {
- return err
- }
- defer r.triggerEventsSem.Release(1)
-
select {
case <-r.ctx.Done():
return r.ctx.Err()
@@ -917,11 +938,6 @@ Loop: // execute fn on the main thread of the incoming request until ctx is done
}
func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter, id SubscriptionIdentifier) (err error) {
- if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil {
- return err
- }
- defer r.triggerEventsSem.Release(1)
-
if subscription.Trigger.Source == nil {
return errors.New("no data source found")
}
@@ -1012,7 +1028,6 @@ type subscriptionUpdater struct {
triggerID uint64
ch chan subscriptionEvent
ctx context.Context
- updateSem *semaphore.Weighted
}
func (s *subscriptionUpdater) Update(data []byte) {
@@ -1020,11 +1035,6 @@ func (s *subscriptionUpdater) Update(data []byte) {
fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID)
}
- if err := s.updateSem.Acquire(s.ctx, 1); err != nil {
- return
- }
- defer s.updateSem.Release(1)
-
select {
case <-s.ctx.Done():
return
@@ -1041,11 +1051,6 @@ func (s *subscriptionUpdater) Done() {
fmt.Printf("resolver:subscription_updater:done:%d\n", s.triggerID)
}
- if err := s.updateSem.Acquire(s.ctx, 1); err != nil {
- return
- }
- defer s.updateSem.Release(1)
-
select {
case <-s.ctx.Done():
return
diff --git a/tmp/agent-patch-flux-pr-1076.1-of-1.2026-02-28__13-38-10__gpt-5-1-codex-mini/app/v2/pkg/engine/resolve/resolve_test.go b/app/v2/pkg/engine/resolve/resolve_test.go
index c4a41eb..ef25c66 100644
--- a/tmp/agent-patch-flux-pr-1076.1-of-1.2026-02-28__13-38-10__gpt-5-1-codex-mini/app/v2/pkg/engine/resolve/resolve_test.go
+++ b/app/v2/pkg/engine/resolve/resolve_test.go
@@ -86,9 +86,7 @@ func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLRespon
}
}
-var multipartSubHeartbeatInterval = 15 * time.Millisecond
-
-const testMaxSubscriptionWorkers = 1
+var multipartSubHeartbeatInterval = 100 * time.Millisecond
func newResolver(ctx context.Context) *Resolver {
return New(ctx, ResolverOptions{
@@ -98,7 +96,6 @@ func newResolver(ctx context.Context) *Resolver {
PropagateSubgraphStatusCodes: true,
AsyncErrorWriter: &TestErrorWriter{},
MultipartSubHeartbeatInterval: multipartSubHeartbeatInterval,
- MaxSubscriptionWorkers: testMaxSubscriptionWorkers,
})
}
@@ -5185,7 +5182,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
messages := recorder.Messages()
assert.Greater(t, len(messages), 2)
- time.Sleep(2 * resolver.multipartSubHeartbeatInterval)
+ time.Sleep(resolver.heartbeatInterval)
// Validate that despite the time, we don't see any heartbeats sent
assert.Contains(t, messages, `{"data":{"counter":0}}`)
assert.Contains(t, messages, `{"data":{"counter":1}}`)
@@ -5221,17 +5218,35 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
}
recorder.AwaitComplete(t, defaultTimeout)
- time.Sleep(2 * resolver.multipartSubHeartbeatInterval)
- assert.Equal(t, 20, len(recorder.Messages()))
- // Validate that despite the time, we don't see any heartbeats sent
- assert.ElementsMatch(t, []string{
- `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
- `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
- `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
- `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
- `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
- }, recorder.Messages())
+ time.Sleep(resolver.heartbeatInterval)
+
+ assert.Len(t, recorder.Messages(), 31)
+
+ messages := recorder.Messages()
+
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[0])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[1])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[2])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[3])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[4])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[5])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[6])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[7])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[8])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[9])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[10])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[11])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[12])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[13])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[14])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[15])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[16])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[17])
+ assert.Equal(t, `{"data":{"counter":0}}`, messages[18])
+ assert.Equal(t, `{"data":{"counter":1}}`, messages[19])
+
+ assert.Contains(t, messages, `{}`)
})
t.Run("should propagate extensions to stream", func(t *testing.T) {
@@ -5417,7 +5432,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
},
}
- const numSubscriptions = testMaxSubscriptionWorkers + 1
+ const numSubscriptions = 2
var resolverCompleted atomic.Uint32
var recorderCompleted atomic.Uint32
for i := 0; i < numSubscriptions; i++ {
@@ -5456,11 +5471,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
c, cancel := context.WithCancel(context.Background())
defer cancel()
- var started atomic.Bool
- var complete atomic.Bool
-
fakeStream := createFakeStream(func(counter int) (message string, done bool) {
- defer started.Store(true)
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true
}, 100*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
@@ -5480,23 +5491,12 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id)
assert.NoError(t, err)
- assert.Eventually(t, func() bool {
- return started.Load()
- }, defaultTimeout, time.Millisecond*10)
-
- assert.Len(t, resolver.triggers, 1)
-
- var unsubscribeComplete atomic.Bool
- go func() {
- defer unsubscribeComplete.Store(true)
- err = resolver.AsyncUnsubscribeSubscription(id)
- assert.NoError(t, err)
- }()
+ recorder.AwaitAnyMessageCount(t, defaultTimeout)
- complete.Store(true)
- assert.Eventually(t, unsubscribeComplete.Load, defaultTimeout, time.Millisecond*100)
+ err = resolver.AsyncUnsubscribeSubscription(id)
+ assert.NoError(t, err)
recorder.AwaitComplete(t, defaultTimeout)
- assert.Len(t, resolver.triggers, 0)
+ fakeStream.AwaitIsDone(t, defaultTimeout)
})
}