Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure curdoc lookups work in async context #3776

Merged
merged 6 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions panel/io/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from bokeh.document.events import DocumentChangedEvent, ModelChangedEvent

from .model import monkeypatch_events
from .state import curdoc_locked, set_curdoc, state
from .state import curdoc_locked, state

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,8 +72,7 @@ def init_doc(doc: Optional[Document]) -> Document:

thread = threading.current_thread()
if thread:
with set_curdoc(curdoc):
state._thread_id = thread.ident
state._thread_id_[curdoc] = thread.ident

session_id = curdoc.session_context.id
sessions = state.session_info['sessions']
Expand Down
6 changes: 5 additions & 1 deletion panel/io/jupyter_server_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .resources import (
DIST_DIR, ERROR_TEMPLATE, Resources, _env,
)
from .server import server_html_page_for_session
from .server import _add_task_factory, server_html_page_for_session
from .state import set_curdoc, state

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,6 +218,10 @@ def _create_server_session(self) -> ServerSession:
app.initialize_document(doc)

loop = tornado.ioloop.IOLoop.current()
try:
_add_task_factory(loop.asyncio_loop) # type: ignore
except Exception:
pass
session = ServerSession(self.session_id, doc, io_loop=loop, token=self.token)
session_context._set_session(session)
return session
Expand Down
2 changes: 2 additions & 0 deletions panel/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def diff(
return None

patch_events = [event for event in events if isinstance(event, DocumentPatchedEvent)]
if not patch_events:
return
monkeypatch_events(events)
msg_type: Literal["PATCH-DOC"] = "PATCH-DOC"
msg = Protocol().create(msg_type, patch_events, use_buffers=binary)
Expand Down
50 changes: 42 additions & 8 deletions panel/io/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from bokeh.embed.util import RenderItem
from bokeh.io import curdoc
from bokeh.server.server import Server
from bokeh.server.server import Server as BokehServer
from bokeh.server.urls import per_app_patterns
from bokeh.server.views.autoload_js_handler import (
AutoloadJsHandler as BkAutoloadJsHandler,
Expand Down Expand Up @@ -205,7 +205,6 @@ def autoload_js_script(doc, resources, token, element_id, app_path, absolute_url

return AUTOLOAD_JS.render(bundle=bundle, elementid=element_id)


def destroy_document(self, session):
"""
Override for Document.destroy() without calling gc.collect directly.
Expand Down Expand Up @@ -240,6 +239,18 @@ def destroy_document(self, session):
state.schedule_task('gc.collect', gc.collect, at=at)


# Patch Srrver to attach task factory to asyncio loop
class Server(BokehServer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
_add_task_factory(self.io_loop.asyncio_loop) # type: ignore
except Exception:
pass

bokeh.server.server.Server = Server

# Patch Application to handle session callbacks
class Application(BkApplication):

Expand Down Expand Up @@ -339,15 +350,12 @@ async def get(self, *args, **kwargs) -> None:

with self._session_prefix():
session = await self.get_session()
state.curdoc = session.document
try:
with set_curdoc(session.document):
resources = Resources.from_bokeh(self.application.resources(server_url))
js = autoload_js_script(
session.document, resources, session.token, element_id,
app_path, absolute_url
)
finally:
state.curdoc = None

self.set_header("Content-Type", 'application/javascript')
self.write(js)
Expand Down Expand Up @@ -557,6 +565,10 @@ def create_static_handler(prefix, key, app):

bokeh.server.tornado.create_static_handler = create_static_handler

#---------------------------------------------------------------------
# Async patches
#---------------------------------------------------------------------

# Bokeh 2.4.x patches the asyncio event loop policy but Tornado 6.1
# support the WindowsProactorEventLoopPolicy so we restore it,
# unless we detect we are running on jupyter_server.
Expand All @@ -571,6 +583,28 @@ def create_static_handler(prefix, key, app):
):
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())

def _add_task_factory(loop):
"""
Adds a Task factory to the asyncio IOLoop that ensures child tasks
have access to their parent.
"""
if getattr(loop, '_has_panel_task_factory', False):
return
existing_factory = loop.get_task_factory()
def task_factory(loop, coro):
try:
parent_task = asyncio.current_task()
except RuntimeError:
parent_task = None
if existing_factory:
task = existing_factory(loop, coro)
else:
task = asyncio.Task(coro, loop=loop)
task.parent_task = parent_task
return task
loop.set_task_factory(task_factory)
loop._has_panel_task_factory = True

#---------------------------------------------------------------------
# Public API
#---------------------------------------------------------------------
Expand All @@ -581,7 +615,7 @@ def serve(
loop: Optional[IOLoop] = None, show: bool = True, start: bool = True,
title: Optional[str] = None, verbose: bool = True, location: bool = True,
threaded: bool = False, **kwargs
) -> threading.Thread | 'Server':
) -> threading.Thread | Server:
"""
Allows serving one or more panel objects on a single server.
The panels argument should be either a Panel object or a function
Expand Down Expand Up @@ -756,7 +790,7 @@ def get_server(

Returns
-------
server : bokeh.server.server.Server
server : panel.io.server.Server
Bokeh Server instance running this panel
"""
from ..config import config
Expand Down
61 changes: 48 additions & 13 deletions panel/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
@contextmanager
def set_curdoc(doc: Document):
orig_doc = state._curdoc
state.curdoc = doc
state._curdoc = doc
yield
state._curdoc = orig_doc

Expand Down Expand Up @@ -113,7 +113,7 @@ class _state(param.Parameterized):
A dictionary used by the cache decorator.""")

# Holds temporary curdoc overrides per thread
_curdoc_ = {}
_curdoc_ = defaultdict(WeakKeyDictionary)

# Whether to hold comm events
_hold: ClassVar[bool] = False
Expand Down Expand Up @@ -325,10 +325,12 @@ def _on_load(self, doc: Optional[Document] = None) -> None:
from .profile import profile_ctx
with set_curdoc(doc):
if (doc and doc in self._launching) or not config.profiler:
for cb in callbacks: cb()
for cb in callbacks:
self.execute(cb, schedule=False)
return
with profile_ctx(config.profiler) as sessions:
for cb in callbacks: cb()
for cb in callbacks:
self.execute(cb, schedule=False)
path = doc.session_context.request.path
self._profiles[(path+':on_load', config.profiler)] += sessions
self.param.trigger('_profiles')
Expand Down Expand Up @@ -565,15 +567,20 @@ def log(self, msg: str, level: str = 'info') -> None:
msg = LOG_USER_MSG.format(msg=msg)
getattr(_state_logger, level.lower())(msg, *args)

def onload(self, callback):
def onload(self, callback: Callable[[], None] | Coroutine[Any, Any, None]):
"""
Callback that is triggered when a session has been served.

Arguments
---------
callback: Callable[[], None] | Coroutine[Any, Any, None]
Callback that is executed when the application is loaded
"""
if self.curdoc is None:
if self._thread_pool:
self._thread_pool.submit(callback)
self._thread_pool.submit(partial(self.execute, callback, schedule=False))
else:
callback()
self.execute(callback, schedule=False)
return
if self.curdoc not in self._onload:
self._onload[self.curdoc] = []
Expand Down Expand Up @@ -802,23 +809,51 @@ def curdoc(self, doc: Document) -> None:
def _curdoc(self) -> Document | None:
"""
Required to make overrides to curdoc (e.g. using the
set_curdoc context manager) thread-safe. Otherwise two threads
may independently override the curdoc and end up in a confused
final state.
set_curdoc context manager) thread-safe and asyncio task
local. Otherwise two threads may independently override the
curdoc and end up in a confused final state.
"""
thread = threading.current_thread()
thread_id = thread.ident if thread else None
return self._curdoc_.get(thread_id)
if thread_id not in self._curdoc_:
return None
curdocs = self._curdoc_[thread_id]
try:
task = asyncio.current_task()
except Exception:
task = None
while True:
if task in curdocs:
return curdocs[task or self]
elif task is None:
break
try:
task = task.parent_task
except Exception:
task = None
return curdocs[self] if self in curdocs else None

@_curdoc.setter
def _curdoc(self, doc: Document | None) -> None:
thread = threading.current_thread()
thread_id = thread.ident if thread else None
if thread_id not in self._curdoc_ and doc is None:
return None
curdocs = self._curdoc_[thread_id]
try:
task = asyncio.current_task()
except Exception:
task = None
key = task or self
if doc is None:
if thread_id in self._curdoc_:
# Do not clean up curdocs for tasks since they may have
# children that are still running
if key in curdocs and task is None:
del curdocs[key]
if not len(curdocs):
del self._curdoc_[thread_id]
else:
self._curdoc_[thread_id] = doc
curdocs[key] = doc

@property
def cookies(self) -> Dict[str, str]:
Expand Down
1 change: 1 addition & 0 deletions panel/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def server_cleanup():
state._curdoc = None
state.cache.clear()
state._scheduled.clear()
state._curdoc_.clear()
if state._thread_pool is not None:
state._thread_pool.shutdown(wait=False)
state._thread_pool = None
Expand Down
Loading