TCP并发的简单实现

之前的代码中,用一个循环开始接受连接,这样的代码只能接受一个连接,在接收连接之后,立刻跳转到与当前socket的通信过程,通信结束以后,才会跳回原来的大循环继续阻塞住等待连接.
显然这样是无法接受并发的,而现实中上网的各种服务器,都可以并发,要在python中实现tcp的并发,需要用到socketserver模块.
先来写一个简单的并发server再来了解原理.其实也能想到,就是每次有连接过来,实例化某个可以进行连接和通信的对象去做这个事情就可以了.再有连接过来,再实例化对象.

# server
import socketserver

buffer_size = 1024
ip_port = ('127.0.0.1', 8080,)


class MyServer(socketserver.BaseRequestHandler):
    def handle(self):  # 这是一个收发消息的通信循环,还需要一个连接循环
        print('conn is:', self.request)  # conn
        print('address is:', self.client_address)  # address

        while True:
            try:
                # 收消息
                data = self.request.recv(buffer_size)
                if not data:
                    break
                print('收到的消息是:', data.decode('utf-8'))
                # 发消息
                self.request.sendall(data.upper())
            except Exception as e:
                print(e)
                break

if __name__ == '__main__':
    s = socketserver.ThreadingTCPServer(ip_port, MyServer) # 服务端对象,
    s.serve_forever()   # 服务端对象无限循环,相当于连接循环

这样的服务端可以同时接受多个TCP连接.socketserver的原理是什么呢,先来看一看socketserver的工作流程:

socketserver源码分析-TCP并发部分

如果遇到一个模块想要搞清楚结构,可以通过在Pycharm内ctrl点击模块名称,进入到模块的源代码,然后可以用Pycharm的UML类图功能来查看这个文件的组织形式.
进入:在当前文件右键-》Diagrams-》Show Diagrams-》Python Class Diagrams. 点击UML图界面上方的m图片可以显示成员函数,点击f图标可以显示成员变量.
socketserver的类图如下:
socketserver类图
可以看到大体的设计思路:
即有一个Baseserver,然后发展出TCP server,之后UDP server也继承TCPserver,但是可以将其当做一个主要分支.
之后的四个具体服务:ThreadingTCPserver ThreadingUnixStreamServer ThreadingUnixDatagramServer和ThreadingUDPserver看名字就知道,分别是多线程的TCP,UDP的网络服务器和多线程的Unix文件的TCP和UDP服务器.凡是多线程的(名字里有Threading的)都会继承一个ThreadingMixIn类.
还有两个ForkingTCP和UDPserver,这些是多进程的服务,都继承一个ForkingMixIn类.
然后上边的这些服务类,最终继承到TCP和UDP server类.关系还不算太复杂,知道了需要哪种服务器,就通过哪个类进行实例化.

这些服务类,相当于我们自己写的服务器里边的产生连接,具体负责通信的,还有一个类和两个继承类.
DatagramRequestHandler 和 StreamRequestHandler 从名字可以看出来是分别用于UDP和TCP通信,继承BaseRequestHandler类.
来分析我们的代码:
MyServer(socketserver.BaseRequestHandler)可以看出,是继承了BaseRequestHandler类,看一看这个类的代码:

class BaseRequestHandler:

    """Base class for request handler classes.

    This class is instantiated for each request to be handled.  The
    constructor sets the instance variables request, client_address
    and server, and then calls the handle() method.  To implement a
    specific service, all you need to do is to derive a class which
    defines a handle() method.
    这个类用于处理每个request的通信.构建类的时候需要request,客户地址和server对象 ,然后调用handle()方法.想完成一个特定的服务,需要继承此类然后定义一个handle()方法.
    The handle() method can find the request as self.request, the
    client address as self.client_address, and the server (in case it
    needs access to per-server information) as self.server.  Since a
    separate instance is created for each request, the handle() method
    can define other arbitrary instance variables.
    handle()方法里通过self.request找到request.由于针对每个request都生成独立的对象,handle()方法里还可以定义任意的变量.
    """

    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.server = server
        self.setup()
        try:
            self.handle()
        finally:
            self.finish()

    def setup(self):
        pass

    def handle(self):
        pass

    def finish(self):
        pass

从类图上看到有两个类继承了这个类,点进去发现,这两个类也没有实现handle()方法,这就是为什么我们要来写handle()方法的原因.那么BaseRequestHandler初始化的过程中,传入的request, client_address, server这三个参数究竟是什么呢?
从我们自己写的收和发消息的语句也能猜出来,request很可能是一个socket对象.

继续往下看这一行s = socketserver.ThreadingTCPServer(ip_port, MyServer).
这一行从名字可以看出,实例化了一个多线程TCPserver对象,参数是ip_port和Myserver,就是我们刚刚定义的处理消息的对象,这里既然是处理连接,猜想估计是把我们的myserver实例化了之后,去处理一个新连接生成的socket对象的通信.
查看ThreadingTCPServer类的代码,结果发现,只有一行代码class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass,类里面什么都没有.说明继承自ThreadingMixIn和 TCPServer类,到两个父类里继续寻找__init__方法,到ThreadingMixIn类内发现代码如下:

class ThreadingMixIn:
    """Mix-in class to handle each request in a new thread."""

    # Decides how threads will act upon termination of the
    # main process
    daemon_threads = False

    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.

        In addition, exception handling is done here.

        """
        try:
            self.finish_request(request, client_address)
        except Exception:
            self.handle_error(request, client_address)
        finally:
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads
        t.start()

没有该方法,查看TCPserver类,发现其中有__init__方法,而且TCPserver类继承了BaseServer类:

class TCPServer(BaseServer):
    address_family = socket.AF_INET  # 地址家族

    socket_type = socket.SOCK_STREAM # socket协议

    request_queue_size = 5 # backlog半连接池大小

    allow_reuse_address = False # 是否允许地址重用

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
        """Constructor.  May be extended, do not override."""
        BaseServer.__init__(self, server_address, RequestHandlerClass)
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)
        if bind_and_activate:
            try:
                self.server_bind()
                self.server_activate()
            except:
                self.server_close()
                raise

在类属性里看到了一些很熟悉的东西,由于socketserver模块最开始就导入了socket模块,所以这里看到了几个之前用过的变量,地址家族就是AF_INET,socket类型就是TCP,request_queue_size就是backlog,而允许重用地址在之前也遇到过.
init方法里,依次传入的是实例,服务地址,RequestHandlerClass类,还有一个默认值参数bind_and_activate=True.
实例不用多说,就是s自己,server_address传入的是ip_port元组,和之前程序内的一样,是IP和端口的元组.MyServer就是我们继承自BaseRequestHandler的类,符合RequestHandlerClass这个参数的要求(传入一个RequestHandler的类).结果发现这个初始化函数,又调用了BaseServer的init方法,再追上去看一看:

class BaseServer:
    def __init__(self, server_address, RequestHandlerClass):
        """Constructor.  May be extended, do not override."""
        self.server_address = server_address
        self.RequestHandlerClass = RequestHandlerClass
        self.__is_shut_down = threading.Event()
        self.__shutdown_request = False

这里就很清楚了,又出来了几个类属性,server_address就是传入的ip+端口,RequestHandlerClass指向了传入的类.后边先不管,然后跳回到TCPServer的init函数下一行,啊,发现了熟悉的东西:
self.socket = socket.socket(self.address_family,self.socket_type)
类属性socket就是实例化的一个socket对象,和我们自己生成的socket对象一样.
之后立刻去尝试使用self.server_bind()和 self.server_activate()两个方法.
追到TCPServer的server_bind()和server_activate()方法:

class TCPServer(BaseServer):
    def server_bind(self):
        """Called by constructor to bind the socket.

        May be overridden.

        """
        if self.allow_reuse_address:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)
        self.server_address = self.socket.getsockname()

    def server_activate(self):
        """Called by constructor to activate the server.

        May be overridden.

        """
        self.socket.listen(self.request_queue_size)

又是熟悉的语句,其实就是拿我们传入的ip_port做了和自行编写代码socket.bind()一样的事情,然后用sockname()方法的结果更新server_address.这里实际上就是绑定IP和端口.
然后server_activate()做了一件也是我们编写过代码的事情,就是listen.
好,现在回头再看 s = socketserver.ThreadingTCPServer(ip_port, MyServer)
这句执行完毕之后,s这个对象下边有若干属性,其中一个叫socket的属性,就是一个已经阻塞在listen状态的,按照我们传入的ip_port参数建立的socket对象.还有我们传入的MyServer类起什么作用呢?
再看下一句:
s.serve_forever() 要找到serve_forever方法,ThreadingTCPServer类里没有,继续找ThreadingMixIn,还是没有,再找TCPServer类,还是没有,最后在BaseServer类里找到了,如下:

# s.serve_forever()
class BaseServer:
    def serve_forever(self, poll_interval=0.5):
        """Handle one request at a time until shutdown.

        Polls for shutdown every poll_interval seconds. Ignores
        self.timeout. If you need to do periodic tasks, do them in
        another thread.
        """
        self.__is_shut_down.clear()
        try:
            # XXX: Consider using another file descriptor or connecting to the
            # socket to wake this up instead of polling. Polling reduces our
            # responsiveness to a shutdown request and wastes cpu at all other
            # times.
            with _ServerSelector() as selector:
                selector.register(self, selectors.EVENT_READ)

                while not self.__shutdown_request:
                    ready = selector.select(poll_interval)
                    if ready:
                        self._handle_request_noblock()

                    self.service_actions()
        finally:
            self.__shutdown_request = False
            self.__is_shut_down.set()

先不看和selector相关的语句,看while内的语句,从字面看,如果不关闭,就一直有一个循环.其实就是执行_handle_request_noblock()方法,一路寻找,两个父类都没有,最后还是在BaseServer里找到_handle_request_noblock()这个方法:

# _handle_request_noblock()
    def _handle_request_noblock(self):
        """Handle one request, without blocking.

        I assume that selector.select() has returned that the socket is
        readable before this function was called, so there should be no risk of
        blocking in get_request().
        """
        try:
            request, client_address = self.get_request()
        except OSError:
            return
        if self.verify_request(request, client_address):
            try:
                self.process_request(request, client_address)
            except Exception:
                self.handle_error(request, client_address)
                self.shutdown_request(request)
            except:
                self.shutdown_request(request)
                raise
        else:
            self.shutdown_request(request)

这个时候发现第一句是request, client_address = self.get_request(),继续再找get_request()方法,这次是在TCPServer类里找到了:

    def get_request(self):
        """Get the request and client address from the socket.

        May be overridden.

        """
        return self.socket.accept()

结果发现,这不就是socket.accept()么,那么上边的request, client_address = self.get_request()这句得到的request就是已经三次握手之后的socket连接,client_address就是字面上的客户地址.

回到_handle_request_noblock()里边继续,这个时候已经知道了request的内容.
看后边的代码,这里先跳过self.verify_request这步验证,然后看self.process_request(request, client_address)这一行,继续找这个方法,最后在ThreadingMixIn类内找到:

class ThreadingMixIn:
    daemon_threads = False

    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.

        In addition, exception handling is done here.

        """
        try:
            self.finish_request(request, client_address)
        except Exception:
            self.handle_error(request, client_address)
        finally:
            self.shutdown_request(request)


    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads
        t.start()

这个地方就是第一个核心步骤了,虽然还没有学,但是可以看到,threading模块内的process_request方法,用一个新的线程去调用self.process_request_thread的方法(也在ThreadingMixIn类中),参数是request(建立三次握手的TCP连接对象和客户地址).然后到process_request_thread(self, request, client_address)方法里一看,(从这里往后,是在一个新的线程里执行了)有一个self.finish_request(request, client_address)方法,继续寻找,结果在BaseServer里找到:

class BaseServer:
    def finish_request(self, request, client_address):
        """Finish one request by instantiating RequestHandlerClass."""
        self.RequestHandlerClass(request, client_address, self)

找了这么久,终于发现了第二个核心步骤,就是用到了RequestHandlerClass类,也就是我们传入的自定义的MyServer类.finish_request方法做的事情很简单,就是实例化了一个MyServer类.
回到我们的MyServer类,还记得吗BaseServer初始化的那几个参数吗:

class BaseServer:
    def __init__(self, server_address, RequestHandlerClass):
        """Constructor.  May be extended, do not override."""
        self.server_address = server_address
        self.RequestHandlerClass = RequestHandlerClass
        self.__is_shut_down = threading.Event()
        self.__shutdown_request = False

这里的self.RequestHandlerClass就是传入的MyServer,然后这里用了request, client_address, self的顺序去实例化了MyServer.我们的MyServer是继承了BaseRequestHandler类的,所以看看BaseRequestHandler的初始化函数:

class BaseRequestHandler:

    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.server = server
        self.setup()
        try:
            self.handle()
        finally:
            self.finish()

一切都明白了,request就是一个已经建立了三次握手的TCP连接对象,地址就是这个TCP连接对象的客户地址,而server,就是s对象.
可见为什么能够实现TCP并发,等待循环的过程还是不变的,s对象产生以后,就生成了s.socket这一个对象,然后通过serve_forever方法,不断的试验TCP连接的生成情况,如果生成了三次握手的TCP连接对象,就立刻通过调用process_request_thread方法新开一个线程,把这个对象连同实例化的处理通信的对象丢到新的线程里去处理.之后就在serve_forever里继续.
为什么要定义handle,就是因为init里会立刻调用handle,而且自己想要的通信判断等逻辑,要写在自己的handle方法里.
还一个需要补充的是,如果初始化连接服务对象的时候不采用ThreadingTCPServer(多线程)而是采用ForkingTCPServer(多进程)的话,服务一样可以运行(windows下会因为os模块没有fork而失败),只不过多线程的开销比多进程的要低.后边学习并发的时候会了解.通过继承关系也能看到,多进程的关键就是ForkingMixIn类.