Skip to content

Commit

Permalink
[Feature] Introduce new Credential Strategies for Agents (#882)
Browse files Browse the repository at this point in the history
## What changes are proposed in this pull request?

This PR introduces two new credential strategies for Agents,
(AgentEmbeddedCredentials, AgentUserCredentials).

Agents currently use the databricks.sdk in order to interact with
databricks resources. However the authentication method for these
resources is a little unique where we store the token for the
authentication in a Credential File on the Kubernetes Container.
Therefore in the past we added the Model Serving Credential Strategy to
the defaultCredentials list to read this file.

Now we want to introduce a new authentication where the user's token is
instead stored in a thread local variable. Agent users will initialize
clients as follows:

```
from databricks.sdk.credentials_provider import ModelServingUserCredentials

invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
definers_client = WorkspaceClient()

```

Then the users can use the invoker_client to interact with resources
with the invokers token or the definers_client to interact with
resources using the old method of authentication.

Additionally as the users will be using these clients to test their code
locally in Databricks Notebooks, if the code is not being run on model
serving environments, users need to be able to authenticate using the
DefaultCredential strategies.

More details:
https://docs.google.com/document/d/14qLVjyxIAk581w287TWElstIeh8-DR30ab9Z6B_Vydg/edit?usp=sharing

## How is this tested?

Added unit tests

---------

Signed-off-by: aravind-segu <[email protected]>
  • Loading branch information
aravind-segu authored Feb 13, 2025
1 parent 3c391a0 commit 41f5f4b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 18 deletions.
87 changes: 71 additions & 16 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import subprocess
import sys
import threading
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -723,14 +724,17 @@ def inner() -> Dict[str, str]:
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
class ModelServingAuthProvider():
USER_CREDENTIALS = "user_credentials"

_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"

def __init__(self):
def __init__(self, credential_type: Optional[str]):
self.expiry_time = -1
self.current_token = None
self.refresh_duration = 300 # 300 Seconds
self.credential_type = credential_type

def should_fetch_model_serving_environment_oauth(self) -> bool:
def should_fetch_model_serving_environment_oauth() -> bool:
"""
Check whether this is the model serving environment
Additionally check if the oauth token file path exists
Expand All @@ -739,15 +743,15 @@ def should_fetch_model_serving_environment_oauth(self) -> bool:
is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
return (is_in_model_serving_env == "true"
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))

def get_model_dependency_oauth_token(self, should_retry=True) -> str:
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
# Use Cached value if it is valid
if self.current_token is not None and self.expiry_time > time.time():
return self.current_token

try:
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
oauth_dict = json.load(f)
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
self.expiry_time = time.time() + self.refresh_duration
Expand All @@ -757,32 +761,43 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
exc_info=e)
time.sleep(0.5)
return self.get_model_dependency_oauth_token(should_retry=False)
return self._get_model_dependency_oauth_token(should_retry=False)
else:
raise RuntimeError(
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
) from e
return self.current_token

def _get_invokers_token(self):
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
invokers_token = None
if "invokers_token" in thread_data:
invokers_token = thread_data["invokers_token"]

if invokers_token is None:
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")

return invokers_token

def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
if not self.should_fetch_model_serving_environment_oauth():
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
return None

# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
"DB_MODEL_SERVING_HOST_URL")
token = self.get_model_dependency_oauth_token()

return (host, token)
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
return (host, self._get_invokers_token())
else:
return (host, self._get_model_dependency_oauth_token())


@credentials_strategy('model-serving', [])
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
def model_serving_auth_visitor(cfg: 'Config',
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
try:
model_serving_auth_provider = ModelServingAuthProvider()
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
return None
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
host, token = model_serving_auth_provider.get_databricks_host_token()
if token is None:
raise ValueError(
Expand All @@ -793,7 +808,6 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
except Exception as e:
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
return None

logger.info("Using Databricks Model Serving Authentication")

def inner() -> Dict[str, str]:
Expand All @@ -804,6 +818,15 @@ def inner() -> Dict[str, str]:
return inner


@credentials_strategy('model-serving', [])
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
return None

return model_serving_auth_visitor(cfg)


class DefaultCredentials:
""" Select the first applicable credential provider from the chain """

Expand Down Expand Up @@ -846,3 +869,35 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
raise ValueError(
f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
)


class ModelServingUserCredentials(CredentialsStrategy):
"""
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
In any other environments, the class defaults to the DefaultCredentialStrategy.
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
"""

def __init__(self):
self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
self.default_credentials = DefaultCredentials()

def auth_type(self):
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
return "model_serving_" + self.credential_type
else:
return self.default_credentials.auth_type()

def __call__(self, cfg: 'Config') -> CredentialsProvider:
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
if not header_factory:
raise ValueError(
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
)
return header_factory
else:
return self.default_credentials(cfg)
50 changes: 48 additions & 2 deletions tests/test_model_serving_auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import threading
import time

import pytest

from databricks.sdk.core import Config
from databricks.sdk.credentials_provider import ModelServingUserCredentials

from .conftest import raises

Expand Down Expand Up @@ -39,7 +41,6 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')

cfg = Config()

assert cfg.auth_type == 'model-serving'
headers = cfg.authenticate()
assert (cfg.host == 'x')
Expand Down Expand Up @@ -93,7 +94,6 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
assert (cfg.host == 'x')
assert headers.get(
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file

# Simulate refreshing the token by patching to to a new file
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
Expand All @@ -113,3 +113,49 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
assert (cfg.host == 'x')
# Read V2 now
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'


def test_agent_user_credentials(monkeypatch, mocker):
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
monkeypatch.setattr(
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
"tests/testdata/model-serving-test-token")

invokers_token_val = "databricks_invokers_token"
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
thread_data["invokers_token"] = invokers_token_val

cfg = Config(credentials_strategy=ModelServingUserCredentials())
assert cfg.auth_type == 'model_serving_user_credentials'

headers = cfg.authenticate()

assert (cfg.host == 'x')
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'

# Test updates of invokers token
invokers_token_val = "databricks_invokers_token_v2"
current_thread = threading.current_thread()
thread_data = current_thread.__dict__
thread_data["invokers_token"] = invokers_token_val

headers = cfg.authenticate()
assert (cfg.host == 'x')
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'


# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):

monkeypatch.setenv('DATABRICKS_HOST', 'x')
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')

cfg = Config(credentials_strategy=ModelServingUserCredentials())
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment

headers = cfg.authenticate()

assert (cfg.host == 'https://x')
assert headers.get("Authorization") == f'Bearer token'

0 comments on commit 41f5f4b

Please sign in to comment.