Source code for mirai_core.updater

import asyncio
from typing import DefaultDict, Union, List, Callable, Any, Awaitable
from collections import defaultdict
from dataclasses import dataclass
import signal
from .log import create_logger, install_logger
from .bot import Bot
from .models.Event import BaseEvent, Events
from .exceptions import SessionException, NetworkException, AuthenticationException, ServerException


[docs]class Updater: def __init__(self, bot: Bot, use_websocket: bool = True): """ Initialize Updater :param bot: the Bot object to use :param use_websocket: bool. whether websocket (recommended) should be used """ self.bot = bot self.loop = bot.loop self.logger = create_logger('Updater') self.event_handlers: DefaultDict[Events, List[EventHandler]] = defaultdict(lambda: list()) self.use_websocket = use_websocket
[docs] async def run_task(self, shutdown_hook: callable = None): """ return awaitable coroutine to run in event loop (must be the same loop as bot object) :param shutdown_hook: callable, if running in main thread, this must be set. Trigger is called on shutdown """ self.logger.debug('Run tasks') tasks = [ self.handshake() ] if not self.use_websocket: tasks.append(self.message_polling()) if shutdown_hook: tasks.append(self.raise_shutdown(shutdown_hook)) await asyncio.wait(tasks)
[docs] def add_handler(self, event: Union[Events, List[Events]]): """ Decorator for event listeners Catch all is not supported at this time :param event: events.Events """ def receiver_wrapper(func): if not asyncio.iscoroutinefunction(func): raise TypeError("event body must be a coroutine function.") # save function and its parameter types event_handler = EventHandler(func) nonlocal event if not isinstance(event, list): event = [event] for e in event: if e in Events.__args__: if e.__name__ == 'Message': self.event_handlers['GroupMessage'].append(event_handler) self.event_handlers['FriendMessage'].append(event_handler) self.event_handlers['TempMessage'].append(event_handler) else: self.event_handlers[e.__name__].append(event_handler) return func return receiver_wrapper
[docs] def run(self, log_to_stderr=True) -> None: """ Start the Updater and block the thread :param log_to_stderr: if you are setting other loggers that capture the log from this Library, set to False """ asyncio.set_event_loop(self.loop) self.loop.set_exception_handler(self.handle_exception) shutdown_event = asyncio.Event() def _signal_handler(*_: Any) -> None: shutdown_event.set() try: self.loop.add_signal_handler(signal.SIGTERM, _signal_handler) self.loop.add_signal_handler(signal.SIGINT, _signal_handler) except (AttributeError, NotImplementedError): pass if log_to_stderr: install_logger() self.loop.create_task(self.run_task(shutdown_hook=shutdown_event.wait)) self.loop.run_forever()
[docs] async def handshake(self): """ Internal use only, automatic handshake Called when launch or websocket disconnects :return: """ while True: try: await self.bot.handshake() if self.use_websocket: asyncio.run_coroutine_threadsafe( self.bot.create_websocket(self.event_caller, self.handshake), self.loop) return True except NetworkException: self.logger.warning('Unable to communicate with Mirai console, retrying in 5 seconds') except Exception as e: self.logger.exception(f'retrying in 5 seconds') await asyncio.sleep(5)
[docs] async def message_polling(self, count=5, interval=0.5) -> None: """ Internal use only, polling message and fire events :param count: maximum message count for each polling :param interval: minimum interval between two polling """ while True: await asyncio.sleep(interval) try: results: List[BaseEvent] = await self.bot.fetch_message(count) if len(results) > 0: self.logger.debug('Received messages:\n' + '\n'.join([str(result) for result in results])) for result in results: asyncio.run_coroutine_threadsafe(self.event_caller(result), self.loop) except Exception as e: self.logger.warning(f'{e}, new handshake initiated') await self.handshake()
[docs] async def event_caller(self, event: BaseEvent) -> None: """ Internal use only, call the event handlers sequentially :param event: the event """ for handler in self.event_handlers[event.type]: if await handler.func(event): # if the function returns True, stop calling next event break
[docs] async def raise_shutdown(self, shutdown_event: Callable[..., Awaitable[None]]) -> None: """ Internal use only, shutdown :param shutdown_event: callable """ await shutdown_event() await self.bot.release() raise Shutdown()
[docs] def handle_exception(self, loop, context): # context["message"] will always be there; but context["exception"] may not msg = context.get("exception", context["message"]) self.logger.exception('Unhandled exception: ', exc_info=msg)
[docs]@dataclass class EventHandler: """ Contains the callback function """ func: Callable
[docs]class Shutdown(Exception): """ Internal use only Shutdown BaseEvent """ pass