mirror of
https://github.com/Threnklyn/esphome-dev.git
synced 2026-05-19 04:33:27 +02:00
Introduce new async-def coroutine syntax (#1657)
This commit is contained in:
+35
-131
@@ -1,24 +1,23 @@
|
||||
import functools
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
# pylint: disable=unused-import, wrong-import-order
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING # noqa
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from esphome.const import (
|
||||
CONF_ARDUINO_VERSION,
|
||||
SOURCE_FILE_EXTENSIONS,
|
||||
CONF_COMMENT,
|
||||
CONF_ESPHOME,
|
||||
CONF_USE_ADDRESS,
|
||||
CONF_ETHERNET,
|
||||
CONF_WIFI,
|
||||
SOURCE_FILE_EXTENSIONS,
|
||||
)
|
||||
from esphome.coroutine import FakeAwaitable as _FakeAwaitable
|
||||
from esphome.coroutine import FakeEventLoop as _FakeEventLoop
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from esphome.coroutine import coroutine, coroutine_with_priority # noqa
|
||||
from esphome.helpers import ensure_unique_string, is_hassio
|
||||
from esphome.util import OrderedDict
|
||||
|
||||
@@ -431,64 +430,6 @@ class Library:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def coroutine(func):
|
||||
return coroutine_with_priority(0.0)(func)
|
||||
|
||||
|
||||
def coroutine_with_priority(priority):
|
||||
def decorator(func):
|
||||
if getattr(func, "_esphome_coroutine", False):
|
||||
# If func is already a coroutine, do not re-wrap it (performance)
|
||||
return func
|
||||
|
||||
@functools.wraps(func)
|
||||
def _wrapper_generator(*args, **kwargs):
|
||||
instance_id = kwargs.pop("__esphome_coroutine_instance__")
|
||||
if not inspect.isgeneratorfunction(func):
|
||||
# If func is not a generator, return result immediately
|
||||
yield func(*args, **kwargs)
|
||||
# pylint: disable=protected-access
|
||||
CORE._remove_coroutine(instance_id)
|
||||
return
|
||||
gen = func(*args, **kwargs)
|
||||
var = None
|
||||
try:
|
||||
while True:
|
||||
var = gen.send(var)
|
||||
if inspect.isgenerator(var):
|
||||
# Yielded generator, equivalent to 'yield from'
|
||||
x = None
|
||||
for x in var:
|
||||
yield None
|
||||
# Last yield value is the result
|
||||
var = x
|
||||
else:
|
||||
yield var
|
||||
except StopIteration:
|
||||
# Stopping iteration
|
||||
yield var
|
||||
# pylint: disable=protected-access
|
||||
CORE._remove_coroutine(instance_id)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import random
|
||||
|
||||
instance_id = random.randint(0, 2 ** 32)
|
||||
kwargs["__esphome_coroutine_instance__"] = instance_id
|
||||
gen = _wrapper_generator(*args, **kwargs)
|
||||
# pylint: disable=protected-access
|
||||
CORE._add_active_coroutine(instance_id, gen)
|
||||
return gen
|
||||
|
||||
# pylint: disable=protected-access
|
||||
wrapper._esphome_coroutine = True
|
||||
wrapper.priority = priority
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def find_source_files(file):
|
||||
files = set()
|
||||
directory = os.path.abspath(os.path.dirname(file))
|
||||
@@ -527,7 +468,7 @@ class EsphomeCore:
|
||||
# The pending tasks in the task queue (mostly for C++ generation)
|
||||
# This is a priority queue (with heapq)
|
||||
# Each item is a tuple of form: (-priority, unique number, task)
|
||||
self.pending_tasks = []
|
||||
self.event_loop = _FakeEventLoop()
|
||||
# Task counter for pending tasks
|
||||
self.task_counter = 0
|
||||
# The variable cache, for each ID this holds a MockObj of the variable obj
|
||||
@@ -542,9 +483,6 @@ class EsphomeCore:
|
||||
self.build_flags: Set[str] = set()
|
||||
# A set of defines to set for the compile process in esphome/core/defines.h
|
||||
self.defines: Set["Define"] = set()
|
||||
# A dictionary of started coroutines, used to warn when a coroutine was not
|
||||
# awaited.
|
||||
self.active_coroutines: Dict[int, Any] = {}
|
||||
# A set of strings of names of loaded integrations, used to find namespace ID conflicts
|
||||
self.loaded_integrations = set()
|
||||
# A set of component IDs to track what Component subclasses are declared
|
||||
@@ -561,7 +499,7 @@ class EsphomeCore:
|
||||
self.board = None
|
||||
self.raw_config = None
|
||||
self.config = None
|
||||
self.pending_tasks = []
|
||||
self.event_loop = _FakeEventLoop()
|
||||
self.task_counter = 0
|
||||
self.variables = {}
|
||||
self.main_statements = []
|
||||
@@ -569,7 +507,6 @@ class EsphomeCore:
|
||||
self.libraries = []
|
||||
self.build_flags = set()
|
||||
self.defines = set()
|
||||
self.active_coroutines = {}
|
||||
self.loaded_integrations = set()
|
||||
self.component_ids = set()
|
||||
|
||||
@@ -596,12 +533,6 @@ class EsphomeCore:
|
||||
|
||||
return None
|
||||
|
||||
def _add_active_coroutine(self, instance_id, obj):
|
||||
self.active_coroutines[instance_id] = obj
|
||||
|
||||
def _remove_coroutine(self, instance_id):
|
||||
self.active_coroutines.pop(instance_id)
|
||||
|
||||
@property
|
||||
def arduino_version(self) -> str:
|
||||
if self.config is None:
|
||||
@@ -657,50 +588,13 @@ class EsphomeCore:
|
||||
return self.esp_platform == "ESP32"
|
||||
|
||||
def add_job(self, func, *args, **kwargs):
|
||||
coro = coroutine(func)
|
||||
task = coro(*args, **kwargs)
|
||||
item = (-coro.priority, self.task_counter, task)
|
||||
self.task_counter += 1
|
||||
heapq.heappush(self.pending_tasks, item)
|
||||
return task
|
||||
self.event_loop.add_job(func, *args, **kwargs)
|
||||
|
||||
def flush_tasks(self):
|
||||
i = 0
|
||||
while self.pending_tasks:
|
||||
i += 1
|
||||
if i > 1000000:
|
||||
raise EsphomeError("Circular dependency detected!")
|
||||
|
||||
inv_priority, num, task = heapq.heappop(self.pending_tasks)
|
||||
priority = -inv_priority
|
||||
_LOGGER.debug("Running %s (num %s)", task, num)
|
||||
try:
|
||||
next(task)
|
||||
# Decrease priority over time, so that if this task is blocked
|
||||
# due to a dependency others will clear the dependency
|
||||
# This could be improved with a less naive approach
|
||||
priority -= 1
|
||||
item = (-priority, num, task)
|
||||
heapq.heappush(self.pending_tasks, item)
|
||||
except StopIteration:
|
||||
_LOGGER.debug(" -> finished")
|
||||
|
||||
# Print not-awaited coroutines
|
||||
for obj in self.active_coroutines.values():
|
||||
_LOGGER.warning(
|
||||
"Coroutine '%s' %s was never awaited with 'yield'.", obj.__name__, obj
|
||||
)
|
||||
_LOGGER.warning("Please file a bug report with your configuration.")
|
||||
if self.active_coroutines:
|
||||
raise EsphomeError()
|
||||
if self.component_ids:
|
||||
comps = ", ".join(f"'{x}'" for x in self.component_ids)
|
||||
_LOGGER.warning(
|
||||
"Components %s were never registered. Please create a bug report", comps
|
||||
)
|
||||
_LOGGER.warning("with your configuration.")
|
||||
raise EsphomeError()
|
||||
self.active_coroutines.clear()
|
||||
try:
|
||||
self.event_loop.flush_tasks()
|
||||
except RuntimeError as e:
|
||||
raise EsphomeError(str(e)) from e
|
||||
|
||||
def add(self, expression):
|
||||
from esphome.cpp_generator import Expression, Statement, statement
|
||||
@@ -779,25 +673,35 @@ class EsphomeCore:
|
||||
_LOGGER.debug("Adding define: %s", define)
|
||||
return define
|
||||
|
||||
def get_variable(self, id):
|
||||
def _get_variable_generator(self, id):
|
||||
while True:
|
||||
try:
|
||||
return self.variables[id]
|
||||
except KeyError:
|
||||
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
|
||||
yield
|
||||
|
||||
async def get_variable(self, id) -> "MockObj":
|
||||
if not isinstance(id, ID):
|
||||
raise ValueError(f"ID {id!r} must be of type ID!")
|
||||
while True:
|
||||
if id in self.variables:
|
||||
yield self.variables[id]
|
||||
return
|
||||
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
|
||||
yield None
|
||||
# Fast path, check if already registered without awaiting
|
||||
if id in self.variables:
|
||||
return self.variables[id]
|
||||
return await _FakeAwaitable(self._get_variable_generator(id))
|
||||
|
||||
def get_variable_with_full_id(self, id):
|
||||
def _get_variable_with_full_id_generator(self, id):
|
||||
while True:
|
||||
if id in self.variables:
|
||||
for k, v in self.variables.items():
|
||||
if k == id:
|
||||
yield (k, v)
|
||||
return
|
||||
return (k, v)
|
||||
_LOGGER.debug("Waiting for variable %s", id)
|
||||
yield None, None
|
||||
yield
|
||||
|
||||
async def get_variable_with_full_id(self, id: ID) -> Tuple[ID, "MockObj"]:
|
||||
if not isinstance(id, ID):
|
||||
raise ValueError(f"ID {id!r} must be of type ID!")
|
||||
return await _FakeAwaitable(self._get_variable_with_full_id_generator(id))
|
||||
|
||||
def register_variable(self, id, obj):
|
||||
if id in self.variables:
|
||||
|
||||
Reference in New Issue
Block a user