新增的功能
在《简易协程-2》的基础上增加协程同步等待、IO超时的支持。
增加一个新类JoinAction支持协程同步等待,yield这个类的对象会让协程进入等待状态,直到目标协程退出或者超时。使用示例如下。
# 生成并运行另外一个协程c
c = cf1()
Scheduler.add(c)
t1 = time()
# 等待c完成,超时时间0.5秒,结果是is_timeout ,如果为True则表示等待超时了
is_timeout = yield JoinAction(c, 0.5)
IO超时的实现是在SocketIO中增加了超时时间参数,单位也是秒。如果请求的事件未能在给定时间到达,则调度器会在协程内抛出一个异常。示例如下。
# require to write data, timeout is 5s
yield SocketIO(sock.fileno(), read=False, timeout=5)
sock.send("data")
yield SocketIO(sock.fileno(), read=True, timeout=5)
data = sock.recv(1024)
完整代码
以下是详细代码。
#!/usr/bin/env python
# coding: utf-8
from collections import deque
from errno import ETIMEDOUT
from heapq import heappop
from heapq import heappush
from itertools import chain
from select import select
from socket import timeout as SocketTimeoutError
from sys import exc_info
from sys import maxint
from time import sleep
from time import time
from types import GeneratorType
class Sleep(object):
__slots__ = ["seconds", ]
def __init__(self, seconds):
# type: (float) -> object
self.seconds = seconds
# assert seconds >= 0
class SocketIO(object):
__slots__ = ["sock_fd", "read", "timeout"]
def __init__(self, sock_fd, read=True, timeout=-1):
self.sock_fd = sock_fd
self.read = read
self.timeout = timeout
class JoinAction(object):
__slots__ = ["target_coroutine", "timeout"]
def __init__(self, target_coroutine, timeout=-1.0):
# type: (GeneratorType, float) -> JoinAction
"""
:param target_coroutine: target generator
:param timeout: seconds
"""
self.target_coroutine = target_coroutine
self.timeout = timeout
class Coroutine(object):
__slots__ = ["generator", "parent", "init_value", "exception_info", "name"]
def __init__(self, generator, parent=None, init_value=None, exception_info=(), name=""):
# type: (GeneratorType, Coroutine, object, tuple) -> Coroutine
self.generator = generator
self.parent = parent
self.init_value = init_value
self.exception_info = exception_info
self.name = name
if not name:
self.name = generator.gi_code.co_name
def __str__(self):
return "%s.%s" % (self.name, self.cid())
__repr__ = __str__
def cid(self):
return id(self.generator)
def reset_input(self, value=None, exception_info=()):
self.init_value = value
self.exception_info = exception_info
def run(self):
if self.exception_info:
value = self.generator.throw(*self.exception_info)
self.exception_info = ()
else:
value = self.generator.send(self.init_value)
self.init_value = value
return value
class CoroutineError(Exception):
pass
class FakeSocket(object):
__slots__ = ["data"]
def __init__(self):
self.data = ""
def fileno(self):
return id(self)
def send(self, data):
self.data = data
return len(data)
def recv(self, _):
return "HTTP/1.1 200 OK\r\nContent-Length:0\r\n\r\n"
from random import random
next_time = {}
def fake_select(rlist, wlist, xlist, timeout):
rxlist = list(rlist)
wxlist = list(wlist)
return rxlist, wxlist, []
WAIT_CANCELED = 0
WAIT_SOCKET = 1
WAIT_JOIN = 2
WAIT_SLEEP = 3
class _TimeoutItem(object):
def __init__(self, till, wait_type, arg):
# type: (int, int, object) -> _TimeoutItem
self.till = till
self.wait_type = wait_type
self.arg = arg
self.id = id(self)
class Scheduler(object):
_instance = None
def __init__(self, ignore_exception=True, debug=False):
"""
:param debug: output running detail
:param ignore_exception: ignore coroutine's uncaught exception
"""
self.ignore_exception = ignore_exception
self.debug = debug
# if true, append debug logs to _debug_logs; else, print them to stdout
self.collect_debug_logs = False
self._debug_logs = []
# use fake_select() to test performance or simulation, work with FakeSocket
self.use_fake_select = False
#
self.start_time = time()
# map coroutine_id => coroutine
self.cid2coroutine = {}
# running queue
self.queue = deque()
# map: sock_fd -> [coroutine, timeout_item]
self.sock_map = {}
self.io_read_queue = set()
self.io_write_queue = set()
# map: coroutine_id-> waiters {waiter_coroutine_it->timeout_item, ...}, by join wait
self.waiting_map = {}
# [timeout_item, ...]
# map: millisecond (int) -> dict(item_id -> timeout_item)
self.timer_slots_map = {}
# [ms1, ms2, ...]
self.millisecond_heap = []
#
self.alive_coroutine_num = 0
# current running coroutine
self.current = None
# whether run() is calling
self.running = False
@classmethod
def get_instance(cls):
# type: () -> Scheduler
if not cls._instance:
cls._instance = cls()
return cls._instance
def _debug_output(self, msg, *args):
if self.debug:
if self.collect_debug_logs:
self._debug_logs.append(("%.6f" % time(), msg % args))
else:
print "%.6f" % time(), msg % args
else:
pass
def _add(self, generator):
co = Coroutine(generator)
cid = co.cid()
self.cid2coroutine[cid] = co
self.alive_coroutine_num += 1
self._debug_output("add new coroutine %d, alive_coroutine_num=%d",
cid, self.alive_coroutine_num)
self.queue.append(co)
return self
def _coroutine_exit(self, coroutine, is_error):
# type: (Coroutine, bool) -> object
cid = coroutine.cid()
assert cid in self.cid2coroutine
parent = coroutine.parent
if parent is None:
# wake up all waiters or cancel io wait timeout
waiters = self.waiting_map.pop(cid, None)
if waiters:
assert isinstance(waiters, dict)
# join wait
self._debug_output("%s wake up %d waiters", cid, len(waiters))
for wcid, timeout_item in waiters.iteritems():
waiter = self.cid2coroutine[wcid]
waiter.reset_input(False)
self.queue.append(waiter)
# invalid timeout_item
self.timer_slots_map[timeout_item.till].pop(timeout_item.id)
del waiters
self.alive_coroutine_num -= 1
else:
if is_error:
parent.reset_input(None, exc_info())
else:
parent.reset_input(coroutine.init_value, ())
self.queue.append(parent)
self.cid2coroutine.pop(cid)
self._debug_output("coroutine %d exited, alive_coroutine_num=%d", cid, self.alive_coroutine_num)
def _current_coroutine(self):
# type: () -> Coroutine
return self.current
@classmethod
def current_id(cls):
return cls.get_instance()._current_coroutine().cid()
@classmethod
def current_name(cls):
return cls.get_instance()._current_coroutine().name
def _add_timeout(self, seconds, wait_type, arg):
# type: (float, int, object) -> _TimeoutItem
till = int(1000 * (time() - self.start_time + seconds + 0.0005)) if seconds >= 0 else maxint
self._debug_output('coroutine add a timeout task at %sms from start', till)
timeout_item = _TimeoutItem(till, wait_type, arg)
# insert new item
if till in self.timer_slots_map:
self.timer_slots_map[till][timeout_item.id] = timeout_item
else:
self.timer_slots_map[till] = {timeout_item.id: timeout_item}
heappush(self.millisecond_heap, till)
return timeout_item
def _do_coroutine_io(self, coroutine, event):
# type: (Coroutine, SocketIO) -> object
coroutine.reset_input()
sock_fd = event.sock_fd
if event.read:
self.io_read_queue.add(sock_fd)
else:
self.io_write_queue.add(sock_fd)
timeout_item = self._add_timeout(event.timeout, WAIT_SOCKET, sock_fd)
self.sock_map[sock_fd] = [coroutine, timeout_item]
def _do_coroutine_sleep(self, coroutine, seconds):
coroutine.reset_input()
timeout_item = self._add_timeout(seconds, WAIT_SLEEP, coroutine)
self._debug_output('coroutine go to sleep until %s', timeout_item.till)
def _do_coroutine_join(self, coroutine, event):
# type: (Coroutine, JoinAction) -> None
target_cid = id(event.target_coroutine)
timeout = event.timeout
cid = coroutine.cid()
if cid == target_cid:
try:
raise CoroutineError("can't join self")
except CoroutineError:
coroutine.reset_input(None, exc_info())
self.queue.append(coroutine)
elif target_cid not in self.cid2coroutine:
# target coroutine exited, join action ends
coroutine.reset_input(False)
self.queue.append(coroutine)
elif 0 <= timeout < 0.001:
# timeout too small, so just tell coroutine he is timeout
coroutine.reset_input(True)
self.queue.append(coroutine)
else:
self._debug_output("coroutine %s try to join %s, timeout=%f",
cid, target_cid, timeout)
timeout_item = self._add_timeout(timeout, WAIT_JOIN, (cid, target_cid))
if target_cid in self.waiting_map:
self.waiting_map[target_cid][cid] = timeout_item
else:
self.waiting_map[target_cid] = {cid: timeout_item}
# noinspection PyBroadException
def _process_running_queue(self):
old_queue = self.queue
self.queue = deque()
append = self.queue.append
for coroutine in old_queue:
self.current = coroutine
# assert isinstance(coroutine, Coroutine)
try:
value = coroutine.run()
except StopIteration:
self._coroutine_exit(coroutine, False)
continue
except:
self._coroutine_exit(coroutine, True)
if coroutine.parent is None and not self.ignore_exception:
self._debug_output("%s raise uncaught exception", coroutine.cid())
raise
else:
continue
if value is None:
# yield to other coroutines
append(coroutine)
elif isinstance(value, GeneratorType):
sub = Coroutine(value, coroutine)
append(sub)
self.cid2coroutine[sub.cid()] = sub
elif isinstance(value, SocketIO):
self._do_coroutine_io(coroutine, value)
elif isinstance(value, Sleep):
self._do_coroutine_sleep(coroutine, value.seconds)
elif isinstance(value, JoinAction):
self._do_coroutine_join(coroutine, value)
else:
# this coroutine exit
self._coroutine_exit(coroutine, False)
self.current = None
def _process_sleep_queue(self):
now = time()
from_start_ms = int(1000 * (now - self.start_time))
millisecond_heap = self.millisecond_heap
while millisecond_heap:
# check recent till millisecond time
till = heappop(millisecond_heap)
# get all timeout tasks in this millisecond
item_map = self.timer_slots_map.pop(till)
if till > from_start_ms:
# there are some tasks in this millisecond, so loop ends
if item_map:
self.timer_slots_map[till] = item_map
heappush(self.millisecond_heap, till)
return min(1.0, 0.001 * (till - from_start_ms))
else:
# no task, continue to next millisecond
continue
# do time out tasks
assert isinstance(item_map, dict)
for timeout_item in item_map.itervalues():
assert isinstance(timeout_item, _TimeoutItem)
wait_type = timeout_item.wait_type
if wait_type is WAIT_CANCELED:
continue
assert timeout_item.till == till
arg = timeout_item.arg
if wait_type is WAIT_JOIN:
# join time out
waiting_cid, target_cid = arg
waiters = self.waiting_map[target_cid]
assert isinstance(waiters, dict)
self._debug_output("coroutine %s join %s time out", waiting_cid, target_cid)
del waiters[waiting_cid]
if not waiters:
del self.waiting_map[target_cid]
# wake up this waiter
waiter = self.cid2coroutine[waiting_cid]
# true: really timeout
waiter.reset_input(True)
self.queue.append(waiter)
self._debug_output("%s timeout on join", waiter)
elif wait_type is WAIT_SOCKET:
# io time out
sock_fd = arg
self._debug_output("socket %s io timeout", sock_fd)
# sock_fd already never listen for events
if sock_fd not in self.sock_map:
continue
# un-register event watch
self.io_read_queue.discard(sock_fd)
self.io_write_queue.discard(sock_fd)
# find the owner coroutine of this sock_fd
coroutine, timeout_item = self.sock_map[sock_fd]
assert isinstance(coroutine, Coroutine)
# owner maybe already exited
if coroutine.cid() not in self.cid2coroutine:
continue
try:
raise SocketTimeoutError(ETIMEDOUT, "timeout")
except SocketTimeoutError:
# raise exception to this coroutine
coroutine.reset_input(None, exc_info())
self.queue.append(coroutine)
self._debug_output("%s timeout on socket", coroutine)
else:
# sleep type, arg is sleeping coroutine. sleep is reached, so wake up this coroutine
assert wait_type is WAIT_SLEEP
assert isinstance(arg, Coroutine)
self.queue.append(arg)
self._debug_output("%s wake up from sleep", arg)
del item_map
return 0.0
def _process_io(self, sleep_seconds):
io_read_queue = self.io_read_queue
io_write_queue = self.io_write_queue
queue_append = self.queue.append
if self.use_fake_select:
rxlist, wxlist, exlist = fake_select(io_read_queue,
io_write_queue, [],
sleep_seconds)
else:
rxlist, wxlist, exlist = select(io_read_queue,
io_write_queue, [],
sleep_seconds)
# collect coroutines waiting for these sockets
io_read_queue -= set(rxlist)
io_write_queue -= set(wxlist)
if exlist:
exset = set(exlist)
io_read_queue -= exset
io_write_queue -= exset
# wake coroutines
for sock_fd in chain(rxlist, wxlist, exlist):
self._debug_output("socket %s become ready", sock_fd)
coroutine, timeout_item = self.sock_map[sock_fd]
cid = coroutine.cid()
assert cid in self.cid2coroutine
queue_append(coroutine)
# try to cancel io timeout item
assert timeout_item.wait_type is WAIT_SOCKET
self.timer_slots_map[timeout_item.till].pop(timeout_item.id)
def _run(self):
if self.running:
raise CoroutineError("already running")
self.running = True
self._debug_logs = []
io_read_queue = self.io_read_queue
io_write_queue = self.io_write_queue
# start to run all coroutines until all exited
while self.alive_coroutine_num > 0:
self._process_running_queue()
sleep_seconds = self._process_sleep_queue()
if io_read_queue or io_write_queue:
if self.queue:
# print "queue is not empty, io timeout set 0"
sleep_seconds = 0
elif sleep_seconds > 1:
sleep_seconds = 1
self._process_io(sleep_seconds)
elif sleep_seconds > 0 and not self.queue and self.millisecond_heap:
# sleep_seconds += 0.0003
self._debug_output("try to sleep for %.6fs", sleep_seconds)
sleep(sleep_seconds)
self._debug_output("wake up at %.6f", time())
# ended
self.running = False
assert not self.queue
assert not self.io_read_queue
assert not self.io_write_queue
assert not self.millisecond_heap
assert not self.timer_slots_map
assert not self.waiting_map
assert not self.cid2coroutine
assert self.current is None
self.sock_map.clear()
@classmethod
def add(cls, coroutine):
return cls.get_instance()._add(coroutine)
@classmethod
def add_many(cls, coroutine_list):
"""
add many coroutines to scheduler
:param coroutine_list: coroutine array
:return: scheduler
"""
for coroutine in coroutine_list:
cls.get_instance()._add(coroutine)
return cls.get_instance()
@classmethod
def run(cls):
return cls.get_instance()._run()
@classmethod
def set_debug(cls, debug=True, collect_logs=False):
cls.get_instance().debug = debug
cls.get_instance().collect_debug_logs = collect_logs
@classmethod
def get_debug_logs(cls):
return cls.get_instance()._debug_logs
@classmethod
def set_use_fake_select(cls, use_fake_select=True):
cls.get_instance().use_fake_select = use_fake_select
def async_urlopen(sock, url, method="GET", headers=(), data=""):
"""
async HTTP request
:param sock:
:param url:
:param method:
:param headers: (head, value) headers list
:param data:
:return response: (code, reason, headers, body)
"""
pieces = [method, ' ', url, ' HTTP/1.1\r\n', ]
for head, val in headers:
pieces.extend((head, ':', val, '\r\n'))
pieces.extend(('Content-Length:', str(len(data)), '\r\n'))
pieces.append('Connection: keep-alive\r\n\r\n')
pieces.append(data)
req_bin = ''.join(pieces)
while req_bin:
yield SocketIO(sock.fileno(), read=False)
sent = sock.send(req_bin)
req_bin = req_bin[sent:]
resp_bin = ""
resp_len = -1
code = 400
reason = "bad request"
while resp_len != len(resp_bin):
yield SocketIO(sock.fileno(), read=True)
data = sock.recv(32 << 10)
if resp_len > 0:
resp_bin += data
else:
resp_bin += data
parts = resp_bin.split('\r\n\r\n', 1)
if len(parts) != 2:
continue
head_bin, resp_bin = parts
lines = head_bin.split('\r\n')
status_line = lines[0]
version, code, reason = status_line.split(' ', 2)
code = int(code)
headers = [line.split(':', 1) for line in lines[1:-1]]
if method == 'HEAD':
break
resp_len = 0
for head, val in headers:
if head.lower() == 'content-length':
resp_len = int(val)
break
yield (code, reason, headers, resp_bin)
超时的实现原理
超时用于三个功能:休眠、IO超时、JoinAction超时。这三者具有一定的相似性,都需要计算一段时间,到达指定时间再用不同的方式处理。归结到一起就是,都需要创建一个一次性定时任务。到达指定时间后,对于休眠任务则唤醒协程,加入到可运行队列;对于IO任务,则唤醒协程,产生超时异常给协程;对于JoinAction任务,则唤醒等待的协程,并用超时结果传递给这个协程。后两者有一点不同的地方是,这两处的定时任务可能中途会被取消。如果IO及时到达,超时任务必须取消。如果目标协程及时退出,JoinAction超时任务也必须取消。
为了简化实现,超时任务的精度只取到毫秒级,这样就可以用整数来表示毫秒。
先说一下设计的主要的数据结构timer_slots_map和millisecond_heap。
millisecond_heap如名字所示,是一个毫秒整数的小根堆。毫秒数是当前时间减去进程启动时间的毫秒时间。每个数字表示这个时间段内可能存在着超时任务。示例如下。
+++++++++++++
|100|200|280|
+++++++++++++
堆中有三个时间,100、200、280,也就是说明,在99-100毫秒、199-200毫秒、279-280毫秒这些时间段可能存在超时任务。
使用堆这种数据结构,我们可以快速的得到最近的时间、快速的插入新的时间。由于堆的结构自身的高效性以及python使用c语言的实现,所以即使长度很大,添加、删除的耗时依然会很小。
接下来是timer_slots_map,这是一个稍微复杂的数据结构,功能是保存所有定时任务。这个结构的第一级是一个字典,毫秒时间映射到对应的定时任务列表。定时任务列表也是一个字典结构,每个定时任务用timeout_item表示,则列表的映射方式是id(timeout_item) -> timeout_item。
以下是一个timer_slots_map的示例。
{
100 => { id1 => timeout_item1 , id2 => timeout_item2 },
280 => { id3 => timeout_item3},
200 => {}
}
如上所示,有三个时间点有定时任务,其中100这个时间有两个任务,而200这个时间点则没有,这是个正常的现象,当定时任务取消时就会出现。
现在来说一下几个需要实现的定时任务接口:
- 增加定时任务
- 取消定时任务
- 获取时间最近的定时任务
1 增加定时任务
功能就是将定时任务timeout_item加入到队列中,定时任务包含具体的类型、参数等,这里我们只关注时间。
首先是计算时间,可以得到一个毫秒整数till。检查till是否在timer_slots_map中,如果是,则till必然已经在millisecond_heap中,否则需要追加到millisecond_heap尾部,使用heappush()自动维护堆的结构。最后就是将timeout_item插入到timer_slots_map[till]这个字典中。
timer_slots_map[till][id(timeout_item)] = timeout_item
2. 取消定时任务
输入参数timeout_item。
首先根据这个定时任务计算超时时间till,再从timer_slots_map[till]这个字典中删除timeout_item。由于我们采用timeout_item的id作为键,所以只需要用timeout_item的id删除即可。这实际上也就是要,这个删除的timeout_item必须是先前增加定时任务使用的对象。
del timer_slots_map[till][id(timeout_item)]
注意到一点,添加的时候millisecond_heap可能加入了till,但是删除的时候,却没有从millisecond_heap删除till这个时间。这么做是有原因的,堆本质是数组,从数组中间删除元素的代价是很大的。保留till在原处并不会影响多大,而且由于我们采用的是毫秒为时间,这也就限制了millisecond_heap的长度。如果采用的精确的双精度表示时间,则millisecond_heap必然会膨胀到无法承受的长度。
3. 获取时间最近的定时任务
millisecond_heap是小根堆,第一个元素就是最近的时间。使用heappop()函数可以方便的从millisecond_heap弹出首个时间,再根据这个时间去timer_slots_map查找对应的定时任务列表。
总结
使用堆和字典两个数据结构,高效而简洁的实现了定时任务。
在IO很多的时候,定时任务可能会快速增加,为了减少millisecond_heap的长度,可以将这个超时时间取整到如10毫秒甚至100毫秒。
JoinAction的实现原理
主要依赖的数据结构是waiting_map。这是一个字典结构,键是协程id,值是等待这个协程的所有协程列表,这是一个字典结构,键是协程id,值是定时任务。
示例如下。
{
c1 => { waiter1 => timeout_item1, waiter2 => timeout_item2 },
c2 => { waiter3 => timeout_item3, waiter4 => timeout_item4 }
}
waiter1 和 waiter2 都在等待协程c1,并分别设有超时任务。
当协程c1退出时,遍历c1对应的等待列表,唤醒所有等待协程,删除超时 任务。