Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Introduce new Credential Strategies for Agents #882

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 71 additions & 16 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import subprocess
import sys
import threading
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -723,14 +724,17 @@ def inner() -> Dict[str, str]:
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
class ModelServingAuthProvider():
USER_CREDENTIALS = "user_credentials"

_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"

def __init__(self):
def __init__(self, credential_type: Optional[str]):
self.expiry_time = -1
self.current_token = None
self.refresh_duration = 300 # 300 Seconds
self.credential_type = credential_type

def should_fetch_model_serving_environment_oauth(self) -> bool:
def should_fetch_model_serving_environment_oauth() -> bool:
"""
Check whether this is the model serving environment
Additionally check if the oauth token file path exists
Expand All @@ -739,15 +743,15 @@ def should_fetch_model_serving_environment_oauth(self) -> bool:
is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
return (is_in_model_serving_env == "true"
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))

def get_model_dependency_oauth_token(self, should_retry=True) -> str:
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
# Use Cached value if it is valid
if self.current_token is not None and self.expiry_time > time.time():
return self.current_token

try:
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
oauth_dict = json.load(f)
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
self.expiry_time = time.time() + self.refresh_duration
Expand All @@ -757,32 +761,43 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
exc_info=e)
time.sleep(0.5)
return self.get_model_dependency_oauth_token(should_retry=False)
return self._get_model_dependency_oauth_token(should_retry=False)
else:
raise RuntimeError(
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
) from e
return self.current_token

def _get_invokers_token(self):
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
invokers_token = None
if "invokers_token" in thread_data:
invokers_token = thread_data["invokers_token"]

if invokers_token is None:
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")

return invokers_token

def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
if not self.should_fetch_model_serving_environment_oauth():
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
return None

# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
"DB_MODEL_SERVING_HOST_URL")
token = self.get_model_dependency_oauth_token()

return (host, token)
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
return (host, self._get_invokers_token())
else:
return (host, self._get_model_dependency_oauth_token())


@credentials_strategy('model-serving', [])
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
def model_serving_auth_visitor(cfg: 'Config',
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
try:
model_serving_auth_provider = ModelServingAuthProvider()
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
return None
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
host, token = model_serving_auth_provider.get_databricks_host_token()
if token is None:
raise ValueError(
Expand All @@ -793,7 +808,6 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
except Exception as e:
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
return None

logger.info("Using Databricks Model Serving Authentication")

def inner() -> Dict[str, str]:
Expand All @@ -804,6 +818,15 @@ def inner() -> Dict[str, str]:
return inner


@credentials_strategy('model-serving', [])
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
return None

return model_serving_auth_visitor(cfg)


class DefaultCredentials:
""" Select the first applicable credential provider from the chain """

Expand Down Expand Up @@ -846,3 +869,35 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
raise ValueError(
f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
)


class ModelServingUserCredentials(CredentialsStrategy):
"""
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
In any other environments, the class defaults to the DefaultCredentialStrategy.
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:

invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
"""

def __init__(self):
self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
self.default_credentials = DefaultCredentials()

def auth_type(self):
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
return "model_serving_" + self.credential_type
else:
return self.default_credentials.auth_type()

def __call__(self, cfg: 'Config') -> CredentialsProvider:
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
if not header_factory:
raise ValueError(
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
)
return header_factory
else:
return self.default_credentials(cfg)
50 changes: 48 additions & 2 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import threading
import time

import pytest

from databricks.sdk.core import Config
from databricks.sdk.credentials_provider import ModelServingUserCredentials

from .conftest import raises

Expand Down Expand Up @@ -39,7 +41,6 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')

cfg = Config()

assert cfg.auth_type == 'model-serving'
headers = cfg.authenticate()
assert (cfg.host == 'x')
Expand Down Expand Up @@ -93,7 +94,6 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
assert (cfg.host == 'x')
assert headers.get(
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file

# Simulate refreshing the token by patching to to a new file
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
Expand All @@ -113,3 +113,49 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
assert (cfg.host == 'x')
# Read V2 now
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'


def test_agent_user_credentials(monkeypatch, mocker):
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
"tests/testdata/model-serving-test-token")

invokers_token_val = "databricks_invokers_token"
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
thread_data["invokers_token"] = invokers_token_val

cfg = Config(credentials_strategy=ModelServingUserCredentials())
assert cfg.auth_type == 'model_serving_user_credentials'

headers = cfg.authenticate()

assert (cfg.host == 'x')
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'

# Test updates of invokers token
invokers_token_val = "databricks_invokers_token_v2"
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
thread_data["invokers_token"] = invokers_token_val

headers = cfg.authenticate()
assert (cfg.host == 'x')
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'


# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):

monkeypatch.setenv('DATABRICKS_HOST', 'x')
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')

cfg = Config(credentials_strategy=ModelServingUserCredentials())
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment

headers = cfg.authenticate()

assert (cfg.host == 'https://x')
assert headers.get("Authorization") == f'Bearer token'
Loading