Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure refreshed tokens can be accessed across processes #6817

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
with:
name: unit_test_suite
python-version: ${{ matrix.python-version }}
channels: pyviz/label/dev,numba,bokeh/label/dev,conda-forge,nodefaults
channels: pyviz/label/dev,numba,conda-forge,nodefaults
conda-update: true
nodejs: true
nodejs-version: "20.9" # https://github.com/bokeh/bokeh/pull/13851
Expand Down Expand Up @@ -233,7 +233,7 @@ jobs:
with:
name: ui_test_suite
python-version: 3.9
channels: pyviz/label/dev,bokeh/label/dev,conda-forge,nodefaults
channels: pyviz/label/dev,conda-forge,nodefaults
envs: "-o recommended -o tests -o build"
cache: ${{ github.event.inputs.cache || github.event.inputs.cache == '' }}
nodejs: true
Expand Down
72 changes: 51 additions & 21 deletions panel/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tornado

from bokeh.server.auth_provider import AuthProvider
from bokeh.util.token import get_token_payload
from tornado.auth import OAuth2Mixin
from tornado.httpclient import HTTPError as HTTPClientError, HTTPRequest
from tornado.web import HTTPError, RequestHandler, decode_signed_value
Expand Down Expand Up @@ -413,7 +414,7 @@ def set_auth_cookies(handler, id_token, access_token, refresh_token=None, expire
type(handler).__name__, user_key)
raise HTTPError(401, "OAuth token payload missing user information")
handler.clear_cookie('is_guest')
handler.set_secure_cookie('user', user, expires_days=config.oauth_expiry)
handler.set_secure_cookie('user', user, expires_days=config.oauth_expiry, httponly=True)
else:
user = None

Expand All @@ -423,14 +424,14 @@ def set_auth_cookies(handler, id_token, access_token, refresh_token=None, expire
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
if refresh_token:
refresh_token = state.encryption.encrypt(refresh_token.encode('utf-8'))
handler.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('access_token', access_token, expires_days=config.oauth_expiry, httponly=True)
if id_token:
handler.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry, httponly=True)
if expires_in:
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
handler.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry)
handler.set_secure_cookie('oauth_expiry', str(int(now_ts + expires_in)), expires_days=config.oauth_expiry, httponly=True)
if refresh_token:
handler.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry)
handler.set_secure_cookie('refresh_token', refresh_token, expires_days=config.oauth_expiry, httponly=True)
if user and user in state._oauth_user_overrides:
state._oauth_user_overrides.pop(user, None)
return user
Expand Down Expand Up @@ -848,11 +849,11 @@ def set_current_user(self, user):
self.clear_cookie("user")
return
self.clear_cookie("is_guest")
self.set_secure_cookie("user", user, expires_days=config.oauth_expiry)
self.set_secure_cookie("user", user, expires_days=config.oauth_expiry, httponly=True)
id_token = base64url_encode(json.dumps({'user': user}))
if state.encryption:
id_token = state.encryption.encrypt(id_token.encode('utf-8'))
self.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry)
self.set_secure_cookie('id_token', id_token, expires_days=config.oauth_expiry, httponly=True)


class LogoutHandler(tornado.web.RequestHandler):
Expand Down Expand Up @@ -987,6 +988,20 @@ async def get_user(handler):
if not config.oauth_refresh_tokens or user is None:
return user

# Try to obtain user oauth overrides from WS headers
# in case the HTTP handler refreshed tokens
is_ws = isinstance(handler, WebSocketHandler)
if is_ws and 'Sec-Websocket-Protocol' in handler.request.headers:
protocol_header = handler.request.headers['Sec-Websocket-Protocol']
_, token = protocol_header.split(', ')
payload = get_token_payload(token)
if 'user_data' in payload:
user_data = payload['user_data']
if state.encryption:
user_data = state.encryption.decrypt(user_data).decode('utf-8')
user_data = json.loads(user_data)
state._oauth_user_overrides[user] = user_data

now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
expiry = None
if user in state._oauth_user_overrides:
Expand All @@ -1003,16 +1018,20 @@ async def get_user(handler):
return
access_token = state._decrypt_cookie(access_cookie)

# Try to get expiry directly from the token since that is
# the real source of truth
try:
access_json = decode_token(access_token)
expiry = access_json['exp']
except Exception:
pass

if expiry is None:
try:
access_json = decode_token(access_token)
expiry = access_json['exp']
except Exception:
expiry = handler.get_secure_cookie('oauth_expiry', max_age_days=config.oauth_expiry)
if expiry is None:
# Token does not have content and therefore does not expire
log.debug("access_token is not a valid JWT token. Expiry cannot be determined.")
return user
expiry = handler.get_secure_cookie('oauth_expiry', max_age_days=config.oauth_expiry)
if expiry is None:
# Token does not have content and therefore does not expire
log.debug("access_token is not a valid JWT token. Expiry cannot be determined.")
return user

if user in state._oauth_user_overrides:
refresh_token = state._oauth_user_overrides[user]['refresh_token']
Expand All @@ -1025,7 +1044,8 @@ async def get_user(handler):

if expiry > now_ts and refresh_token:
log.debug("Fully authenticated and tokens still valid.")
self._schedule_refresh(expiry, user, refresh_token, handler.application, handler.request)
if is_ws:
self._schedule_refresh(expiry, user, refresh_token, handler.application, handler.request)
expires_in = expiry - now_ts
OAuthLoginHandler.set_auth_cookies(
handler, None, access_token, refresh_token, expires_in
Expand All @@ -1047,8 +1067,13 @@ async def get_user(handler):

log.debug("access_token has expired, %s using refresh_token to obtain new tokens.", type(self).__name__)
access_token, refresh_token, expiry = await self._scheduled_refresh(
user, refresh_token, handler.application, handler.request
user, refresh_token, handler.application, handler.request,
reschedule=is_ws
)
# If user not in overrides refresh failed and we need to
# fully reauthenticate
if user not in state._oauth_user_overrides:
return
expires_in = expiry - now_ts
OAuthLoginHandler.set_auth_cookies(
handler, None, access_token, refresh_token, expires_in
Expand Down Expand Up @@ -1106,15 +1131,18 @@ def _schedule_refresh(self, expiry_ts, user, refresh_token, application, request
finally:
state.schedule_task(task, refresh_cb, at=expiry_date)

async def _scheduled_refresh(self, user, refresh_token, application, request):
async def _scheduled_refresh(self, user, refresh_token, application, request, reschedule=True):
await self._refresh_access_token(user, refresh_token, application, request)
if user not in state._oauth_user_overrides:
return None, None, None
user_state = state._oauth_user_overrides[user]
access_token, refresh_token = user_state['access_token'], user_state['refresh_token']
if user_state['expiry']:
expiry = user_state['expiry']
else:
expiry = decode_token(access_token)['exp']
self._schedule_refresh(expiry, user, refresh_token, application, request)
if reschedule:
self._schedule_refresh(expiry, user, refresh_token, application, request)
return access_token, refresh_token, expiry

async def _refresh_access_token(self, user, refresh_token, application, request):
Expand All @@ -1126,7 +1154,7 @@ async def _refresh_access_token(self, user, refresh_token, application, request)
return
else:
refresh_token = state._oauth_user_overrides[user]['refresh_token']
log.debug("%s refreshing token", type(self).__name__)
log.debug("%s refreshing tokens", type(self).__name__)
state._oauth_user_overrides[user] = {}
auth_handler = self.login_handler(application=application, request=request)
_, access_token, refresh_token, expires_in = await auth_handler._fetch_access_token(
Expand All @@ -1135,13 +1163,15 @@ async def _refresh_access_token(self, user, refresh_token, application, request)
refresh_token=refresh_token
)
if access_token:
log.debug("%s successfully refreshed access_token", type(self).__name__)
now_ts = dt.datetime.now(dt.timezone.utc).timestamp()
state._oauth_user_overrides[user] = {
'access_token': access_token,
'refresh_token': refresh_token,
'expiry': now_ts+expires_in if expires_in else None
}
else:
log.debug("%s failed to refresh access_token", type(self).__name__)
del state._oauth_user_overrides[user]


Expand Down
26 changes: 25 additions & 1 deletion panel/io/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""
from __future__ import annotations

import json
import logging
import os

from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import bokeh.command.util

Expand Down Expand Up @@ -84,6 +85,29 @@ def _log_session_destroyed(session_context):
doc.on_event('document_ready', partial(state._schedule_on_load, doc))
doc.on_session_destroyed(_log_session_destroyed)

def process_request(self, request) -> dict[str, Any]:
''' Processes incoming HTTP request returning a dictionary of
additional data to add to the session_context.
Args:
request: HTTP request
Returns:
A dictionary of JSON serializable data to be included on
the session context.
'''
request_data = super().process_request(request)
user = request.cookies.get('user')
if user:
from tornado.web import decode_signed_value
user = decode_signed_value(config.cookie_secret, 'user', user.value).decode('utf-8')
if user in state._oauth_user_overrides:
user_data = json.dumps(state._oauth_user_overrides[user])
if state.encryption:
user_data = state.encryption.encrypt(user_data.encode('utf-8'))
request_data['user_data'] = user_data
return request_data

bokeh.command.util.Application = Application # type: ignore


Expand Down
Loading