FastAPI的WebSocket也未提供CBV方案。造个轮子
WebSocketCBV是对starlette的WebSocketEndpoint进行修改。
WS装饰器同前一篇中的装饰器。这里只提供了被动加载方案。
def WS(path: str, router: APIRouter):
def decorator(cls: Type[T]):
endpoint = getattr(cls, "endpoint")
assert endpoint, "请配置endpoint方法"
update_cbv_class_init(cls)
ws = APIWebSocketRoute(path, endpoint)
_update_endpoint_self_param(cls, ws)
router.routes.append(ws)
return cls
return decorator
class WebSocketCBV:
encoding = None
def __init__(self, websocket: WebSocket):
self.websocket = websocket
async def endpoint(self) -> None:
assert self.websocket, "请在__init__()中配置正确的websocket对象"
await self.on_connect()
# ------------------------------
close_code = status.WS_1000_NORMAL_CLOSURE
try:
while True:
message = await self.websocket.receive()
if message["type"] == "websocket.receive":
data = await self.decode(message)
await self.on_receive(data)
# ------------------------------
elif message["type"] == "websocket.disconnect":
close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE))
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
raise exc from None
finally:
await self.on_disconnect(close_code)
# ------------------------------
async def decode(self, message: Message) -> typing.Any:
if self.encoding == "text":
if "text" not in message:
await self.websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected text websocket messages, but got bytes")
return message["text"]
elif self.encoding == "bytes":
if "bytes" not in message:
await self.websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected bytes websocket messages, but got text")
return message["bytes"]
elif self.encoding == "json":
if message.get("text") is not None:
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
try:
return json.loads(text)
except json.decoder.JSONDecodeError:
await self.websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
assert (
self.encoding is None
), f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self) -> None:
"""Override to handle an incoming websocket connection"""
await self.websocket.accept()
async def on_receive(self, data: typing.Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, close_code: int) -> None:
"""Override to handle a disconnecting websocket"""
继承WebSocketCBV,且使用WS装饰器。即可实现。
参考官方的WS演示示例,这里提供出CBV版的
router = APIRouter()
@WS("/ws", router)
class WebSocketTest(WebSocketCBV):
async def on_receive(self, data: typing.Any) -> None:
await self.websocket.send_text(f"Message text was: {data}")
注意:websocket对象,被整合到了self中。如果想重写__ init __(),请记得加上。
update_cbv_class_init(cls)
与_update_endpoint_self_param(cls, ws)
,请参考前篇