// Copyright 2022-2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package test

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"sync"
	"testing"
	"time"

	"github.com/nats-io/nats.go"
	"github.com/nats-io/nats.go/jetstream"
)

func TestPullConsumerFetch(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	t.Run("no options", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		msgs, err := c.Fetch(5)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		var i int
		for msg := range msgs.Messages() {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
			i++
		}
		if len(testMsgs) != i {
			t.Fatalf("Invalid number of messages received; want: %d; got: %d", len(testMsgs), i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}
	})

	t.Run("delete consumer during fetch", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		msgs, err := c.Fetch(10)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}

		var i int
		for msg := range msgs.Messages() {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
			i++
		}
		if len(testMsgs) != i {
			t.Fatalf("Invalid number of messages received; want: %d; got: %d", len(testMsgs), i)
		}
		if !errors.Is(msgs.Error(), jetstream.ErrConsumerDeleted) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, msgs.Error())
		}
	})

	t.Run("no options, fetch single messages one by one", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		res := make([]jetstream.Msg, 0)
		errs := make(chan error)
		done := make(chan struct{})
		go func() {
			for {
				if len(res) == len(testMsgs) {
					close(done)
					return
				}
				msgs, err := c.Fetch(1)
				if err != nil {
					errs <- err
					return
				}

				msg := <-msgs.Messages()
				if msg != nil {
					res = append(res, msg)
				}
				if err := msgs.Error(); err != nil {
					errs <- err
					return
				}
			}
		}()

		time.Sleep(10 * time.Millisecond)
		publishTestMsgs(t, js)
		select {
		case err := <-errs:
			t.Fatalf("Unexpected error: %v", err)
		case <-done:
			if len(res) != len(testMsgs) {
				t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(res))
			}
		}
		for i, msg := range res {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with no wait, no messages at the time of request", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := c.FetchNoWait(5)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		publishTestMsgs(t, js)

		msg := <-msgs.Messages()
		if msg != nil {
			t.Fatalf("Expected no messages; got: %s", string(msg.Data()))
		}
	})

	t.Run("with no wait, some messages available", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		time.Sleep(50 * time.Millisecond)
		msgs, err := c.FetchNoWait(10)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		publishTestMsgs(t, js)

		var msgsNum int
		for range msgs.Messages() {
			msgsNum++
		}
		if err != nil {
			t.Fatalf("Unexpected error during fetch: %v", err)
		}

		if msgsNum != len(testMsgs) {
			t.Fatalf("Expected %d messages, got: %d", len(testMsgs), msgsNum)
		}
	})

	t.Run("with timeout", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := c.Fetch(5, jetstream.FetchMaxWait(50*time.Millisecond))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msg := <-msgs.Messages()
		if msg != nil {
			t.Fatalf("Expected no messages; got: %s", string(msg.Data()))
		}
	})

	t.Run("with invalid timeout value", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = c.Fetch(5, jetstream.FetchMaxWait(-50*time.Millisecond))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})

	t.Run("consumer does not exist", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		publishTestMsgs(t, js)
		// fetch 5 messages, should return normally
		msgs, err := c.Fetch(5)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		var i int
		for range msgs.Messages() {
			i++
		}
		if i != len(testMsgs) {
			t.Fatalf("Expected %d messages; got: %d", len(testMsgs), i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}

		// fetch again, should timeout without any error
		msgs, err = c.Fetch(5, jetstream.FetchMaxWait(200*time.Millisecond))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		select {
		case _, ok := <-msgs.Messages():
			if ok {
				t.Fatalf("Expected channel to be closed")
			}
		case <-time.After(1 * time.Second):
			t.Fatalf("Expected channel to be closed")
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}

		// delete the consumer, at this point server should stop sending heartbeats for pull requests
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}
		msgs, err = c.Fetch(5)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		select {
		case _, ok := <-msgs.Messages():
			if ok {
				t.Fatalf("Expected channel to be closed")
			}
		case <-time.After(1 * time.Second):
			t.Fatalf("Expected channel to be closed")
		}
		if !errors.Is(msgs.Error(), nats.ErrNoResponders) {
			t.Fatalf("Expected error: %v; got: %v", nats.ErrNoResponders, err)
		}
	})

	t.Run("with invalid heartbeat value", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// default expiry (30s), hb too large
		_, err = c.Fetch(5, jetstream.FetchHeartbeat(20*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}

		// custom expiry, hb too large
		_, err = c.Fetch(5, jetstream.FetchHeartbeat(2*time.Second), jetstream.FetchMaxWait(3*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}

		// negative heartbeat
		_, err = c.Fetch(5, jetstream.FetchHeartbeat(-2*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})

	t.Run("with context", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// pull request should expire before client timeout
		ctx, cancel = context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		result, err := c.Fetch(1, jetstream.FetchContext(ctx))
		if err != nil {
			t.Fatalf("Unexpected error from Fetch: %v", err)
		}
		msg, ok := <-result.Messages()
		if ok {
			t.Fatalf("Expected no message, got: %v", msg)
		}
		if result.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", result.Error())
		}

		// Test context cancellation
		ctx, cancel = context.WithCancel(context.Background())
		go func() {
			time.Sleep(50 * time.Millisecond)
			cancel()
		}()
		result, err = c.Fetch(1, jetstream.FetchContext(ctx))
		if err != nil {
			t.Fatalf("Unexpected error from Fetch: %v", err)
		}
		msg = <-result.Messages()
		if msg != nil {
			t.Fatalf("Expected no message, got: %v", msg)
		}
		err = result.Error()
		if !errors.Is(err, context.Canceled) {
			t.Fatalf("Expected context canceled error, got: %v", err)
		}

		// Test mutual exclusion with FetchMaxWait
		ctx, cancel = context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		_, err = c.Fetch(1, jetstream.FetchContext(ctx), jetstream.FetchMaxWait(time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected mutual exclusion error, got: %v", err)
		}

		// Test already expired context
		expiredCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
		defer cancel()
		_, err = c.Fetch(1, jetstream.FetchContext(expiredCtx))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected invalid option error, got: %v", err)
		}
	})
}

func TestPullConsumerMessagesConcurrentStopAndDrain(t *testing.T) {
	srv := RunBasicJetStreamServer()
	defer shutdownJSServerAndRemoveStorage(t, srv)

	nc, err := nats.Connect(srv.ClientURL())
	if err != nil {
		t.Fatalf("connect: %v", err)
	}
	t.Cleanup(nc.Close)

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	js, err := jetstream.New(nc)
	if err != nil {
		t.Fatalf("jetstream: %v", err)
	}

	_, err = js.CreateStream(ctx, jetstream.StreamConfig{
		Name:     "FOO",
		Subjects: []string{"FOO.>"},
	})
	if err != nil {
		t.Fatalf("create stream: %v", err)
	}

	consumer, err := js.CreateOrUpdateConsumer(ctx, "FOO", jetstream.ConsumerConfig{
		Durable:   "race-consumer",
		AckPolicy: jetstream.AckExplicitPolicy,
	})
	if err != nil {
		t.Fatalf("create consumer: %v", err)
	}

	messages, err := consumer.Messages()
	if err != nil {
		t.Fatalf("messages: %v", err)
	}

	start := make(chan struct{})
	var ready, done sync.WaitGroup
	ready.Add(2)
	done.Add(2)

	go func() {
		defer done.Done()
		ready.Done()
		<-start
		messages.Stop()
	}()

	go func() {
		defer done.Done()
		ready.Done()
		<-start
		nc.Drain()
	}()

	ready.Wait()
	time.Sleep(2 * time.Millisecond)
	close(start)
	done.Wait()
}

func TestPullConsumerFetchRace(t *testing.T) {
	srv := RunBasicJetStreamServer()
	defer shutdownJSServerAndRemoveStorage(t, srv)
	nc, err := nats.Connect(srv.ClientURL())
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	js, err := jetstream.New(nc)
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	defer nc.Close()

	s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}

	for i := 0; i < 3; i++ {
		if _, err := js.Publish(context.Background(), "FOO.123", []byte(fmt.Sprintf("msg-%d", i))); err != nil {
			t.Fatalf("Unexpected error during publish: %s", err)
		}
	}
	msgs, err := c.Fetch(5)
	if err != nil {
		t.Fatalf("Unexpected error: %v", err)
	}
	errCh := make(chan error)
	go func() {
		for {
			err := msgs.Error()
			if err != nil {
				errCh <- err
				return
			}
		}
	}()
	deleteErrCh := make(chan error, 1)
	go func() {
		time.Sleep(100 * time.Millisecond)
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			deleteErrCh <- err
		}
		close(deleteErrCh)
	}()

	var i int
	for msg := range msgs.Messages() {
		if string(msg.Data()) != fmt.Sprintf("msg-%d", i) {
			t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, fmt.Sprintf("msg-%d", i), string(msg.Data()))
		}
		i++
	}
	if i != 3 {
		t.Fatalf("Invalid number of messages received; want: %d; got: %d", 3, i)
	}
	select {
	case err := <-errCh:
		if !errors.Is(err, jetstream.ErrConsumerDeleted) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err)
		}
	case <-time.After(1 * time.Second):
		t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, nil)
	}

	// wait until the consumer is deleted, otherwise we may close the connection
	// before the consumer delete response is received
	select {
	case ert, ok := <-deleteErrCh:
		if !ok {
			break
		}
		t.Fatalf("Error deleting consumer: %s", ert)
	case <-time.After(1 * time.Second):
		t.Fatalf("Expected done to be closed")
	}
}

func TestPullConsumerFetchBytes(t *testing.T) {
	testSubject := "FOO.123"
	msg := [10]byte{}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream, count int) {
		for i := 0; i < count; i++ {
			if _, err := js.Publish(context.Background(), testSubject, msg[:]); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}
	t.Run("no options, exact byte count received", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, Name: "con"})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js, 5)
		// actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43)
		msgs, err := c.FetchBytes(300)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		var i int
		for msg := range msgs.Messages() {
			msg.Ack()
			i++
		}
		if i != 5 {
			t.Fatalf("Expected 5 messages; got: %d", i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}
	})

	t.Run("no options, last msg does not fit max bytes", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, Name: "con"})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js, 5)
		// actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43)
		msgs, err := c.FetchBytes(250)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		var i int
		for msg := range msgs.Messages() {
			msg.Ack()
			i++
		}
		if i != 4 {
			t.Fatalf("Expected 4 messages; got: %d", i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}
	})
	t.Run("no options, single msg is too large", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, Name: "con"})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js, 5)
		// actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43)
		msgs, err := c.FetchBytes(30)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		var i int
		for msg := range msgs.Messages() {
			msg.Ack()
			i++
		}
		if i != 0 {
			t.Fatalf("Expected 0 messages; got: %d", i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}
	})

	t.Run("timeout waiting for messages", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, Name: "con"})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js, 5)
		// actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43)
		msgs, err := c.FetchBytes(1000, jetstream.FetchMaxWait(50*time.Millisecond))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		var i int
		for msg := range msgs.Messages() {
			msg.Ack()
			i++
		}
		if i != 5 {
			t.Fatalf("Expected 5 messages; got: %d", i)
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}
	})

	t.Run("consumer does not exist", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// fetch again, should timeout without any error
		msgs, err := c.FetchBytes(5, jetstream.FetchMaxWait(200*time.Millisecond))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		select {
		case _, ok := <-msgs.Messages():
			if ok {
				t.Fatalf("Expected channel to be closed")
			}
		case <-time.After(1 * time.Second):
			t.Fatalf("Expected channel to be closed")
		}
		if msgs.Error() != nil {
			t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
		}

		// delete the consumer
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}
		msgs, err = c.FetchBytes(5)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		select {
		case _, ok := <-msgs.Messages():
			if ok {
				t.Fatalf("Expected channel to be closed")
			}
		case <-time.After(1 * time.Second):
			t.Fatalf("Expected channel to be closed")
		}
		if !errors.Is(msgs.Error(), nats.ErrNoResponders) {
			t.Fatalf("Expected error: %v; got: %v", nats.ErrNoResponders, err)
		}
	})

	t.Run("with invalid heartbeat value", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// default expiry (30s), hb too large
		_, err = c.FetchBytes(5, jetstream.FetchHeartbeat(20*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}

		// custom expiry, hb too large
		_, err = c.FetchBytes(5, jetstream.FetchHeartbeat(2*time.Second), jetstream.FetchMaxWait(3*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}

		// negative heartbeat
		_, err = c.FetchBytes(5, jetstream.FetchHeartbeat(-2*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})
}

func TestPullConsumerFetch_WithCluster(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	name := "cluster"
	stream := jetstream.StreamConfig{
		Name:     name,
		Replicas: 1,
		Subjects: []string{"FOO.*"},
	}
	t.Run("no options", func(t *testing.T) {
		withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) {
			srv := srvs[0]
			nc, err := nats.Connect(srv.ClientURL())
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			js, err := jetstream.New(nc)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			defer nc.Close()

			s, err := js.Stream(ctx, stream.Name)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			publishTestMsgs(t, js)
			msgs, err := c.Fetch(5)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			var i int
			for msg := range msgs.Messages() {
				if string(msg.Data()) != testMsgs[i] {
					t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
				}
				i++
			}
			if msgs.Error() != nil {
				t.Fatalf("Unexpected error during fetch: %v", msgs.Error())
			}
		})
	})

	t.Run("with no wait, no messages at the time of request", func(t *testing.T) {
		withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) {
			nc, err := nats.Connect(srvs[0].ClientURL())
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			js, err := jetstream.New(nc)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			defer nc.Close()

			s, err := js.Stream(ctx, stream.Name)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}

			msgs, err := c.FetchNoWait(5)
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			time.Sleep(100 * time.Millisecond)
			publishTestMsgs(t, js)

			msg := <-msgs.Messages()
			if msg != nil {
				t.Fatalf("Expected no messages; got: %s", string(msg.Data()))
			}
		})
	})
}

func TestPullConsumerMessages(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	t.Run("no options", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()

		// calling Stop() multiple times should have no effect
		it.Stop()
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
		_, err = it.Next()
		if err == nil || !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
		}
	})

	t.Run("with custom batch size", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages(jetstream.PullMaxMessages(3))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		time.Sleep(10 * time.Millisecond)
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with max fitting 1 message", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// subscribe to next request subject to verify how many next requests were sent
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages(jetstream.PullMaxBytes(60))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		time.Sleep(10 * time.Millisecond)
		requestsNum, _, err := sub.Pending()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// with batch size set to 1, and 5 messages published on subject, there should be a total of 5 requests sent
		if requestsNum < 5 {
			t.Fatalf("Unexpected number of requests sent; want at least 5; got %d", requestsNum)
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("remove consumer when fetching messages", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer it.Stop()

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)
		}
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}
		_, err = it.Next()
		if !errors.Is(err, jetstream.ErrConsumerDeleted) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err)
		}
		publishTestMsgs(t, js)
		time.Sleep(50 * time.Millisecond)
		_, err = it.Next()
		if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
		}
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
	})

	t.Run("with custom max bytes", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// subscribe to next request subject to verify how many next requests were sent
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages(jetstream.PullMaxBytes(150))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		time.Sleep(10 * time.Millisecond)
		requestsNum, _, err := sub.Pending()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		if requestsNum < 3 {
			t.Fatalf("Unexpected number of requests sent; want at least 3; got %d", requestsNum)
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with batch size set to 1", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// subscribe to next request subject to verify how many next requests were sent
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages(jetstream.PullMaxMessages(1))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		time.Sleep(10 * time.Millisecond)
		requestsNum, _, err := sub.Pending()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// with batch size set to 1, and 5 messages published on subject, there should be a total of 5 requests sent
		if requestsNum != 5 {
			t.Fatalf("Unexpected number of requests sent; want 5; got %d", requestsNum)
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with auto unsubscribe", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "test", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		for i := 0; i < 100; i++ {
			if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages(jetstream.StopAfter(50), jetstream.PullMaxMessages(40))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		for i := 0; i < 50; i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			if err := msg.DoubleAck(ctx); err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			msgs = append(msgs, msg)

		}
		if _, err := it.Next(); err != jetstream.ErrMsgIteratorClosed {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
		}
		if len(msgs) != 50 {
			t.Fatalf("Unexpected received message count; want %d; got %d", 50, len(msgs))
		}
		ci, err := c.Info(ctx)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		if ci.NumPending != 50 {
			t.Fatalf("Unexpected number of pending messages; want 50; got %d", ci.NumPending)
		}
		if ci.NumAckPending != 0 {
			t.Fatalf("Unexpected number of ack pending messages; want 0; got %d", ci.NumAckPending)
		}
		if ci.NumWaiting != 0 {
			t.Fatalf("Unexpected number of waiting pull requests; want 0; got %d", ci.NumWaiting)
		}
	})

	t.Run("with auto unsubscribe concurrent", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "test", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		it, err := c.Messages(jetstream.StopAfter(50), jetstream.PullMaxMessages(40))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		for i := 0; i < 100; i++ {
			if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}

		var mu sync.Mutex // Mutex to guard the msgs slice.
		msgs := make([]jetstream.Msg, 0)
		var wg sync.WaitGroup

		wg.Add(50)
		for i := 0; i < 50; i++ {
			go func() {
				defer wg.Done()

				msg, err := it.Next()
				if err != nil {
					return
				}

				ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
				defer cancel()
				if err := msg.DoubleAck(ctx); err == nil {
					// Only append the msg if ack is successful.
					mu.Lock()
					msgs = append(msgs, msg)
					mu.Unlock()
				}
			}()
		}

		wg.Wait()

		// Call Next in a goroutine so we can timeout if it doesn't return.
		errs := make(chan error)
		go func() {
			// This call should return the error ErrMsgIteratorClosed.
			_, err := it.Next()
			errs <- err
		}()

		timer := time.NewTimer(5 * time.Second)
		defer timer.Stop()

		select {
		case <-timer.C:
			t.Fatal("Timed out waiting for Next() to return")
		case err := <-errs:
			if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
				t.Fatalf("Unexpected error: %v", err)
			}
		}

		mu.Lock()
		wantLen, gotLen := 50, len(msgs)
		mu.Unlock()
		if wantLen != gotLen {
			t.Fatalf("Unexpected received message count; want %d; got %d", wantLen, gotLen)
		}

		ci, err := c.Info(ctx)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		if ci.NumPending != 50 {
			t.Fatalf("Unexpected number of pending messages; want 50; got %d", ci.NumPending)
		}
		if ci.NumAckPending != 0 {
			t.Fatalf("Unexpected number of ack pending messages; want 0; got %d", ci.NumAckPending)
		}
		if ci.NumWaiting != 0 {
			t.Fatalf("Unexpected number of waiting pull requests; want 0; got %d", ci.NumWaiting)
		}
	})

	t.Run("create iterator, stop, then create again", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		time.Sleep(10 * time.Millisecond)

		publishTestMsgs(t, js)
		it, err = c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			if msg == nil {
				break
			}
			msg.Ack()
			msgs = append(msgs, msg)

		}
		it.Stop()
		if len(msgs) != 2*len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		expectedMsgs := append(testMsgs, testMsgs...)
		for i, msg := range msgs {
			if string(msg.Data()) != expectedMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with invalid batch size", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = c.Messages(jetstream.PullMaxMessages(-1))
		if err == nil || !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})

	t.Run("with server restart", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer it.Stop()

		done := make(chan struct{})
		errs := make(chan error)
		publishTestMsgs(t, js)
		go func() {
			for i := 0; i < 2*len(testMsgs); i++ {
				msg, err := it.Next()
				if err != nil {
					errs <- err
					return
				}
				msg.Ack()
				msgs = append(msgs, msg)
			}
			done <- struct{}{}
		}()
		time.Sleep(10 * time.Millisecond)
		// restart the server
		srv = restartBasicJSServer(t, srv)
		defer shutdownJSServerAndRemoveStorage(t, srv)
		time.Sleep(10 * time.Millisecond)
		publishTestMsgs(t, js)

		select {
		case <-done:
			if len(msgs) != 2*len(testMsgs) {
				t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
			}
		case err := <-errs:
			t.Fatalf("Unexpected error: %s", err)
		}
	})

	t.Run("with graceful shutdown", func(t *testing.T) {
		cases := map[string]func(jetstream.MessagesContext){
			"stop":  func(mc jetstream.MessagesContext) { mc.Stop() },
			"drain": func(mc jetstream.MessagesContext) { mc.Drain() },
		}

		for name, unsubscribe := range cases {
			t.Run(name, func(t *testing.T) {
				srv := RunBasicJetStreamServer()
				defer shutdownJSServerAndRemoveStorage(t, srv)

				nc, err := nats.Connect(srv.ClientURL())
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				js, err := jetstream.New(nc)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer nc.Close()

				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
				defer cancel()
				s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				it, err := c.Messages()
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				publishTestMsgs(t, js)

				errs := make(chan error)
				msgs := make([]jetstream.Msg, 0)

				go func() {
					for {
						msg, err := it.Next()
						if err != nil {
							errs <- err
							return
						}
						msg.Ack()
						msgs = append(msgs, msg)
					}
				}()

				time.Sleep(10 * time.Millisecond)
				unsubscribe(it) // Next() should return ErrMsgIteratorClosed

				timer := time.NewTimer(5 * time.Second)
				defer timer.Stop()

				select {
				case <-timer.C:
					t.Fatal("Timed out waiting for Next() to return")
				case err := <-errs:
					if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
						t.Fatalf("Unexpected error: %v", err)
					}

					if len(msgs) != len(testMsgs) {
						t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
					}
				}
			})
		}
	})

	t.Run("with idle heartbeat", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// remove consumer to force missing heartbeats
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}

		it, err := c.Messages(jetstream.PullHeartbeat(500 * time.Millisecond))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer it.Stop()
		now := time.Now()
		_, err = it.Next()
		elapsed := time.Since(now)
		if !errors.Is(err, jetstream.ErrNoHeartbeat) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrNoHeartbeat, err)
		}
		// we should get missing heartbeat error after approximately 2*heartbeat interval
		if elapsed < time.Second || elapsed > 1500*time.Millisecond {
			t.Fatalf("Unexpected elapsed time; want 1-1.5s; got %v", elapsed)
		}
	})

	t.Run("no messages received after stop", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		go func() {
			time.Sleep(100 * time.Millisecond)
			it.Stop()
		}()
		for i := 0; i < 2; i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			time.Sleep(80 * time.Millisecond)
			msg.Ack()
			msgs = append(msgs, msg)
		}
		_, err = it.Next()
		if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
		}

		if len(msgs) != 2 {
			t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs))
		}
	})

	t.Run("drain mode", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		it, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		go func() {
			time.Sleep(100 * time.Millisecond)
			it.Drain()
		}()
		for i := 0; i < len(testMsgs); i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatal(err)
			}
			time.Sleep(50 * time.Millisecond)
			msg.Ack()
			msgs = append(msgs, msg)
		}
		_, err = it.Next()
		if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrMsgIteratorClosed, err)
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count after drain; want %d; got %d", len(testMsgs), len(msgs))
		}
	})

	t.Run("with max messages and per fetch size limit", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// subscribe to next request subject to verify how many next requests were sent
		// and whether both thresholds work as expected
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}
		defer sub.Unsubscribe()

		it, err := c.Messages(jetstream.PullMaxMessagesWithBytesLimit(10, 1024))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		smallMsg := nats.Msg{
			Subject: "FOO.A",
			Data:    []byte("msg"),
		}
		// publish 10 small messages
		for i := 0; i < 10; i++ {
			if _, err := js.PublishMsg(ctx, &smallMsg); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}

		for i := 0; i < 10; i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			msg.Ack()
		}

		// we should get 2 pull requests
		for range 2 {
			fetchReq, err := sub.NextMsg(100 * time.Millisecond)
			if err != nil {
				t.Fatalf("Error on next msg: %v", err)
			}
			if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) {
				t.Fatalf("Unexpected fetch request: %s", fetchReq.Data)
			}
		}
		// make sure no more requests were sent
		_, err = sub.NextMsg(100 * time.Millisecond)
		if !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout error; got: %v", err)
		}

		// now publish 10 large messages, almost hitting the limit
		// we need to account for the total message size (which includes js ack reply subject)
		largeMsg := nats.Msg{
			Subject: "FOO.B",
			Data:    make([]byte, 950),
		}
		for range 10 {
			if _, err := js.PublishMsg(ctx, &largeMsg); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}

		for i := 0; i < 10; i++ {
			msg, err := it.Next()
			if err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			msg.Ack()
		}
		// we expect 10 pull requests
		for range 9 {
			fetchReq, err := sub.NextMsg(100 * time.Millisecond)
			if err != nil {
				t.Fatalf("Error on next msg: %v", err)
			}
			if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) {
				t.Fatalf("Unexpected fetch request: %s", fetchReq.Data)
			}
		}
		_, err = sub.NextMsg(100 * time.Millisecond)
		if !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout error; got: %v", err)
		}

		it.Stop()
	})
}

func TestPullConsumerConsume(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	t.Run("no options", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		publishTestMsgs(t, js)
		wg.Wait()
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("subscribe twice on the same consumer", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg := sync.WaitGroup{}
		msgs1, msgs2 := make([]jetstream.Msg, 0), make([]jetstream.Msg, 0)
		l1, err := c.Consume(func(msg jetstream.Msg) {
			msgs1 = append(msgs1, msg)
			wg.Done()
			msg.Ack()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l1.Stop()
		l2, err := c.Consume(func(msg jetstream.Msg) {
			msgs2 = append(msgs2, msg)
			wg.Done()
			msg.Ack()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l2.Stop()

		wg.Add(len(testMsgs))
		publishTestMsgs(t, js)
		wg.Wait()

		if len(msgs1)+len(msgs2) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs1)+len(msgs2))
		}
		if len(msgs1) == 0 || len(msgs2) == 0 {
			t.Fatalf("Received no messages on one of the subscriptions")
		}
	})

	t.Run("subscribe, cancel subscription, then subscribe again", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg := sync.WaitGroup{}
		wg.Add(len(testMsgs))
		msgs := make([]jetstream.Msg, 0)
		l, err := c.Consume(func(msg jetstream.Msg) {
			if err := msg.Ack(); err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			msgs = append(msgs, msg)
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		wg.Wait()
		l.Stop()

		time.Sleep(10 * time.Millisecond)
		wg.Add(len(testMsgs))
		l, err = c.Consume(func(msg jetstream.Msg) {
			if err := msg.Ack(); err != nil {
				t.Fatalf("Unexpected error: %v", err)
			}
			msgs = append(msgs, msg)
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()
		publishTestMsgs(t, js)
		wg.Wait()
		if len(msgs) != 2*len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		expectedMsgs := append(testMsgs, testMsgs...)
		for i, msg := range msgs {
			if string(msg.Data()) != expectedMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with custom batch size", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullMaxMessages(4))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		publishTestMsgs(t, js)
		wg.Wait()

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("fetch messages one by one", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullMaxMessages(1))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		publishTestMsgs(t, js)
		wg.Wait()

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("remove consumer during consume", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		errs := make(chan error, 10)
		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.ConsumeErrHandler(func(consumeCtx jetstream.ConsumeContext, err error) {
			errs <- err
		}))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		publishTestMsgs(t, js)
		wg.Wait()
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}
		select {
		case err := <-errs:
			if !errors.Is(err, jetstream.ErrConsumerDeleted) {
				t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err)
			}
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for %v", jetstream.ErrConsumerDeleted)
		}
		publishTestMsgs(t, js)
		time.Sleep(50 * time.Millisecond)
		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
	})

	t.Run("with custom max bytes", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		// subscribe to next request subject to verify how many next requests were sent
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}

		publishTestMsgs(t, js)
		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullMaxBytes(150))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		wg.Wait()
		requestsNum, _, err := sub.Pending()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// new request should be sent after each consumed message (msg size is 57)
		if requestsNum < 3 {
			t.Fatalf("Unexpected number of requests sent; want at least 5; got %d", requestsNum)
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with auto unsubscribe", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		for i := 0; i < 100; i++ {
			if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(50)
		_, err = c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			msg.Ack()
			wg.Done()
		}, jetstream.StopAfter(50), jetstream.PullMaxMessages(40))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg.Wait()
		time.Sleep(10 * time.Millisecond)
		ci, err := c.Info(ctx)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		if ci.NumPending != 50 {
			t.Fatalf("Unexpected number of pending messages; want 50; got %d", ci.NumPending)
		}
		if ci.NumAckPending != 0 {
			t.Fatalf("Unexpected number of ack pending messages; want 0; got %d", ci.NumAckPending)
		}
		if ci.NumWaiting != 0 {
			t.Fatalf("Unexpected number of waiting pull requests; want 0; got %d", ci.NumWaiting)
		}
	})

	t.Run("with invalid batch size", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = c.Consume(func(_ jetstream.Msg) {
		}, jetstream.PullMaxMessages(-1))
		if err == nil || !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})

	t.Run("with custom expiry", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs := make([]jetstream.Msg, 0)
		wg := &sync.WaitGroup{}
		wg.Add(len(testMsgs))
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullExpiry(2*time.Second))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		publishTestMsgs(t, js)
		wg.Wait()

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
		}
		for i, msg := range msgs {
			if string(msg.Data()) != testMsgs[i] {
				t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
			}
		}
	})

	t.Run("with invalid expiry", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		_, err = c.Consume(func(_ jetstream.Msg) {
		}, jetstream.PullExpiry(-1))
		if err == nil || !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})

	t.Run("with missing heartbeat", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// delete consumer to force missing heartbeat error
		if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
			t.Fatalf("Error deleting consumer: %s", err)
		}

		errs := make(chan error, 1)
		now := time.Now()
		var elapsed time.Duration
		l, err := c.Consume(func(msg jetstream.Msg) {},
			jetstream.PullHeartbeat(500*time.Millisecond),
			jetstream.ConsumeErrHandler(func(consumeCtx jetstream.ConsumeContext, err error) {
				errs <- err
			}))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()

		// if the consumer does not exist, server will return ErrNoResponders
		select {
		case err := <-errs:
			if !errors.Is(err, nats.ErrNoResponders) {
				t.Fatalf("Expected error: %v; got: %v", nats.ErrNoResponders, err)
			}
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for %v", jetstream.ErrNoHeartbeat)
		}

		// after 2*heartbeat interval, we should get ErrNoHeartbeat
		select {
		case err := <-errs:
			if !errors.Is(err, jetstream.ErrNoHeartbeat) {
				t.Fatalf("Expected error: %v; got: %v", jetstream.ErrNoHeartbeat, err)
			}
			elapsed = time.Since(now)
			if elapsed < time.Second || elapsed > 1500*time.Millisecond {
				t.Fatalf("Unexpected elapsed time; want between 1s and 1.5s; got %v", elapsed)
			}
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for %v", jetstream.ErrNoHeartbeat)
		}
	})

	t.Run("with server restart", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg := &sync.WaitGroup{}
		wg.Add(2 * len(testMsgs))
		msgs := make([]jetstream.Msg, 0)
		publishTestMsgs(t, js)
		l, err := c.Consume(func(msg jetstream.Msg) {
			msgs = append(msgs, msg)
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer l.Stop()
		time.Sleep(10 * time.Millisecond)
		// restart the server
		srv = restartBasicJSServer(t, srv)
		defer shutdownJSServerAndRemoveStorage(t, srv)
		time.Sleep(10 * time.Millisecond)
		publishTestMsgs(t, js)
		wg.Wait()
	})

	t.Run("no messages received after stop", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		wg := &sync.WaitGroup{}
		wg.Add(2)
		publishTestMsgs(t, js)
		msgs := make([]jetstream.Msg, 0)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			time.Sleep(80 * time.Millisecond)
			msg.Ack()
			msgs = append(msgs, msg)
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		cc.Stop()
		wg.Wait()
		// wait for some time to make sure no new messages are received
		time.Sleep(100 * time.Millisecond)

		if len(msgs) != 2 {
			t.Fatalf("Unexpected received message count after stop; want 2; got %d", len(msgs))
		}
	})

	t.Run("drain mode", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		wg := &sync.WaitGroup{}
		wg.Add(5)
		publishTestMsgs(t, js)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			time.Sleep(50 * time.Millisecond)
			msg.Ack()
			wg.Done()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		cc.Drain()
		wg.Wait()
	})

	t.Run("wait for closed after drain", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		msgs := make([]jetstream.Msg, 0)
		lock := sync.Mutex{}
		publishTestMsgs(t, js)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			time.Sleep(50 * time.Millisecond)
			msg.Ack()
			lock.Lock()
			msgs = append(msgs, msg)
			lock.Unlock()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		closed := cc.Closed()
		time.Sleep(100 * time.Millisecond)

		cc.Drain()

		select {
		case <-closed:
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for consume to be closed")
		}

		if len(msgs) != len(testMsgs) {
			t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", len(testMsgs), len(msgs))
		}
	})

	t.Run("wait for closed after stop", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		msgs := make([]jetstream.Msg, 0)
		lock := sync.Mutex{}
		publishTestMsgs(t, js)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			time.Sleep(50 * time.Millisecond)
			msg.Ack()
			lock.Lock()
			msgs = append(msgs, msg)
			lock.Unlock()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)
		closed := cc.Closed()

		cc.Stop()

		select {
		case <-closed:
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for consume to be closed")
		}

		if len(msgs) < 1 || len(msgs) > 3 {
			t.Fatalf("Unexpected received message count after consume closed; want 1-3; got %d", len(msgs))
		}
	})

	t.Run("wait for closed on already closed consume", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		publishTestMsgs(t, js)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			time.Sleep(50 * time.Millisecond)
			msg.Ack()
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		time.Sleep(100 * time.Millisecond)

		cc.Stop()

		time.Sleep(100 * time.Millisecond)

		select {
		case <-cc.Closed():
		case <-time.After(5 * time.Second):
			t.Fatalf("Timeout waiting for consume to be closed")
		}
	})

	t.Run("with max messages and per fetch size limit", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// subscribe to next request subject to verify how many next requests were sent
		// and whether both thresholds work as expected
		sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name))
		if err != nil {
			t.Fatalf("Error on subscribe: %v", err)
		}
		defer sub.Unsubscribe()

		wg := &sync.WaitGroup{}
		msgs := make([]jetstream.Msg, 0)
		cc, err := c.Consume(func(msg jetstream.Msg) {
			msg.Ack()
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullMaxMessagesWithBytesLimit(10, 1024))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		smallMsg := nats.Msg{
			Subject: "FOO.A",
			Data:    []byte("msg"),
		}
		wg.Add(10)
		// publish 10 small messages
		for i := 0; i < 10; i++ {
			if _, err := js.PublishMsg(ctx, &smallMsg); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
		wg.Wait()

		// we should get 2 pull requests
		for range 2 {
			fetchReq, err := sub.NextMsg(100 * time.Millisecond)
			if err != nil {
				t.Fatalf("Error on next msg: %v", err)
			}
			if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) {
				t.Fatalf("Unexpected fetch request: %s", fetchReq.Data)
			}
		}
		// make sure no more requests were sent
		_, err = sub.NextMsg(100 * time.Millisecond)
		if !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout error; got: %v", err)
		}

		// now publish 10 large messages, almost hitting the limit
		// we need to account for the total message size (which includes js ack reply subject)
		largeMsg := nats.Msg{
			Subject: "FOO.B",
			Data:    make([]byte, 950),
		}
		wg.Add(10)
		for range 10 {
			if _, err := js.PublishMsg(ctx, &largeMsg); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
		wg.Wait()

		// we expect 10 pull requests
		for range 10 {
			fetchReq, err := sub.NextMsg(100 * time.Millisecond)
			if err != nil {
				t.Fatalf("Error on next msg: %v", err)
			}
			if !bytes.Contains(fetchReq.Data, []byte(`"max_bytes":1024`)) {
				t.Fatalf("Unexpected fetch request: %s", fetchReq.Data)
			}
		}
		_, err = sub.NextMsg(100 * time.Millisecond)
		if !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout error; got: %v", err)
		}

		cc.Stop()
	})

	t.Run("avoid stall on batch completed status", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg := &sync.WaitGroup{}
		msgs := make([]jetstream.Msg, 0)
		// use consume with small max messages and large max bytes
		// to make sure we don't stall on batch completed status
		cc, err := c.Consume(func(msg jetstream.Msg) {
			msg.Ack()
			msgs = append(msgs, msg)
			wg.Done()
		}, jetstream.PullMaxMessagesWithBytesLimit(2, 1024))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		wg.Add(10)
		for i := 0; i < 10; i++ {
			if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
		wg.Wait()
		cc.Stop()
	})

	t.Run("invalid heartbeat", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		// default expiry is 30s, so max heartbeat should be 15s
		_, err = c.Consume(func(_ jetstream.Msg) {
		}, jetstream.PullHeartbeat(20*time.Second))
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
		}
	})
}

func TestPullConsumerConsume_WithCluster(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	name := "cluster"
	singleStream := jetstream.StreamConfig{
		Name:     name,
		Replicas: 1,
		Subjects: []string{"FOO.*"},
	}

	streamWithReplicas := jetstream.StreamConfig{
		Name:     name,
		Replicas: 3,
		Subjects: []string{"FOO.*"},
	}

	for _, stream := range []jetstream.StreamConfig{singleStream, streamWithReplicas} {
		t.Run(fmt.Sprintf("num replicas: %d, no options", stream.Replicas), func(t *testing.T) {
			withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) {
				nc, err := nats.Connect(srvs[0].ClientURL())
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				js, err := jetstream.New(nc)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer nc.Close()

				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
				defer cancel()
				s, err := js.Stream(ctx, stream.Name)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				msgs := make([]jetstream.Msg, 0)
				wg := &sync.WaitGroup{}
				wg.Add(len(testMsgs))
				l, err := c.Consume(func(msg jetstream.Msg) {
					msgs = append(msgs, msg)
					wg.Done()
				})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer l.Stop()

				publishTestMsgs(t, js)
				wg.Wait()
				if len(msgs) != len(testMsgs) {
					t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
				}
				for i, msg := range msgs {
					if string(msg.Data()) != testMsgs[i] {
						t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
					}
				}
			})
		})

		t.Run(fmt.Sprintf("num replicas: %d, subscribe, cancel subscription, then subscribe again", stream.Replicas), func(t *testing.T) {
			withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) {
				nc, err := nats.Connect(srvs[0].ClientURL())
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				js, err := jetstream.New(nc)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer nc.Close()

				ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
				defer cancel()
				s, err := js.Stream(ctx, stream.Name)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				wg := sync.WaitGroup{}
				wg.Add(len(testMsgs))
				msgs := make([]jetstream.Msg, 0)
				l, err := c.Consume(func(msg jetstream.Msg) {
					if err := msg.Ack(); err != nil {
						t.Fatalf("Unexpected error: %v", err)
					}
					msgs = append(msgs, msg)
					if len(msgs) == 5 {
						cancel()
					}
					wg.Done()
				})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				publishTestMsgs(t, js)
				wg.Wait()
				l.Stop()

				time.Sleep(10 * time.Millisecond)
				wg.Add(len(testMsgs))
				l, err = c.Consume(func(msg jetstream.Msg) {
					if err := msg.Ack(); err != nil {
						t.Fatalf("Unexpected error: %v", err)
					}
					msgs = append(msgs, msg)
					wg.Done()
				})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer l.Stop()
				publishTestMsgs(t, js)
				wg.Wait()
				if len(msgs) != 2*len(testMsgs) {
					t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
				}
				expectedMsgs := append(testMsgs, testMsgs...)
				for i, msg := range msgs {
					if string(msg.Data()) != expectedMsgs[i] {
						t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
					}
				}
			})
		})

		t.Run(fmt.Sprintf("num replicas: %d, recover consume after server restart", stream.Replicas), func(t *testing.T) {
			withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) {
				nc, err := nats.Connect(srvs[0].ClientURL())
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				js, err := jetstream.New(nc)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer nc.Close()

				ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
				defer cancel()
				s, err := js.Stream(ctx, streamWithReplicas.Name)
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy, InactiveThreshold: 10 * time.Second})
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}

				wg := sync.WaitGroup{}
				wg.Add(len(testMsgs))
				msgs := make([]jetstream.Msg, 0)
				l, err := c.Consume(func(msg jetstream.Msg) {
					if err := msg.Ack(); err != nil {
						t.Fatalf("Unexpected error: %v", err)
					}
					msgs = append(msgs, msg)
					wg.Done()
				}, jetstream.PullExpiry(1*time.Second), jetstream.PullHeartbeat(500*time.Millisecond))
				if err != nil {
					t.Fatalf("Unexpected error: %v", err)
				}
				defer l.Stop()

				publishTestMsgs(t, js)
				wg.Wait()

				time.Sleep(10 * time.Millisecond)
				srvs[0].Shutdown()
				srvs[1].Shutdown()
				srvs[0].Restart()
				srvs[1].Restart()
				wg.Add(len(testMsgs))

				for i := 0; i < 10; i++ {
					time.Sleep(500 * time.Millisecond)
					if _, err := js.Stream(context.Background(), stream.Name); err == nil {
						break
					} else if i == 9 {
						t.Fatal("JetStream not recovered: ", err)
					}
				}
				publishTestMsgs(t, js)
				wg.Wait()
				if len(msgs) != 2*len(testMsgs) {
					t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
				}
				expectedMsgs := append(testMsgs, testMsgs...)
				for i, msg := range msgs {
					if string(msg.Data()) != expectedMsgs[i] {
						t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data()))
					}
				}
			})
		})
	}
}

func TestPullConsumerNext(t *testing.T) {
	testSubject := "FOO.123"
	testMsgs := []string{"m1", "m2", "m3", "m4", "m5"}
	publishTestMsgs := func(t *testing.T, js jetstream.JetStream) {
		for _, msg := range testMsgs {
			if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil {
				t.Fatalf("Unexpected error during publish: %s", err)
			}
		}
	}

	t.Run("no options", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		publishTestMsgs(t, js)
		msgs := make([]jetstream.Msg, 0)

		var i int
		for i := 0; i < len(testMsgs); i++ {
			msg, err := c.Next()
			if err != nil {
				t.Fatalf("Error fetching message: %s", err)
			}
			msgs = append(msgs, msg)
		}
		if len(testMsgs) != len(msgs) {
			t.Fatalf("Invalid number of messages received; want: %d; got: %d", len(testMsgs), i)
		}
	})

	t.Run("delete consumer while waiting for message", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		time.AfterFunc(100*time.Millisecond, func() {
			if err := s.DeleteConsumer(ctx, c.CachedInfo().Name); err != nil {
				t.Fatalf("Error deleting consumer: %s", err)
			}
		})

		if _, err := c.Next(); !errors.Is(err, jetstream.ErrConsumerDeleted) {
			t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err)
		}
		time.Sleep(100 * time.Millisecond)
	})

	t.Run("with custom timeout", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		if _, err := c.Next(jetstream.FetchMaxWait(50 * time.Millisecond)); !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout; got: %s", err)
		}
	})
}

func TestPullConsumerMessagesNextWithTimeout(t *testing.T) {
	t.Run("with timeout option", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer msgs.Stop()

		// no msgs yet, should timeout
		start := time.Now()
		_, err = msgs.Next(jetstream.NextMaxWait(100 * time.Millisecond))
		elapsed := time.Since(start)
		if !errors.Is(err, nats.ErrTimeout) {
			t.Fatalf("Expected timeout error; got: %v", err)
		}
		if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond {
			t.Fatalf("Timeout not respected; elapsed: %v", elapsed)
		}

		// Publish a message and verify it can be fetched
		if _, err := js.Publish(ctx, "FOO.A", []byte("msg1")); err != nil {
			t.Fatalf("Unexpected error during publish: %s", err)
		}

		msg, err := msgs.Next(jetstream.NextMaxWait(1 * time.Second))
		if err != nil {
			t.Fatalf("Expected to receive message, got error: %v", err)
		}
		if string(msg.Data()) != "msg1" {
			t.Fatalf("Unexpected message data; got: %s", msg.Data())
		}
	})

	t.Run("with context option", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		s, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(context.Background(), jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer msgs.Stop()

		// context timeout
		ctx1, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
		defer cancel2()

		start := time.Now()
		_, err = msgs.Next(jetstream.NextContext(ctx1))
		elapsed := time.Since(start)
		if !errors.Is(err, context.DeadlineExceeded) {
			t.Fatalf("Expected context deadline exceeded error; got: %v", err)
		}
		if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond {
			t.Fatalf("Context timeout not respected; elapsed: %v", elapsed)
		}

		// cancel context before calling Next
		ctx2, cancel3 := context.WithCancel(context.Background())
		cancel3()
		_, err = msgs.Next(jetstream.NextContext(ctx2))
		if !errors.Is(err, context.Canceled) {
			t.Fatalf("Expected context canceled error; got: %v", err)
		}

		// Publish a message and verify it can be fetched
		if _, err := js.Publish(context.Background(), "FOO.A", []byte("msg1")); err != nil {
			t.Fatalf("Unexpected error during publish: %s", err)
		}

		ctx3, cancel4 := context.WithTimeout(context.Background(), time.Second)
		defer cancel4()
		msg, err := msgs.Next(jetstream.NextContext(ctx3))
		if err != nil {
			t.Fatalf("Expected to receive message, got error: %v", err)
		}
		if string(msg.Data()) != "msg1" {
			t.Fatalf("Unexpected message data; got: %s", msg.Data())
		}
	})

	t.Run("context and timeout provided", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer nc.Close()

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := c.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer msgs.Stop()

		// Test that providing both NextMaxWait and NextContext returns an error
		testCtx, testCancel := context.WithTimeout(context.Background(), time.Second)
		defer testCancel()

		_, err = msgs.Next(jetstream.NextMaxWait(500*time.Millisecond), jetstream.NextContext(testCtx))
		if err == nil {
			t.Fatal("Expected error when providing both NextMaxWait and NextContext")
		}
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected ErrInvalidOption, got: %v", err)
		}
		if !errors.Is(err, jetstream.ErrInvalidOption) {
			t.Fatalf("Expected specific error message, got: %v", err)
		}
	})
}

func TestPullConsumerConnectionClosed(t *testing.T) {
	t.Run("messages", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
			Name:     "test-stream",
			Subjects: []string{"test.>"},
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
			Name: "test-consumer",
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := consumer.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		errC := make(chan error, 1)

		go func() {
			_, err := msgs.Next()
			errC <- err
		}()

		time.Sleep(100 * time.Millisecond)
		nc.Close()

		select {
		case err := <-errC:
			if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
				t.Fatalf("Expected error to contain ErrMsgIteratorClosed, got: %v", err)
			}
			if !errors.Is(err, jetstream.ErrConnectionClosed) {
				t.Fatalf("Expected error to contain ErrConnectionClosed, got: %v", err)
			}
		case <-time.After(10 * time.Second):
			t.Fatal("Next() hung indefinitely after connection closed")
		}
	})

	t.Run("consume", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)
		nc, err := nats.Connect(srv.ClientURL())
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
			Name:     "test-stream",
			Subjects: []string{"test.>"},
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
			Name: "test-consumer",
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		errC := make(chan error, 1)

		consumeCtx, err := consumer.Consume(func(msg jetstream.Msg) {
		}, jetstream.ConsumeErrHandler(func(cc jetstream.ConsumeContext, err error) {
			errC <- err
		}))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer consumeCtx.Stop()

		time.Sleep(100 * time.Millisecond)
		nc.Close()

		select {
		case err := <-errC:
			if !errors.Is(err, jetstream.ErrConnectionClosed) {
				t.Fatalf("Expected ErrConnectionClosed, got: %v", err)
			}
			select {
			case <-consumeCtx.Closed():
			case <-time.After(3 * time.Second):
				t.Fatal("Received error but Consume context was not closed")
			}
		case <-time.After(3 * time.Second):
			t.Fatal("Consume did not return error after connection closed")
		}
	})
}

func TestPullConsumerMaxReconnectsExceeded(t *testing.T) {
	t.Run("messages", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)

		nc, err := nats.Connect(srv.ClientURL(),
			nats.MaxReconnects(3),
			nats.ReconnectWait(100*time.Millisecond),
		)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
			Name:     "test-stream",
			Subjects: []string{"test.>"},
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
			Name: "test-consumer",
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		msgs, err := consumer.Messages()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		errC := make(chan error, 1)

		go func() {
			_, err := msgs.Next()
			errC <- err
		}()

		time.Sleep(100 * time.Millisecond)
		shutdownJSServerAndRemoveStorage(t, srv)

		select {
		case err := <-errC:
			if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
				t.Fatalf("Expected error to contain ErrMsgIteratorClosed, got: %v", err)
			}
			if !errors.Is(err, jetstream.ErrConnectionClosed) {
				t.Fatalf("Expected error to contain ErrConnectionClosed, got: %v", err)
			}
		case <-time.After(15 * time.Second):
			t.Fatal("Next() hung after reconnection attempts exhausted")
		}
	})

	t.Run("consume", func(t *testing.T) {
		srv := RunBasicJetStreamServer()
		defer shutdownJSServerAndRemoveStorage(t, srv)

		nc, err := nats.Connect(srv.ClientURL(),
			nats.MaxReconnects(3),
			nats.ReconnectWait(100*time.Millisecond),
		)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		js, err := jetstream.New(nc)
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
			Name:     "test-stream",
			Subjects: []string{"test.>"},
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
			Name: "test-consumer",
		})
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}

		errC := make(chan error, 1)

		consumeCtx, err := consumer.Consume(func(msg jetstream.Msg) {
		}, jetstream.ConsumeErrHandler(func(cc jetstream.ConsumeContext, err error) {
			errC <- err
		}))
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		defer consumeCtx.Stop()

		time.Sleep(100 * time.Millisecond)
		shutdownJSServerAndRemoveStorage(t, srv)

		// first, we should receive Server Shutdown error form server
		select {
		case err := <-errC:
			if !errors.Is(err, jetstream.ErrServerShutdown) {
				t.Fatalf("Expected error to contain ErrServerShutdown, got: %v", err)
			}
			// consume context should not be closed yet because client tries to reconnect
			select {
			case <-consumeCtx.Closed():
				t.Fatalf("Consume context should not be closed after server shutdown error")
			case <-time.After(100 * time.Millisecond):
			}
		case <-time.After(3 * time.Second):
			t.Fatal("Consume did not return error after server shutdown")
		}

		// now we should receive connection closed error after all reconnection attempts exhausted
		// and consume context should be closed
		select {
		case err := <-errC:
			if !errors.Is(err, jetstream.ErrConnectionClosed) {
				t.Fatalf("Expected ErrConnectionClosed, got: %v", err)
			}
			select {
			case <-consumeCtx.Closed():
			case <-time.After(3 * time.Second):
				t.Fatal("Received error but Consume context was not closed")
			}
		case <-time.After(3 * time.Second):
			t.Fatal("Consume did not return error after connection closed")
		}

	})
}
