Introduce new async-def coroutine syntax (#1657)

This commit is contained in:
Otto Winter
2021-05-17 07:14:15 +02:00
committed by GitHub
parent 95ed3e9d46
commit d4686c0fb1
10 changed files with 391 additions and 238 deletions
+35 -131
View File
@@ -1,24 +1,23 @@
import functools
import heapq
import inspect
import logging
import math
import os
import re
# pylint: disable=unused-import, wrong-import-order
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING # noqa
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from esphome.const import (
CONF_ARDUINO_VERSION,
SOURCE_FILE_EXTENSIONS,
CONF_COMMENT,
CONF_ESPHOME,
CONF_USE_ADDRESS,
CONF_ETHERNET,
CONF_WIFI,
SOURCE_FILE_EXTENSIONS,
)
from esphome.coroutine import FakeAwaitable as _FakeAwaitable
from esphome.coroutine import FakeEventLoop as _FakeEventLoop
# pylint: disable=unused-import
from esphome.coroutine import coroutine, coroutine_with_priority # noqa
from esphome.helpers import ensure_unique_string, is_hassio
from esphome.util import OrderedDict
@@ -431,64 +430,6 @@ class Library:
return NotImplemented
def coroutine(func):
return coroutine_with_priority(0.0)(func)
def coroutine_with_priority(priority):
def decorator(func):
if getattr(func, "_esphome_coroutine", False):
# If func is already a coroutine, do not re-wrap it (performance)
return func
@functools.wraps(func)
def _wrapper_generator(*args, **kwargs):
instance_id = kwargs.pop("__esphome_coroutine_instance__")
if not inspect.isgeneratorfunction(func):
# If func is not a generator, return result immediately
yield func(*args, **kwargs)
# pylint: disable=protected-access
CORE._remove_coroutine(instance_id)
return
gen = func(*args, **kwargs)
var = None
try:
while True:
var = gen.send(var)
if inspect.isgenerator(var):
# Yielded generator, equivalent to 'yield from'
x = None
for x in var:
yield None
# Last yield value is the result
var = x
else:
yield var
except StopIteration:
# Stopping iteration
yield var
# pylint: disable=protected-access
CORE._remove_coroutine(instance_id)
@functools.wraps(func)
def wrapper(*args, **kwargs):
import random
instance_id = random.randint(0, 2 ** 32)
kwargs["__esphome_coroutine_instance__"] = instance_id
gen = _wrapper_generator(*args, **kwargs)
# pylint: disable=protected-access
CORE._add_active_coroutine(instance_id, gen)
return gen
# pylint: disable=protected-access
wrapper._esphome_coroutine = True
wrapper.priority = priority
return wrapper
return decorator
def find_source_files(file):
files = set()
directory = os.path.abspath(os.path.dirname(file))
@@ -527,7 +468,7 @@ class EsphomeCore:
# The pending tasks in the task queue (mostly for C++ generation)
# This is a priority queue (with heapq)
# Each item is a tuple of form: (-priority, unique number, task)
self.pending_tasks = []
self.event_loop = _FakeEventLoop()
# Task counter for pending tasks
self.task_counter = 0
# The variable cache, for each ID this holds a MockObj of the variable obj
@@ -542,9 +483,6 @@ class EsphomeCore:
self.build_flags: Set[str] = set()
# A set of defines to set for the compile process in esphome/core/defines.h
self.defines: Set["Define"] = set()
# A dictionary of started coroutines, used to warn when a coroutine was not
# awaited.
self.active_coroutines: Dict[int, Any] = {}
# A set of strings of names of loaded integrations, used to find namespace ID conflicts
self.loaded_integrations = set()
# A set of component IDs to track what Component subclasses are declared
@@ -561,7 +499,7 @@ class EsphomeCore:
self.board = None
self.raw_config = None
self.config = None
self.pending_tasks = []
self.event_loop = _FakeEventLoop()
self.task_counter = 0
self.variables = {}
self.main_statements = []
@@ -569,7 +507,6 @@ class EsphomeCore:
self.libraries = []
self.build_flags = set()
self.defines = set()
self.active_coroutines = {}
self.loaded_integrations = set()
self.component_ids = set()
@@ -596,12 +533,6 @@ class EsphomeCore:
return None
def _add_active_coroutine(self, instance_id, obj):
self.active_coroutines[instance_id] = obj
def _remove_coroutine(self, instance_id):
self.active_coroutines.pop(instance_id)
@property
def arduino_version(self) -> str:
if self.config is None:
@@ -657,50 +588,13 @@ class EsphomeCore:
return self.esp_platform == "ESP32"
def add_job(self, func, *args, **kwargs):
coro = coroutine(func)
task = coro(*args, **kwargs)
item = (-coro.priority, self.task_counter, task)
self.task_counter += 1
heapq.heappush(self.pending_tasks, item)
return task
self.event_loop.add_job(func, *args, **kwargs)
def flush_tasks(self):
i = 0
while self.pending_tasks:
i += 1
if i > 1000000:
raise EsphomeError("Circular dependency detected!")
inv_priority, num, task = heapq.heappop(self.pending_tasks)
priority = -inv_priority
_LOGGER.debug("Running %s (num %s)", task, num)
try:
next(task)
# Decrease priority over time, so that if this task is blocked
# due to a dependency others will clear the dependency
# This could be improved with a less naive approach
priority -= 1
item = (-priority, num, task)
heapq.heappush(self.pending_tasks, item)
except StopIteration:
_LOGGER.debug(" -> finished")
# Print not-awaited coroutines
for obj in self.active_coroutines.values():
_LOGGER.warning(
"Coroutine '%s' %s was never awaited with 'yield'.", obj.__name__, obj
)
_LOGGER.warning("Please file a bug report with your configuration.")
if self.active_coroutines:
raise EsphomeError()
if self.component_ids:
comps = ", ".join(f"'{x}'" for x in self.component_ids)
_LOGGER.warning(
"Components %s were never registered. Please create a bug report", comps
)
_LOGGER.warning("with your configuration.")
raise EsphomeError()
self.active_coroutines.clear()
try:
self.event_loop.flush_tasks()
except RuntimeError as e:
raise EsphomeError(str(e)) from e
def add(self, expression):
from esphome.cpp_generator import Expression, Statement, statement
@@ -779,25 +673,35 @@ class EsphomeCore:
_LOGGER.debug("Adding define: %s", define)
return define
def get_variable(self, id):
def _get_variable_generator(self, id):
while True:
try:
return self.variables[id]
except KeyError:
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
yield
async def get_variable(self, id) -> "MockObj":
if not isinstance(id, ID):
raise ValueError(f"ID {id!r} must be of type ID!")
while True:
if id in self.variables:
yield self.variables[id]
return
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
yield None
# Fast path, check if already registered without awaiting
if id in self.variables:
return self.variables[id]
return await _FakeAwaitable(self._get_variable_generator(id))
def get_variable_with_full_id(self, id):
def _get_variable_with_full_id_generator(self, id):
while True:
if id in self.variables:
for k, v in self.variables.items():
if k == id:
yield (k, v)
return
return (k, v)
_LOGGER.debug("Waiting for variable %s", id)
yield None, None
yield
async def get_variable_with_full_id(self, id: ID) -> Tuple[ID, "MockObj"]:
if not isinstance(id, ID):
raise ValueError(f"ID {id!r} must be of type ID!")
return await _FakeAwaitable(self._get_variable_with_full_id_generator(id))
def register_variable(self, id, obj):
if id in self.variables: