diff --git a/weed/messaging/msgclient/sub_chan.go b/weed/messaging/msgclient/sub_chan.go index 3eabc6210..213ff4666 100644 --- a/weed/messaging/msgclient/sub_chan.go +++ b/weed/messaging/msgclient/sub_chan.go @@ -1,6 +1,7 @@ package msgclient import ( + "context" "crypto/md5" "hash" "io" @@ -15,6 +16,7 @@ type SubChannel struct { ch chan []byte stream messaging_pb.SeaweedMessaging_SubscribeClient md5hash hash.Hash + cancel context.CancelFunc } func (mc *MessagingClient) NewSubChannel(subscriberId, chanName string) (*SubChannel, error) { @@ -27,7 +29,8 @@ func (mc *MessagingClient) NewSubChannel(subscriberId, chanName string) (*SubCha if err != nil { return nil, err } - sc, err := setupSubscriberClient(grpcConnection, tp, subscriberId, time.Unix(0, 0)) + ctx, cancel := context.WithCancel(context.Background()) + sc, err := setupSubscriberClient(ctx, grpcConnection, tp, subscriberId, time.Unix(0, 0)) if err != nil { return nil, err } @@ -36,6 +39,7 @@ func (mc *MessagingClient) NewSubChannel(subscriberId, chanName string) (*SubCha ch: make(chan []byte), stream: sc, md5hash: md5.New(), + cancel: cancel, } go func() { @@ -57,6 +61,7 @@ func (mc *MessagingClient) NewSubChannel(subscriberId, chanName string) (*SubCha IsClose: true, }) close(t.ch) + cancel() return } t.ch <- resp.Data.Value @@ -74,3 +79,7 @@ func (sc *SubChannel) Channel() chan []byte { func (sc *SubChannel) Md5() []byte { return sc.md5hash.Sum(nil) } + +func (sc *SubChannel) Cancel() { + sc.cancel() +} diff --git a/weed/messaging/msgclient/subscriber.go b/weed/messaging/msgclient/subscriber.go index f96bba2ec..926e193dd 100644 --- a/weed/messaging/msgclient/subscriber.go +++ b/weed/messaging/msgclient/subscriber.go @@ -12,6 +12,7 @@ import ( type Subscriber struct { subscriberClients []messaging_pb.SeaweedMessaging_SubscribeClient + subscriberCancels []context.CancelFunc subscriberId string } @@ -21,6 +22,7 @@ func (mc *MessagingClient) NewSubscriber(subscriberId, namespace, topic string, PartitionCount: 4, } subscriberClients := make([]messaging_pb.SeaweedMessaging_SubscribeClient, topicConfiguration.PartitionCount) + subscriberCancels := make([]context.CancelFunc, topicConfiguration.PartitionCount) for i := 0; i < int(topicConfiguration.PartitionCount); i++ { if partitionId>=0 && i != partitionId { @@ -35,21 +37,24 @@ func (mc *MessagingClient) NewSubscriber(subscriberId, namespace, topic string, if err != nil { return nil, err } - client, err := setupSubscriberClient(grpcClientConn, tp, subscriberId, startTime) + ctx, cancel := context.WithCancel(context.Background()) + client, err := setupSubscriberClient(ctx, grpcClientConn, tp, subscriberId, startTime) if err != nil { return nil, err } subscriberClients[i] = client + subscriberCancels[i] = cancel } return &Subscriber{ subscriberClients: subscriberClients, + subscriberCancels: subscriberCancels, subscriberId: subscriberId, }, nil } -func setupSubscriberClient(grpcConnection *grpc.ClientConn, tp broker.TopicPartition, subscriberId string, startTime time.Time) (stream messaging_pb.SeaweedMessaging_SubscribeClient, err error) { - stream, err = messaging_pb.NewSeaweedMessagingClient(grpcConnection).Subscribe(context.Background()) +func setupSubscriberClient(ctx context.Context, grpcConnection *grpc.ClientConn, tp broker.TopicPartition, subscriberId string, startTime time.Time) (stream messaging_pb.SeaweedMessaging_SubscribeClient, err error) { + stream, err = messaging_pb.NewSeaweedMessagingClient(grpcConnection).Subscribe(ctx) if err != nil { return } @@ -98,3 +103,11 @@ func (s *Subscriber) Subscribe(processFn func(m *messaging_pb.Message)) { } } } + +func (s *Subscriber) Shutdown() { + for i := 0; i < len(s.subscriberClients); i++ { + if s.subscriberCancels[i] != nil { + s.subscriberCancels[i]() + } + } +}