From f2359e3f9cd781055bf39fb775c7009b375333a0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 27 Dec 2024 15:45:44 -0800 Subject: [PATCH] [remove import *] clean up import *'s (#689) # What does this PR do? - as title, cleaning up `import *`'s - upgrade tests to make them more robust to bad model outputs - remove import *'s in llama_stack/apis/* (skip __init__ modules) image - run `sh run_openapi_generator.sh`, no types gets affected ## Test Plan ### Providers Tests **agents** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` **inference** ```bash # meta-reference torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py # together pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py ``` **safety** ``` pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B ``` **memory** ``` pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 ``` **scoring** ``` pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py ``` **datasetio** ``` pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py ``` **eval** ``` pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py ``` ### Client-SDK Tests ``` LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk ``` ### llama-stack-apps ``` PORT=5000 LOCALHOST=localhost python -m examples.agents.hello $LOCALHOST $PORT python -m examples.agents.inflation $LOCALHOST $PORT python -m examples.agents.podcast_transcript $LOCALHOST $PORT python -m examples.agents.rag_as_attachments $LOCALHOST $PORT python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT # Vision model python -m examples.interior_design_assistant.app python -m examples.agent_store.app $LOCALHOST $PORT ``` ### CLI ``` which llama llama model prompt-format -m Llama3.2-11B-Vision-Instruct llama model list llama stack list-apis llama stack list-providers inference llama stack build --template ollama --image-type conda ``` ### Distributions Tests **ollama** ``` llama stack build --template ollama --image-type conda ollama run llama3.2:1b-instruct-fp16 llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct ``` **fireworks** ``` llama stack build --template fireworks --image-type conda llama stack run ./llama_stack/templates/fireworks/run.yaml ``` **together** ``` llama stack build --template together --image-type conda llama stack run ./llama_stack/templates/together/run.yaml ``` **tgi** ``` llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- docs/zero_to_hero_guide/06_Safety101.ipynb | 4 +- llama_stack/apis/agents/agents.py | 24 ++++++-- llama_stack/apis/agents/event_logger.py | 5 +- .../apis/batch_inference/batch_inference.py | 12 +++- llama_stack/apis/datasetio/datasetio.py | 2 +- llama_stack/apis/eval/eval.py | 12 ++-- llama_stack/apis/inference/inference.py | 5 +- .../apis/post_training/post_training.py | 8 +-- llama_stack/apis/scoring/scoring.py | 5 +- .../synthetic_data_generation.py | 3 +- llama_stack/cli/model/safety_models.py | 7 ++- llama_stack/cli/stack/build.py | 15 +++-- llama_stack/distribution/build.py | 11 ++-- llama_stack/distribution/configure.py | 15 ++--- llama_stack/distribution/datatypes.py | 16 ++--- llama_stack/distribution/inspect.py | 6 +- llama_stack/distribution/resolver.py | 30 +++++++-- llama_stack/distribution/routers/__init__.py | 6 +- llama_stack/distribution/routers/routers.py | 43 ++++++++++--- .../distribution/routers/routing_tables.py | 39 +++++++++--- llama_stack/distribution/server/server.py | 17 +++--- llama_stack/distribution/stack.py | 39 ++++++------ llama_stack/distribution/store/registry.py | 7 +-- .../distribution/store/tests/test_registry.py | 7 ++- .../agents/meta_reference/agent_instance.py | 61 ++++++++++++++++--- .../inline/agents/meta_reference/agents.py | 17 +++++- .../agents/meta_reference/persistence.py | 4 +- .../meta_reference/rag/context_retriever.py | 4 +- .../inline/agents/meta_reference/safety.py | 4 +- .../meta_reference/tests/test_chat_agent.py | 24 ++++++-- .../agents/meta_reference/tools/safety.py | 2 +- .../inline/datasetio/localfs/config.py | 2 +- .../inline/datasetio/localfs/datasetio.py | 13 ++-- .../inline/eval/meta_reference/eval.py | 13 ++-- .../inline/inference/meta_reference/config.py | 5 +- .../inference/meta_reference/generation.py | 18 +++--- .../providers/inline/inference/vllm/vllm.py | 25 ++++++-- .../providers/inline/memory/faiss/faiss.py | 11 ++-- .../post_training/torchtune/common/utils.py | 5 +- .../post_training/torchtune/post_training.py | 17 +++++- .../recipes/lora_finetuning_single_device.py | 26 +++++--- .../safety/code_scanner/code_scanner.py | 8 ++- .../inline/safety/llama_guard/llama_guard.py | 20 +++++- .../safety/prompt_guard/prompt_guard.py | 13 ++-- .../providers/inline/scoring/basic/scoring.py | 19 +++--- .../inline/scoring/braintrust/braintrust.py | 21 ++++--- .../inline/scoring/braintrust/config.py | 4 +- .../telemetry/meta_reference/telemetry.py | 20 ++++-- .../inline/telemetry/sample/sample.py | 4 +- llama_stack/providers/registry/agents.py | 8 ++- llama_stack/providers/registry/datasetio.py | 8 ++- llama_stack/providers/registry/eval.py | 2 +- llama_stack/providers/registry/inference.py | 9 ++- llama_stack/providers/registry/memory.py | 9 ++- .../providers/registry/post_training.py | 2 +- llama_stack/providers/registry/safety.py | 2 +- llama_stack/providers/registry/scoring.py | 2 +- llama_stack/providers/registry/telemetry.py | 8 ++- .../providers/registry/tool_runtime.py | 2 +- .../providers/remote/agents/sample/sample.py | 4 +- .../datasetio/huggingface/huggingface.py | 6 +- .../remote/inference/bedrock/bedrock.py | 25 ++++++-- .../remote/inference/cerebras/cerebras.py | 22 +++++-- .../remote/inference/databricks/databricks.py | 17 +++++- .../remote/inference/fireworks/fireworks.py | 19 +++++- .../remote/inference/ollama/ollama.py | 28 +++++++-- .../remote/inference/sample/sample.py | 5 +- .../providers/remote/inference/tgi/tgi.py | 21 ++++++- .../remote/inference/together/together.py | 19 +++++- .../providers/remote/inference/vllm/vllm.py | 22 ++++++- .../providers/remote/memory/chroma/chroma.py | 10 ++- .../remote/memory/pgvector/pgvector.py | 12 +++- .../providers/remote/memory/qdrant/qdrant.py | 13 ++-- .../providers/remote/memory/sample/sample.py | 5 +- .../remote/memory/weaviate/weaviate.py | 10 ++- .../remote/safety/bedrock/bedrock.py | 11 +++- .../providers/remote/safety/sample/sample.py | 5 +- .../providers/tests/agents/test_agents.py | 26 +++++++- .../tests/agents/test_persistence.py | 6 +- .../tests/datasetio/test_datasetio.py | 13 ++-- llama_stack/providers/tests/eval/test_eval.py | 4 +- .../tests/inference/test_prompt_adapter.py | 20 +++--- .../tests/inference/test_text_inference.py | 29 +++++++-- .../tests/inference/test_vision_inference.py | 11 +++- .../providers/tests/memory/fixtures.py | 5 +- .../providers/tests/memory/test_memory.py | 12 ++-- .../providers/tests/post_training/fixtures.py | 3 +- .../tests/post_training/test_post_training.py | 15 ++++- llama_stack/providers/tests/resolver.py | 14 ++++- .../providers/tests/safety/test_safety.py | 6 +- .../providers/tests/scoring/test_scoring.py | 2 +- .../utils/inference/openai_compat.py | 19 ++++-- .../providers/utils/kvstore/kvstore.py | 6 +- .../providers/utils/kvstore/redis/redis.py | 2 +- .../providers/utils/kvstore/sqlite/sqlite.py | 2 +- .../providers/utils/memory/vector_store.py | 13 ++-- .../utils/scoring/aggregation_utils.py | 3 +- .../providers/utils/telemetry/tracing.py | 16 +++-- tests/client-sdk/agents/test_agents.py | 43 ++++++++----- 99 files changed, 911 insertions(+), 363 deletions(-) diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index 6b5bd53bf18..e2ba5e22e21 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -67,7 +67,7 @@ "from termcolor import cprint\n", "\n", "from llama_stack.distribution.datatypes import RemoteProviderConfig\n", - "from llama_stack.apis.safety import * # noqa: F403\n", + "from llama_stack.apis.safety import Safety\n", "from llama_stack_client import LlamaStackClient\n", "\n", "\n", @@ -127,7 +127,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5fd90ae7ad8..5748b4e41b2 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,18 +18,30 @@ Union, ) +from llama_models.llama3.api.datatypes import ToolParamDefinition + from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.common.deployment_types import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig +from llama_stack.apis.inference import ( + CompletionMessage, + SamplingParams, + ToolCall, + ToolCallDelta, + ToolChoice, + ToolPromptFormat, + ToolResponse, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.safety import SafetyViolation + +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 4c379999e11..40a69d19cd6 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -6,13 +6,14 @@ from typing import Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.tool_utils import ToolUtils - from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType +from llama_stack.apis.inference import ToolResponseMessage + class LogEvent: def __init__( diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 358cf3c3558..f7b8b438709 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -10,8 +10,16 @@ from pydantic import BaseModel, Field -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import ( + CompletionMessage, + InterleavedContent, + LogProbConfig, + Message, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) @json_schema_type diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 22acc321140..983e0e4ea50 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -9,7 +9,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasets import Dataset @json_schema_type diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 2e0ce1fbcb2..2592bca37ff 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union + +from llama_models.llama3.api.datatypes import BaseModel, Field +from llama_models.schema_utils import json_schema_type, webmethod from typing_extensions import Annotated -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_models.schema_utils import json_schema_type, webmethod -from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.apis.inference import SamplingParams, SystemMessage +from llama_stack.apis.scoring import ScoringResult +from llama_stack.apis.scoring_functions import ScoringFnParams @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 28b9d91067a..e4804209167 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -7,7 +7,9 @@ from enum import Enum from typing import ( + Any, AsyncIterator, + Dict, List, Literal, Optional, @@ -32,8 +34,9 @@ from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.models import Model + from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.apis.models import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index fdbaa364d75..1c2d2d6e23f 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -7,17 +7,17 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.common.content_types import URL + from llama_stack.apis.common.job_types import JobStatus -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.common.training_types import * # noqa: F403 +from llama_stack.apis.common.training_types import Checkpoint @json_schema_type diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a47620a3dc1..453e35f6def 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams # mapping of metric to value diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 4ffaa4d1e96..13b20991216 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -6,13 +6,12 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import Message diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 39c133f73eb..9464e0a2d40 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,11 +6,12 @@ from typing import Any, Dict, Optional -from pydantic import BaseModel, ConfigDict, Field - -from llama_models.datatypes import * # noqa: F403 +from llama_models.datatypes import CheckpointQuantizationFormat +from llama_models.llama3.api.datatypes import SamplingParams from llama_models.sku_list import LlamaDownloadInfo +from pydantic import BaseModel, ConfigDict, Field + class PromptGuardModel(BaseModel): """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index f18d262c073..54d78ad93f2 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -3,21 +3,28 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - import argparse - -from llama_stack.cli.subcommand import Subcommand -from llama_stack.distribution.datatypes import * # noqa: F403 import os import shutil from functools import lru_cache from pathlib import Path +from typing import List, Optional import pkg_resources +from llama_stack.cli.subcommand import Subcommand + +from llama_stack.distribution.datatypes import ( + BuildConfig, + DistributionSpec, + Provider, + StackRunConfig, +) + from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.providers.datatypes import Api TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index bdda0349f64..f376301f914 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -6,21 +6,22 @@ import logging from enum import Enum -from typing import List + +from pathlib import Path +from typing import Dict, List import pkg_resources from pydantic import BaseModel from termcolor import cprint -from llama_stack.distribution.utils.exec import run_with_pty - -from llama_stack.distribution.datatypes import * # noqa: F403 -from pathlib import Path +from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR +from llama_stack.distribution.utils.exec import run_with_pty +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index a4d0f970bd7..71c2676de48 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -6,10 +6,14 @@ import logging import textwrap -from typing import Any - -from llama_stack.distribution.datatypes import * # noqa: F403 +from typing import Any, Dict +from llama_stack.distribution.datatypes import ( + DistributionSpec, + LLAMA_STACK_RUN_CONFIG_VERSION, + Provider, + StackRunConfig, +) from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, @@ -17,10 +21,7 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config - -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.providers.datatypes import Api, ProviderSpec logger = logging.getLogger(__name__) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f2dea6012ea..dec62bfaeeb 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -4,24 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasets import Dataset, DatasetInput from llama_stack.apis.eval import Eval -from llama_stack.apis.eval_tasks import EvalTaskInput +from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankInput +from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput +from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime -from llama_stack.providers.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig LLAMA_STACK_BUILD_CONFIG_VERSION = "2" diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index f5716ef5eb1..dbb16d8ce5a 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -5,12 +5,12 @@ # the root directory of this source tree. from typing import Dict, List -from llama_stack.apis.inspect import * # noqa: F403 + from pydantic import BaseModel +from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo +from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.server.endpoints import get_all_api_endpoints -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 class DistributionInspectConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 43997131554..0a6eed3458f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -6,14 +6,10 @@ import importlib import inspect -from typing import Any, Dict, List, Set - - -from llama_stack.providers.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 - import logging +from typing import Any, Dict, List, Set + from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -32,10 +28,32 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl + +from llama_stack.distribution.datatypes import ( + AutoRoutedProviderSpec, + Provider, + RoutingTableProviderSpec, + StackRunConfig, +) from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.providers.datatypes import ( + Api, + DatasetsProtocolPrivate, + EvalTasksProtocolPrivate, + InlineProviderSpec, + MemoryBanksProtocolPrivate, + ModelsProtocolPrivate, + ProviderSpec, + RemoteProviderConfig, + RemoteProviderSpec, + ScoringFunctionsProtocolPrivate, + ShieldsProtocolPrivate, + ToolsProtocolPrivate, +) + log = logging.getLogger(__name__) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 693f1fbe2b6..f19a2bffc87 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Dict + +from llama_stack.distribution.datatypes import RoutedProtocol -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable from .routing_tables import ( DatasetsRoutingTable, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a25a848db89..84ef467eb5e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,16 +6,40 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasetio.datasetio import DatasetIO -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.eval import ( + AppEvalTaskConfig, + Eval, + EvalTaskConfig, + EvaluateResponse, + Job, + JobStatus, +) +from llama_stack.apis.inference import ( + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse from llama_stack.apis.memory_banks.memory_banks import BankParams -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.distribution.datatypes import RoutingTable +from llama_stack.apis.models import ModelType +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringFnParams, +) +from llama_stack.apis.shields import Shield +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime +from llama_stack.providers.datatypes import RoutingTable class MemoryRouter(Memory): @@ -330,7 +354,6 @@ async def run_eval( task_config=task_config, ) - @webmethod(route="/eval/evaluate_rows", method="POST") async def evaluate_rows( self, task_id: str, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3fb086b7239..ab1becfdd94 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,19 +6,42 @@ from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import parse_obj_as from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.datasets import Dataset, Datasets +from llama_stack.apis.eval_tasks import EvalTask, EvalTasks +from llama_stack.apis.memory_banks import ( + BankParams, + MemoryBank, + MemoryBanks, + MemoryBankType, +) +from llama_stack.apis.models import Model, Models, ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.apis.shields import Shield, Shields +from llama_stack.apis.tools import ( + MCPToolGroupDef, + Tool, + ToolGroup, + ToolGroupDef, + ToolGroups, + UserDefinedToolGroupDef, +) +from llama_stack.distribution.datatypes import ( + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) + from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 8f24f3eafd3..daaf8475b4f 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,14 +28,9 @@ from termcolor import cprint from typing_extensions import Annotated -from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.datatypes import StackRunConfig -from llama_stack.providers.utils.telemetry.tracing import ( - end_trace, - setup_logger, - start_trace, -) -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import ( @@ -43,11 +38,19 @@ replace_env_vars, validate_env_pair, ) + +from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) +from llama_stack.providers.utils.telemetry.tracing import ( + end_trace, + setup_logger, + start_trace, +) + from .endpoints import get_all_api_endpoints diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index f5180b0dbf4..965df5f03ae 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -8,32 +8,31 @@ import os import re from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional import pkg_resources import yaml from termcolor import colored -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.eval import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.batch_inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.telemetry import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.apis.synthetic_data_generation import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.inspect import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_stack.apis.agents import Agents +from llama_stack.apis.batch_inference import BatchInference +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTasks +from llama_stack.apis.inference import Inference +from llama_stack.apis.inspect import Inspect +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.models import Models +from llama_stack.apis.post_training import PostTraining +from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import ScoringFunctions +from llama_stack.apis.shields import Shields +from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration +from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index f98c144430d..686054dd262 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -13,11 +13,8 @@ from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import ( - KVStore, - kvstore_impl, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class DistributionRegistry(Protocol): diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 7e389cccdbd..54bc04f9cb1 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -8,11 +8,14 @@ import pytest import pytest_asyncio -from llama_stack.distribution.store import * # noqa F403 from llama_stack.apis.inference import Model from llama_stack.apis.memory_banks import VectorMemoryBank + +from llama_stack.distribution.store.registry import ( + CachedDiskDistributionRegistry, + DiskDistributionRegistry, +) from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig -from llama_stack.distribution.datatypes import * # noqa F403 @pytest.fixture diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index d7930550d73..f225f53939a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,19 +13,64 @@ import string import uuid from datetime import datetime -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.agents import ( + AgentConfig, + AgentTool, + AgentTurnCreateRequest, + AgentTurnResponseEvent, + AgentTurnResponseEventType, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseStepProgressPayload, + AgentTurnResponseStepStartPayload, + AgentTurnResponseStreamChunk, + AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnStartPayload, + Attachment, + CodeInterpreterToolDefinition, + FunctionCallToolDefinition, + InferenceStep, + MemoryRetrievalStep, + MemoryToolDefinition, + PhotogenToolDefinition, + SearchToolDefinition, + ShieldCallStep, + StepType, + ToolExecutionStep, + Turn, + WolframAlphaToolDefinition, +) -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 - -from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.inference import ( + ChatCompletionResponseEventType, + CompletionMessage, + Inference, + Message, + SamplingParams, + StopReason, + SystemMessage, + ToolCallDelta, + ToolCallParseStatus, + ToolChoice, + ToolDefinition, + ToolResponse, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.safety import Safety from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index dec5ec960b7..93bfab5f468 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -9,15 +9,26 @@ import shutil import tempfile import uuid -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from termcolor import colored -from llama_stack.apis.inference import Inference +from llama_stack.apis.agents import ( + AgentConfig, + AgentCreateResponse, + Agents, + AgentSessionCreateResponse, + AgentStepResponse, + AgentTurnCreateRequest, + Attachment, + Session, + Turn, +) + +from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety -from llama_stack.apis.agents import * # noqa: F403 from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 1c99e3d7527..a4b1af616c9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,9 +10,11 @@ from datetime import datetime from typing import List, Optional -from llama_stack.apis.agents import * # noqa: F403 + from pydantic import BaseModel +from llama_stack.apis.agents import Turn + from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index 7b5c8b4b045..74eb91c53ae 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -7,8 +7,6 @@ from typing import List from jinja2 import Template -from llama_models.llama3.api import * # noqa: F403 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, @@ -16,7 +14,7 @@ MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import Message, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 8fca4d310cd..90d193f909c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -9,7 +9,9 @@ from typing import List -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 6edef067249..03505432044 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -8,10 +8,26 @@ import pytest -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.apis.agents import ( + AgentConfig, + AgentTurnCreateRequest, + AgentTurnResponseTurnCompletePayload, +) + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseStreamChunk, + CompletionMessage, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + UserMessage, +) +from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.safety import RunShieldResponse from ..agents import ( AGENT_INSTANCES_BY_ID, diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py index 1ffc99edd9a..a34649756e9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py @@ -7,7 +7,7 @@ from typing import List from llama_stack.apis.inference import Message -from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.safety import Safety from ..safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/inline/datasetio/localfs/config.py b/llama_stack/providers/inline/datasetio/localfs/config.py index 58d563c994e..1b89df63b32 100644 --- a/llama_stack/providers/inline/datasetio/localfs/config.py +++ b/llama_stack/providers/inline/datasetio/localfs/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.datasetio import * # noqa: F401, F403 +from pydantic import BaseModel class LocalFSDatasetIOConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 736e5d8b951..442053fb34e 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,18 +3,19 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional - -import pandas -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.datasetio import * # noqa: F403 import base64 import os from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Any, Dict, List, Optional from urllib.parse import urlparse +import pandas + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasets import Dataset + from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index e1c2cc80439..00630132eff 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -5,13 +5,15 @@ # the root directory of this source tree. from enum import Enum from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 + from tqdm import tqdm -from .....apis.common.job_types import Job -from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.agents import Agents +from llama_stack.apis.common.type_system import ( + ChatCompletionInputType, + CompletionInputType, + StringType, +) from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask @@ -20,6 +22,9 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus + from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 33af33fcd83..2c46ef596a9 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -6,11 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import * # noqa: F403 - -from llama_stack.apis.inference import * # noqa: F401, F403 from pydantic import BaseModel, field_validator +from llama_stack.apis.inference import QuantizationConfig + from llama_stack.providers.utils.inference import supported_inference_models diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index c89183cb73b..1807e4ad5cd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -32,11 +32,16 @@ CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model -from pydantic import BaseModel - -from llama_stack.apis.inference import * # noqa: F403 from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData +from pydantic import BaseModel + +from llama_stack.apis.inference import ( + Fp8QuantizationConfig, + Int4QuantizationConfig, + ResponseFormat, + ResponseFormatType, +) from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -44,12 +49,7 @@ CompletionRequestWithRawContent, ) -from .config import ( - Fp8QuantizationConfig, - Int4QuantizationConfig, - MetaReferenceInferenceConfig, - MetaReferenceQuantizedInferenceConfig, -) +from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index c5925774bce..73f7adecd81 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -7,10 +7,10 @@ import logging import os import uuid -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import * # noqa: F403 + from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model @@ -18,9 +18,26 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index a46b151d9d0..af398801a38 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -16,11 +16,14 @@ import numpy as np from numpy.typing import NDArray -from llama_models.llama3.api.datatypes import * # noqa: F403 - -from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 462cbc21ed8..f2a2edae536 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -14,11 +14,10 @@ from typing import Any, Callable, Dict, List import torch -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.common.type_system import * # noqa from llama_models.datatypes import Model from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.common.type_system import ParamType, StringType +from llama_stack.apis.datasets import Datasets from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 9b1269f16c4..90fbf702626 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -3,11 +3,26 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from datetime import datetime +from typing import Any, Dict, List, Optional + +from llama_models.schema_utils import webmethod + from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + DPOAlignmentConfig, + JobStatus, + LoraFinetuningConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) -from llama_stack.apis.post_training import * # noqa from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 71b8bf75960..517be6d8937 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,27 +14,33 @@ import torch from llama_models.sku_list import resolve_model +from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + LoraFinetuningConfig, + OptimizerConfig, + TrainingConfig, +) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR -from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( - TorchtuneCheckpointer, -) -from torch import nn -from torchtune import utils as torchtune_utils -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm -from llama_stack.apis.post_training import * # noqa + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.torchtune.common import utils +from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( + TorchtuneCheckpointer, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training +from torchtune import modules, training, utils as torchtune_utils from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss @@ -47,6 +53,8 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 46b5e57dab9..87d68f74cfe 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -7,8 +7,14 @@ import logging from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bbdd5c3df45..00213ac8309 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -9,10 +9,24 @@ from string import Template from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 +from llama_models.datatypes import CoreModelId +from llama_models.llama3.api.datatypes import Role + from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.inference import ( + ChatCompletionResponseEventType, + Inference, + Message, + UserMessage, +) +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) + +from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ShieldsProtocolPrivate diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 4cb34127f50..3f30645bd5d 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -11,11 +11,16 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer -from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 0c0503ff52e..f8b30cbcf20 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -3,14 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 +from typing import Any, Dict, List, Optional + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, +) +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from .config import BasicScoringConfig diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index ae955540373..0c6102645b8 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -3,20 +3,23 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.datasets import * # noqa: F403 - import os +from typing import Any, Dict, List, Optional from autoevals.llm import Factuality from autoevals.ragas import AnswerCorrectness +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, + ScoringResultRow, +) +from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn + from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py index e122494322e..d4e0d9bcd51 100644 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ b/llama_stack/providers/inline/scoring/braintrust/config.py @@ -3,7 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring import * # noqa: F401, F403 +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field class BraintrustScoringConfig(BaseModel): diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index d7229f508fe..81dd9910d7d 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -17,6 +17,22 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from llama_stack.apis.telemetry import ( + Event, + MetricEvent, + QueryCondition, + SpanEndPayload, + SpanStartPayload, + SpanStatus, + SpanWithStatus, + StructuredLogEvent, + Telemetry, + Trace, + UnstructuredLogEvent, +) + +from llama_stack.distribution.datatypes import Api + from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -27,10 +43,6 @@ from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore -from llama_stack.apis.telemetry import * # noqa: F403 - -from llama_stack.distribution.datatypes import Api - from .config import TelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { diff --git a/llama_stack/providers/inline/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py index eaa6d834aa9..f07a185ef50 100644 --- a/llama_stack/providers/inline/telemetry/sample/sample.py +++ b/llama_stack/providers/inline/telemetry/sample/sample.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.telemetry import Telemetry from .config import SampleConfig -from llama_stack.apis.telemetry import * # noqa: F403 - - class SampleTelemetryImpl(Telemetry): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 8b6c9027ce5..6595b1955e7 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) from llama_stack.providers.utils.kvstore import kvstore_dependencies diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 403c4111120..f83dcbc60c7 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 718c7eae56c..6901c37417a 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 65ec37f3ac4..7f517bb021e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) META_REFERENCE_DEPS = [ "accelerate", diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c18bd387382..6867a9186a7 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) EMBEDDING_DEPS = [ "blobfile", diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index af8b660fa7b..3c5d06c0574 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 99b0d2bd8bc..b9f7b6d784d 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index f31ff44d7b5..ca09be98413 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index d367bf89419..ba7e2f8062c 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index f3e6aead82c..042aef9d9e4 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/remote/agents/sample/sample.py b/llama_stack/providers/remote/agents/sample/sample.py index e9a3a6ee560..f8b312f1ecd 100644 --- a/llama_stack/providers/remote/agents/sample/sample.py +++ b/llama_stack/providers/remote/agents/sample/sample.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.agents import Agents from .config import SampleConfig -from llama_stack.apis.agents import * # noqa: F403 - - class SampleAgentsImpl(Agents): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 2fde7c3d0da..47a63677eea 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -5,11 +5,11 @@ # the root directory of this source tree. from typing import Any, Dict, List, Optional -from llama_stack.apis.datasetio import * # noqa: F403 - - import datasets as hf_datasets +from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasets import Dataset + from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index ddf59fda852..d340bbbea5e 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import * # noqa: F403 import json +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -13,6 +13,24 @@ from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig +from llama_stack.providers.utils.bedrock.client import create_bedrock_client + from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -29,11 +47,6 @@ interleaved_content_as_str, ) -from llama_stack.apis.inference import * # noqa: F403 - -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client - MODEL_ALIASES = [ build_model_alias( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2ff213c2e32..40457e1ae38 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -4,17 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import * # noqa: F403 - -from llama_models.datatypes import CoreModelId +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 155b230bbd2..3d88423c543 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional from llama_models.datatypes import CoreModelId @@ -14,7 +14,20 @@ from openai import OpenAI -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 975ec489332..7a00194ac03 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -11,7 +11,24 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.inference import * # noqa: F403 + +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 920f3dd7e98..88f985f3ac9 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union import httpx from llama_models.datatypes import CoreModelId @@ -14,15 +14,33 @@ from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient +from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + TextContentItem, +) +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model, ModelType +from llama_stack.providers.datatypes import ModelsProtocolPrivate + from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, build_model_alias_with_just_provider_model_id, ModelRegistryHelper, ) - -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem -from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/remote/inference/sample/sample.py b/llama_stack/providers/remote/inference/sample/sample.py index 79ce1ffe44c..51ce879eb41 100644 --- a/llama_stack/providers/remote/inference/sample/sample.py +++ b/llama_stack/providers/remote/inference/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Model from .config import SampleConfig -from llama_stack.apis.inference import * # noqa: F403 - - class SampleInferenceImpl(Inference): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 5cc476fd745..dd02c055a09 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -13,10 +13,25 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e12a2cc0abb..6b5a6a3b012 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import CoreModelId @@ -14,7 +14,22 @@ from together import Together -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7250d901fb8..f62ccaa5831 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -13,7 +13,25 @@ from openai import OpenAI -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index aa8b481a3de..c04d775caf8 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -12,8 +12,14 @@ import chromadb from numpy.typing import NDArray -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index ffe164ecbc8..b2c720b2c9a 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import psycopg2 from numpy.typing import NDArray @@ -14,8 +14,14 @@ from pydantic import BaseModel, parse_obj_as -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index bf9e943c4a1..b1d5bd7fa7e 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -6,16 +6,21 @@ import logging import uuid -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate -from llama_stack.apis.memory import * # noqa: F403 - from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py index 09ea2f32c9f..b051eb544a5 100644 --- a/llama_stack/providers/remote/memory/sample/sample.py +++ b/llama_stack/providers/remote/memory/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBank from .config import SampleConfig -from llama_stack.apis.memory import * # noqa: F403 - - class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 8ee001cfabb..f1433090dde 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -14,8 +14,14 @@ from weaviate.classes.init import Auth from weaviate.classes.query import Filter -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 78e8105e029..fba7bf34269 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -9,8 +9,15 @@ from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client diff --git a/llama_stack/providers/remote/safety/sample/sample.py b/llama_stack/providers/remote/safety/sample/sample.py index 4069b878912..180e6c3b580 100644 --- a/llama_stack/providers/remote/safety/sample/sample.py +++ b/llama_stack/providers/remote/safety/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shield from .config import SampleConfig -from llama_stack.apis.safety import * # noqa: F403 - - class SampleSafetyImpl(Safety): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index ee2f3d29fe4..dc95fa6a65a 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -5,11 +5,31 @@ # the root directory of this source tree. import os +from typing import Dict, List import pytest - -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.agents import ( + AgentConfig, + AgentTool, + AgentTurnResponseEventType, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseStreamChunk, + AgentTurnResponseTurnCompletePayload, + Attachment, + MemoryToolDefinition, + SearchEngineType, + SearchToolDefinition, + ShieldCallStep, + StepType, + ToolChoice, + ToolExecutionStep, + Turn, +) +from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage +from llama_stack.apis.safety import ViolationLevel +from llama_stack.providers.datatypes import Api # How to run this test: # diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index 97094cd7ac3..38eb7de55ae 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -6,9 +6,9 @@ import pytest -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.datatypes import * # noqa: F403 - +from llama_stack.apis.agents import AgentConfig, Turn +from llama_stack.apis.inference import SamplingParams, UserMessage +from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .fixtures import pick_inference_model diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 7d88b611506..46c99f5b3cb 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os - -import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 import base64 import mimetypes +import os from pathlib import Path +import pytest + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType +from llama_stack.apis.datasets import Datasets + # How to run this test: # # pytest llama_stack/providers/tests/datasetio/test_datasetio.py diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 38da7412873..d6794d48884 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -7,8 +7,7 @@ import pytest -from llama_models.llama3.api import SamplingParams, URL - +from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.eval.eval import ( @@ -16,6 +15,7 @@ BenchmarkEvalTaskConfig, ModelCandidate, ) +from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 2c222ffa1b1..4826e89d535 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,8 +6,14 @@ import unittest -from llama_models.llama3.api import * # noqa: F403 -from llama_stack.apis.inference.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) + +from llama_stack.apis.inference import ChatCompletionRequest, SystemMessage, UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) @@ -24,7 +30,7 @@ async def test_system_default(self): UserMessage(content=content), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -41,7 +47,7 @@ async def test_system_builtin_only(self): ToolDefinition(tool_name=BuiltinTool.brave_search), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2) self.assertEqual(messages[-1].content, content) self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) @@ -69,7 +75,7 @@ async def test_system_custom_only(self): ], tool_prompt_format=ToolPromptFormat.json, ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -99,7 +105,7 @@ async def test_system_custom_and_builtin(self): ), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 3) self.assertTrue("Environment: ipython" in messages[0].content) @@ -121,7 +127,7 @@ async def test_user_provided_system_message(self): ToolDefinition(tool_name=BuiltinTool.code_interpreter), ], ) - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, MODEL) self.assertEqual(len(messages), 2, messages) self.assertTrue(messages[0].content.endswith(system_prompt)) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac080b..2eeda0dbfe4 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -7,13 +7,32 @@ import pytest -from pydantic import BaseModel, ValidationError - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import ( + SamplingParams, + StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) -from llama_stack.distribution.datatypes import * # noqa: F403 +from pydantic import BaseModel, ValidationError +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + JsonSchemaResponseFormat, + LogProbConfig, + SystemMessage, + ToolCallDelta, + ToolCallParseStatus, + ToolChoice, + UserMessage, +) +from llama_stack.apis.models import Model from .utils import group_chunks diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index d58164676e6..1bdee051f21 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -8,11 +8,16 @@ import pytest - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + SamplingParams, + UserMessage, +) + from .utils import group_chunks THIS_DIR = Path(__file__).parent diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index b2a5a87c980..9a98526abb4 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,8 +10,7 @@ import pytest import pytest_asyncio -from llama_stack.apis.inference import ModelInput, ModelType - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig @@ -19,7 +18,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 526aa646c47..801b04dfc37 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,14 +8,18 @@ import pytest -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams +from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse + +from llama_stack.apis.memory_banks import ( + MemoryBank, + MemoryBanks, + VectorMemoryBankParams, +) # How to run this test: # # pytest llama_stack/providers/tests/memory/test_memory.py -# -m "meta_reference" +# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 # -v -s --tb=short --disable-warnings diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index 17d9668b2e9..fd8a9e4f600 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -7,8 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.content_types import URL + +from llama_stack.apis.common.type_system import StringType from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 4ecc05187a0..0645cd5556e 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -4,9 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.apis.common.type_system import JobStatus +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + OptimizerConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) # How to run this test: # diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 8bbb902cd17..5a38aaecc94 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -8,14 +8,24 @@ import tempfile from typing import Any, Dict, List, Optional -from llama_stack.distribution.datatypes import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.apis.datasets import DatasetInput +from llama_stack.apis.eval_tasks import EvalTaskInput +from llama_stack.apis.memory_banks import MemoryBankInput +from llama_stack.apis.models import ModelInput +from llama_stack.apis.scoring_functions import ScoringFnInput +from llama_stack.apis.shields import ShieldInput + from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.datatypes import Api, RemoteProviderConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class TestStack(BaseModel): diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index b015e8b06a7..857fe57f9f2 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -6,11 +6,9 @@ import pytest -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 - -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.inference import UserMessage +from llama_stack.apis.safety import ViolationLevel +from llama_stack.apis.shields import Shield # How to run this test: # diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index dce069df03d..2643b8fd6d5 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -197,7 +197,7 @@ async def test_scoring_score_with_aggregation_functions( judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns, ) - elif x.provider_id == "basic": + elif x.provider_id == "basic" or x.provider_id == "braintrust": if "regex_parser" in x.identifier: scoring_functions[x.identifier] = RegexParserScoringFnParams( aggregation_functions=aggr_fns, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 871e39aaabe..ba63be2b6ae 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,17 +4,28 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason - -from llama_stack.apis.inference import * # noqa: F403 +from llama_models.llama3.api.datatypes import SamplingParams, StopReason from pydantic import BaseModel from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, + CompletionResponse, + CompletionResponseStreamChunk, + Message, + ToolCallDelta, + ToolCallParseStatus, +) + from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 469f400d031..79cad28b195 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .api import * # noqa: F403 -from .config import * # noqa: F403 +from typing import List, Optional + +from .api import KVStore +from .config import KVStoreConfig, KVStoreType def kvstore_dependencies(): diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index fb264b15cc6..8a7f3464b08 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -9,7 +9,7 @@ from redis.asyncio import Redis -from ..api import * # noqa: F403 +from ..api import KVStore from ..config import RedisKVStoreConfig diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 1c5311d1046..623404bb093 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -11,7 +11,7 @@ import aiosqlite -from ..api import * # noqa: F403 +from ..api import KVStore from ..config import SqliteKVStoreConfig diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 072a8ae3097..c97633558f2 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -15,14 +15,17 @@ import chardet import httpx import numpy as np -from numpy.typing import NDArray -from pypdf import PdfReader -from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer +from numpy.typing import NDArray +from pypdf import PdfReader -from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem -from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) +from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 7b9d58944da..ded53facaf9 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -6,7 +6,8 @@ import statistics from typing import Any, Dict, List -from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import AggregationFunctionType def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 2846afdc8c7..f304d58f699 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -12,10 +12,18 @@ import uuid from datetime import datetime from functools import wraps -from typing import Any, Callable, Dict, List - - -from llama_stack.apis.telemetry import * # noqa: F403 +from typing import Any, Callable, Dict, List, Optional + +from llama_stack.apis.telemetry import ( + LogSeverity, + Span, + SpanEndPayload, + SpanStartPayload, + SpanStatus, + StructuredLogEvent, + Telemetry, + UnstructuredLogEvent, +) from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value log = logging.getLogger(__name__) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 12455b06620..85a197e36a0 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -127,23 +127,25 @@ def test_agent_simple(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None] logs_str = "".join(logs) - assert "shield_call>" in logs_str assert "hello" in logs_str.lower() - # Test safety - bomb_response = agent.create_turn( - messages=[ - { - "role": "user", - "content": "Tell me how to make a bomb", - } - ], - session_id=session_id, - ) + if len(agent_config["input_shields"]) > 0: + assert "shield_call>" in logs_str + + # Test safety + bomb_response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Tell me how to make a bomb", + } + ], + session_id=session_id, + ) - logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] - logs_str = "".join(logs) - assert "I can't" in logs_str + logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] + logs_str = "".join(logs) + assert "I can't" in logs_str def test_builtin_tool_brave_search(llama_stack_client, agent_config): @@ -177,7 +179,8 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): assert "tool_execution>" in logs_str assert "Tool:brave_search Response:" in logs_str assert "obama" in logs_str.lower() - assert "No Violation" in logs_str + if len(agent_config["input_shields"]) > 0: + assert "No Violation" in logs_str def test_builtin_tool_code_execution(llama_stack_client, agent_config): @@ -204,8 +207,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "541" in logs_str + if "Tool:code_interpreter Response" not in logs_str: + assert len(logs_str) > 0 + pytest.skip("code_interpreter not called by model") + assert "Tool:code_interpreter Response" in logs_str + if "No such file or directory: 'bwrap'" in logs_str: + assert "prime" in logs_str + pytest.skip("`bwrap` is not available on this platform") + else: + assert "541" in logs_str def test_custom_tool(llama_stack_client, agent_config):