Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

View File

@@ -0,0 +1,49 @@
from typing import Any
import toml # type: ignore
def load_api_poetry_configs() -> dict[str, Any]:
pyproject_toml = toml.load("api/pyproject.toml")
return pyproject_toml["tool"]["poetry"]
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
configs = load_api_poetry_configs()
configs_by_group = {"main": configs}
for group_name in configs["group"]:
configs_by_group[group_name] = configs["group"][group_name]
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
return dependencies_by_group
def test_group_dependencies_sorted():
for group_name, dependencies in load_all_dependency_groups().items():
dependency_names = list(dependencies.keys())
expected_dependency_names = sorted(set(dependency_names))
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
assert expected_dependency_names == dependency_names, (
f"Dependencies in group {group_name} are not sorted. "
f"Check and fix [{section}] section in pyproject.toml file"
)
def test_group_dependencies_version_operator():
for group_name, dependencies in load_all_dependency_groups().items():
for dependency_name, specification in dependencies.items():
version_spec = specification if isinstance(specification, str) else specification["version"]
assert not version_spec.startswith("^"), (
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
f"'^' operator is too wide and not allowed in the version specification."
)
def test_duplicated_dependency_crossing_groups() -> None:
all_dependency_names: list[str] = []
for dependencies in load_all_dependency_groups().values():
dependency_names = list(dependencies.keys())
all_dependency_names.extend(dependency_names)
expected_all_dependency_names = set(all_dependency_names)
assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
"Duplicated dependencies crossing groups are found"
)

View File

@@ -0,0 +1,101 @@
# OpenAI API Key
OPENAI_API_KEY=
# Azure OpenAI API Base Endpoint & API Key
AZURE_OPENAI_API_BASE=
AZURE_OPENAI_API_KEY=
# Anthropic API Key
ANTHROPIC_API_KEY=
# Replicate API Key
REPLICATE_API_KEY=
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
# Minimax Credentials
MINIMAX_API_KEY=
MINIMAX_GROUP_ID=
# Spark Credentials
SPARK_APP_ID=
SPARK_API_KEY=
SPARK_API_SECRET=
# Tongyi Credentials
TONGYI_DASHSCOPE_API_KEY=
# Wenxin Credentials
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_GENERATION_MODEL_UID=
XINFERENCE_CHAT_MODEL_UID=
XINFERENCE_EMBEDDINGS_MODEL_UID=
XINFERENCE_RERANK_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
# LocalAI Credentials
LOCALAI_SERVER_URL=
# Cohere Credentials
COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Together API Key
TOGETHER_API_KEY=
# Mock Switch
MOCK_SWITCH=false
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTION_API_KEY=
# Volcengine MaaS Credentials
VOLC_API_KEY=
VOLC_SECRET_KEY=
VOLC_MODEL_ENDPOINT_ID=
VOLC_EMBEDDING_ENDPOINT_ID=
# 360 AI Credentials
ZHINAO_API_KEY=
# VESSL AI Credentials
VESSL_AI_MODEL_NAME=
VESSL_AI_API_KEY=
VESSL_AI_ENDPOINT_URL=
# GPUStack Credentials
GPUSTACK_SERVER_URL=
GPUSTACK_API_KEY=
# Gitee AI Credentials
GITEE_AI_API_KEY=
# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=

View File

@@ -0,0 +1 @@
.env.test

View File

@@ -0,0 +1,19 @@
import os
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()

View File

@@ -0,0 +1,25 @@
import pytest
from app_factory import create_app
from configs import dify_config
mock_user = type(
"MockUser",
(object,),
{
"is_authenticated": True,
"id": "123",
"is_editor": True,
"is_dataset_editor": True,
"status": "active",
"get_id": "123",
"current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b",
},
)
@pytest.fixture
def app():
app = create_app()
dify_config.LOGIN_DISABLED = True
return app

View File

@@ -0,0 +1,9 @@
from unittest.mock import patch
from app_fixture import mock_user # type: ignore
def test_post_requires_login(app):
with app.test_client() as client, patch("flask_login.utils._get_user", mock_user):
response = client.get("/console/api/data-source/integrates")
assert response.status_code == 200

View File

@@ -0,0 +1,98 @@
import os
from collections.abc import Iterable
from typing import Any, Literal, Union
import anthropic
import pytest
from _pytest.monkeypatch import MonkeyPatch
from anthropic import Stream
from anthropic.resources import Messages
from anthropic.types import (
ContentBlock,
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageDeltaUsage,
MessageParam,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
TextDelta,
Usage,
)
from anthropic.types.message_delta_event import Delta
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockAnthropicClass:
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Message(
id="msg-123",
type="message",
role="assistant",
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
model=model,
stop_reason="stop_sequence",
usage=Usage(input_tokens=1, output_tokens=1),
)
@staticmethod
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
full_response_text = "hello, I'm a chatbot from anthropic"
yield MessageStartEvent(
type="message_start",
message=Message(
id="msg-123",
content=[],
role="assistant",
model=model,
stop_reason=None,
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
),
)
index = 0
for i in range(0, len(full_response_text)):
yield ContentBlockDeltaEvent(
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
)
index += 1
yield MessageDeltaEvent(
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
)
yield MessageStopEvent(type="message_stop")
def mocked_anthropic(
self: Messages,
*,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any,
) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError("Invalid API key")
if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
else:
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,82 @@
import os
from collections.abc import Callable
from typing import Literal
import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch
def mock_get(*args, **kwargs):
if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
raise httpx.HTTPStatusError(
"Invalid API key",
request=httpx.Request("GET", ""),
response=httpx.Response(401),
)
return httpx.Response(
200,
json={
"items": [
{"title": "Model 1", "_id": "model1"},
{"title": "Model 2", "_id": "model2"},
]
},
request=httpx.Request("GET", ""),
)
def mock_stream(*args, **kwargs):
class MockStreamResponse:
def __init__(self):
self.status_code = 200
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def iter_bytes(self):
yield b"Mocked audio data"
return MockStreamResponse()
def mock_fishaudio(
monkeypatch: MonkeyPatch,
methods: list[Literal["list-models", "tts"]],
) -> Callable[[], None]:
"""
mock fishaudio module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "list-models" in methods:
monkeypatch.setattr(httpx, "get", mock_get)
if "tts" in methods:
monkeypatch.setattr(httpx, "stream", mock_stream)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_fishaudio_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_fishaudio(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@@ -0,0 +1,115 @@
from unittest.mock import MagicMock
import google.generativeai.types.generation_types as generation_config_types # type: ignore
import pytest
from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse
from extensions import ext_redis
class MockGoogleResponseClass:
_done = False
def __iter__(self):
full_response_text = "it's google!"
for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
else:
yield GenerateContentResponse(
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
class MockGoogleResponseCandidateClass:
finish_reason = "stop"
@property
def content(self) -> gag_content.Content:
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
class MockGoogleClass:
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
@staticmethod
def generate_content_stream() -> MockGoogleResponseClass:
return MockGoogleResponseClass()
def generate_content(
self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
stream: bool = False,
**kwargs,
) -> GenerateContentResponse:
if stream:
return MockGoogleClass.generate_content_stream()
return MockGoogleClass.generate_content_sync()
@property
def generative_response_text(self) -> str:
return "it's google!"
@property
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def mock_configure(api_key: str):
if len(api_key) < 16:
raise Exception("Invalid API key")
class MockFileState:
def __init__(self):
self.name = "FINISHED"
class MockGoogleFile:
def __init__(self, name: str = "mock_file_name"):
self.name = name
self.state = MockFileState()
def mock_get_file(name: str) -> MockGoogleFile:
return MockGoogleFile(name)
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
return MockGoogleFile()
@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr("google.generativeai.configure", mock_configure)
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
yield
monkeypatch.undo()
@pytest.fixture
def setup_mock_redis() -> None:
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.setex = MagicMock(return_value=None)
ext_redis.redis_client.exists = MagicMock(return_value=True)

View File

@@ -0,0 +1,20 @@
import os
import pytest
from _pytest.monkeypatch import MonkeyPatch
from huggingface_hub import InferenceClient # type: ignore
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,56 @@
import re
from collections.abc import Generator
from typing import Any, Literal, Optional, Union
from _pytest.monkeypatch import MonkeyPatch
from huggingface_hub import InferenceClient # type: ignore
from huggingface_hub.inference._text_generation import ( # type: ignore
Details,
StreamDetails,
TextGenerationResponse,
TextGenerationStreamResponse,
Token,
)
from huggingface_hub.utils import BadRequestError # type: ignore
class MockHuggingfaceChatClass:
@staticmethod
def generate_create_sync(model: str) -> TextGenerationResponse:
response = TextGenerationResponse(
generated_text="You can call me Miku Miku o~e~o~",
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
),
)
return response
@staticmethod
def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
full_text = "You can call me Miku Miku o~e~o~"
for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
yield response
def text_generation(
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
raise BadRequestError("Invalid API key")
if model is None:
raise BadRequestError("Invalid model")
if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

View File

@@ -0,0 +1,94 @@
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if "rerank" in model_name:
model_type = "reranker"
else:
model_type = "embedding"
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(" ")
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
"id": idx,
"text": token,
"special": False,
"start": current_index,
"stop": current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
tokenized_texts.append(tokenized_text)
return tokenized_texts
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
# {
# "object": "list",
# "data": [
# {
# "object": "embedding",
# "embedding": [...],
# "index": 0
# }
# ],
# "model": "MODEL_NAME",
# "usage": {
# "prompt_tokens": 3,
# "total_tokens": 3
# }
# }
embeddings = []
for idx in range(len(texts)):
embedding = [0.1] * 768
embeddings.append(
{
"object": "embedding",
"embedding": embedding,
"index": idx,
}
)
return {
"object": "list",
"data": embeddings,
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
"total_tokens": sum(len(text.split(" ")) for text in texts),
},
}
@staticmethod
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
# Example response:
# [
# {
# "index": 0,
# "text": "Deep Learning is ...",
# "score": 0.9950755
# }
# ]
reranked_docs = []
for idx, text in enumerate(texts):
reranked_docs.append(
{
"index": idx,
"text": text,
"score": 0.9,
}
)
# For mock, only return the first document
break
return reranked_docs

View File

@@ -0,0 +1,59 @@
import os
from collections.abc import Callable
from typing import Any, Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from nomic import embed # type: ignore
def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict:
texts_len = len(texts)
foo_embedding_sample = 0.123456
combined = {
"embeddings": [[foo_embedding_sample for _ in range(768)] for _ in range(texts_len)],
"usage": {"prompt_tokens": texts_len, "total_tokens": texts_len},
"model": model,
"inference_mode": "remote",
}
return combined
def mock_nomic(
monkeypatch: MonkeyPatch,
methods: list[Literal["text_embedding"]],
) -> Callable[[], None]:
"""
mock nomic module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "text_embedding" in methods:
monkeypatch.setattr(embed, "text", create_embedding)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_nomic_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_nomic(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@@ -0,0 +1,71 @@
import os
from collections.abc import Callable
from typing import Literal
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from openai.resources.audio.transcriptions import Transcriptions
from openai.resources.chat import Completions as ChatCompletions
from openai.resources.completions import Completions
from openai.resources.embeddings import Embeddings
from openai.resources.models import Models
from openai.resources.moderations import Moderations
from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
def mock_openai(
monkeypatch: MonkeyPatch,
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
"""
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "completion" in methods:
monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
if "chat" in methods:
monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
if "remote" in methods:
monkeypatch.setattr(Models, "list", MockModelClass.list)
if "moderation" in methods:
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
if "speech2text" in methods:
monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
if "text_embedding" in methods:
monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@@ -0,0 +1,267 @@
import re
from collections.abc import Generator
from json import dumps
from time import time
# import monkeypatch
from typing import Any, Literal, Optional, Union
from openai import AzureOpenAI, OpenAI
from openai._types import NOT_GIVEN, NotGiven
from openai.resources.chat.completions import Completions
from openai.types import Completion as CompletionMessage
from openai.types.chat import (
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionToolParam,
completion_create_params,
)
from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion
from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice
from openai.types.chat.chat_completion_chunk import (
Choice,
ChoiceDelta,
ChoiceDeltaFunctionCall,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.completion_usage import CompletionUsage
from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockChatClass:
@staticmethod
def generate_function_call(
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
) -> Optional[FunctionCall]:
if not functions or len(functions) == 0:
return None
function: completion_create_params.Function = functions[0]
function_name = function["name"]
function_description = function["description"]
function_parameters = function["parameters"]
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == "boolean":
parameters[parameter_name] = True
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool = tools[0]
if "type" in tools and tools["type"] != "function":
return None
function = tool["function"]
function_call = MockChatClass.generate_function_call(functions=[function])
if function_call is None:
return None
list_tool_calls.append(
ChatCompletionMessageToolCall(
id="sakurajima-mai",
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type="function",
)
)
return list_tool_calls
@staticmethod
def mocked_openai_chat_create_sync(
model: str,
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> CompletionMessage:
tool_calls = []
function_call = MockChatClass.generate_function_call(functions=functions)
if not function_call:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
return _ChatCompletion(
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
_ChatCompletionChoice(
finish_reason="content_filter",
index=0,
message=ChatCompletionMessage(
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
),
)
],
created=int(time()),
model=model,
object="chat.completion",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
),
)
@staticmethod
def mocked_openai_chat_create_stream(
model: str,
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Generator[ChatCompletionChunk, None, None]:
tool_calls = []
function_call = MockChatClass.generate_function_call(functions=functions)
if not function_call:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield ChatCompletionChunk(
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content="",
function_call=ChoiceDeltaFunctionCall(
name=function_call.name,
arguments=function_call.arguments,
)
if function_call
else None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="misaka-mikoto",
function=ChoiceDeltaToolCallFunction(
name=tool_calls[0].function.name,
arguments=tool_calls[0].function.arguments,
),
type="function",
)
]
if tool_calls and len(tool_calls) > 0
else None,
),
finish_reason="function_call",
index=0,
)
],
created=int(time()),
model=model,
object="chat.completion.chunk",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
),
)
else:
yield ChatCompletionChunk(
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content=full_text[i],
role="assistant",
),
finish_reason="content_filter",
index=0,
)
],
created=int(time()),
model=model,
object="chat.completion.chunk",
system_fingerprint="",
)
def chat_create(
self: Completions,
*,
messages: list[ChatCompletionMessageParam],
model: Union[
str,
Literal[
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
],
],
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
**kwargs: Any,
):
openai_models = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError("Invalid api key")
if stream:
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

View File

@@ -0,0 +1,130 @@
import re
from collections.abc import Generator
from time import time
# import monkeypatch
from typing import Any, Literal, Optional, Union
from openai import AzureOpenAI, BadRequestError, OpenAI
from openai._types import NOT_GIVEN, NotGiven
from openai.resources.completions import Completions
from openai.types import Completion as CompletionMessage
from openai.types.completion import CompletionChoice
from openai.types.completion_usage import CompletionUsage
from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockCompletionsClass:
@staticmethod
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
return CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text="mock",
index=0,
logprobs=None,
finish_reason="stop",
)
],
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
),
)
@staticmethod
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text="",
index=0,
logprobs=None,
finish_reason="stop",
)
],
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
),
)
else:
yield CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
created=int(time()),
model=model,
system_fingerprint="",
choices=[
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
],
)
def completion_create(
self: Completions,
*,
model: Union[
str,
Literal[
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
],
],
prompt: Union[str, list[str], list[int], list[list[int]], None],
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any,
):
openai_models = [
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
]
azure_openai_models = ["gpt-35-turbo-instruct"]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError("Invalid api key")
if not prompt:
raise BadRequestError("Invalid prompt")
if stream:
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,140 @@
import re
from typing import Any, Literal, Union
from openai._types import NOT_GIVEN, NotGiven
from openai.resources.moderations import Moderations
from openai.types import ModerationCreateResponse
from openai.types.moderation import Categories, CategoryScores, Moderation
from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockModerationClass:
def moderation_create(
self: Moderations,
*,
input: Union[str, list[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError("Invalid API key")
for text in input:
result = []
if "kill" in text:
moderation_categories = {
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
"illicit": False,
"illicit/violent": False,
}
moderation_categories_scores = {
"harassment": 1.0,
"harassment/threatening": 1.0,
"hate": 1.0,
"hate/threatening": 1.0,
"self-harm": 1.0,
"self-harm/instructions": 1.0,
"self-harm/intent": 1.0,
"sexual": 1.0,
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
"illicit": 1.0,
"illicit/violent": 1.0,
}
category_applied_input_types = {
"sexual": ["text", "image"],
"hate": ["text"],
"harassment": ["text"],
"self-harm": ["text", "image"],
"sexual/minors": ["text"],
"hate/threatening": ["text"],
"violence/graphic": ["text", "image"],
"self-harm/intent": ["text", "image"],
"self-harm/instructions": ["text", "image"],
"harassment/threatening": ["text"],
"violence": ["text", "image"],
"illicit": ["text"],
"illicit/violent": ["text"],
}
result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
category_applied_input_types=category_applied_input_types,
)
)
else:
moderation_categories = {
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
"illicit": False,
"illicit/violent": False,
}
moderation_categories_scores = {
"harassment": 0.0,
"harassment/threatening": 0.0,
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"self-harm/instructions": 0.0,
"self-harm/intent": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
"illicit": 0.0,
"illicit/violent": 0.0,
}
category_applied_input_types = {
"sexual": ["text", "image"],
"hate": ["text"],
"harassment": ["text"],
"self-harm": ["text", "image"],
"sexual/minors": ["text"],
"hate/threatening": ["text"],
"violence/graphic": ["text", "image"],
"self-harm/intent": ["text", "image"],
"self-harm/instructions": ["text", "image"],
"harassment/threatening": ["text"],
"violence": ["text", "image"],
"illicit": ["text"],
"illicit/violent": ["text"],
}
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
category_applied_input_types=category_applied_input_types,
)
)
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

View File

@@ -0,0 +1,22 @@
from time import time
from openai.types.model import Model
class MockModelClass:
"""
mock class for openai.models.Models
"""
def list(
self,
**kwargs,
) -> list[Model]:
return [
Model(
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
created=int(time()),
object="model",
owned_by="organization:org-123",
)
]

View File

@@ -0,0 +1,29 @@
import re
from typing import Any, Literal, Union
from openai._types import NOT_GIVEN, FileTypes, NotGiven
from openai.resources.audio.transcriptions import Transcriptions
from openai.types.audio.transcription import Transcription
from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockSpeech2TextClass:
def speech2text_create(
self: Transcriptions,
*,
file: FileTypes,
model: Union[str, Literal["whisper-1"]],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> Transcription:
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError("Invalid API key")
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")

View File

@@ -0,0 +1,169 @@
import os
import re
from typing import Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests import Response
from requests.sessions import Session
from xinference_client.client.restful.restful_client import ( # type: ignore
Client,
RESTfulChatModelHandle,
RESTfulEmbeddingModelHandle,
RESTfulGenerateModelHandle,
RESTfulRerankModelHandle,
)
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore
class MockXinferenceClass:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")
if model_uid == "generate":
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if model_uid == "chat":
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if model_uid == "embedding":
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if model_uid == "rerank":
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError("404 Not Found")
def get(self: Session, url: str, **kwargs):
response = Response()
if "v1/models/" in url:
# get model uid
model_uid = url.split("/")[-1] or ""
if not re.match(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
) and model_uid not in {"generate", "chat", "embedding", "rerank"}:
response.status_code = 404
response._content = b"{}"
return response
# check if url is valid
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
response.status_code = 404
response._content = b"{}"
return response
if model_uid in {"generate", "chat"}:
response.status_code = 200
response._content = b"""{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}"""
return response
elif model_uid == "embedding":
response.status_code = 200
response._content = b"""{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}"""
return response
elif "v1/cluster/auth" in url:
response.status_code = 200
response._content = b"""{
"auth": true
}"""
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
) -> dict:
# check if self._model_uid is a valid uuid
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "rerank"
):
raise RuntimeError("404 Not Found")
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
raise RuntimeError("404 Not Found")
if top_n is None:
top_n = 1
return {
"results": [
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
# check if self._model_uid is a valid uuid
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "embedding"
):
raise RuntimeError("404 Not Found")
if isinstance(input, str):
input = [input]
ipt_len = len(input)
embedding = Embedding(
object="list",
model=self._model_uid,
data=[
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
for i in range(ipt_len)
],
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
)
return embedding
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,92 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model="claude-instant-1.2",
credentials={
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_stream_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = AnthropicLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@@ -0,0 +1,17 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_anthropic_mock):
provider = AnthropicProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})

View File

@@ -0,0 +1,109 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
def test_validate_credentials(setup_azure_ai_studio_mock):
model = AzureAIStudioLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="gpt-35-turbo",
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
)
model.validate_credentials(
model="gpt-35-turbo",
credentials={
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
},
)
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
def test_invoke_model(setup_azure_ai_studio_mock):
model = AzureAIStudioLargeLanguageModel()
result = model.invoke(
model="gpt-35-turbo",
credentials={
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
def test_invoke_stream_model(setup_azure_ai_studio_mock):
model = AzureAIStudioLargeLanguageModel()
result = model.invoke(
model="gpt-35-turbo",
credentials={
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0
def test_get_num_tokens():
model = AzureAIStudioLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="gpt-35-turbo",
credentials={
"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"),
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 21

View File

@@ -0,0 +1,17 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.azure_ai_studio import AzureAIStudioProvider
def test_validate_provider_credentials():
provider = AzureAIStudioProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={"api_key": os.getenv("AZURE_AI_STUDIO_API_KEY"), "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}
)

View File

@@ -0,0 +1,42 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
def test_validate_credentials():
model = AzureRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="azure-ai-studio-rerank-v1",
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
)
def test_invoke_model():
model = AzureRerankModel()
result = model.invoke(
model="azure-ai-studio-rerank-v1",
credentials={
"api_key": os.getenv("AZURE_AI_STUDIO_JWT_TOKEN"),
"api_base": os.getenv("AZURE_AI_STUDIO_API_BASE"),
},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 1
assert result.docs[0].score >= 0.8

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,62 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="embedding",
credentials={
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": "invalid_key",
"base_model_name": "text-embedding-ada-002",
},
)
model.validate_credentials(
model="embedding",
credentials={
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
result = model.invoke(
model="embedding",
credentials={
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = AzureOpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
)
assert num_tokens == 2

View File

@@ -0,0 +1,172 @@
import os
from collections.abc import Generator
from time import sleep
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLanguageModel
def test_predefined_models():
model = BaichuanLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = BaichuanLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
)
def test_invoke_model():
sleep(3)
model = BaichuanLanguageModel()
response = model.invoke(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=["you"],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_with_system_message():
sleep(3)
model = BaichuanLanguageModel()
response = model.invoke(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(content="请记住你是Kasumi。"),
UserPromptMessage(content="现在告诉我你是谁?"),
],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=["you"],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = BaichuanLanguageModel()
response = model.invoke(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = BaichuanLanguageModel()
response = model.invoke(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"with_search_enhance": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
total_message += chunk.delta.message.content
assert "" not in total_message
def test_get_num_tokens():
sleep(3)
model = BaichuanLanguageModel()
response = model.get_num_tokens(
model="baichuan2-turbo",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 9

View File

@@ -0,0 +1,15 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider
def test_validate_provider_credentials():
provider = BaichuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})

View File

@@ -0,0 +1,87 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel
def test_validate_credentials():
model = BaichuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
)
def test_invoke_model():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model="baichuan-text-embedding",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = BaichuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="baichuan-text-embedding",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model="baichuan-text-embedding",
credentials={
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22

View File

@@ -0,0 +1,103 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
def test_validate_credentials():
model = BedrockLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
)
def test_invoke_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
print(chunk)
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = BedrockLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@@ -0,0 +1,21 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider
def test_validate_provider_credentials():
provider = BedrockProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
}
)

View File

@@ -0,0 +1,229 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
def test_predefined_models():
model = ChatGLMLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
},
stop=["you"],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
),
UserPromptMessage(content="波士顿天气如何?"),
],
model_parameters={
"temperature": 0,
"top_p": 1.0,
},
stop=["you"],
user="abc-123",
stream=True,
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(response, Generator)
call: LLMResultChunk = None
chunks = []
for chunk in response:
chunks.append(chunk)
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
call = chunk
break
assert call is not None
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
},
stop=["you"],
user="abc-123",
stream=False,
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
assert response.message.tool_calls[0].function.name == "get_current_weather"
def test_get_num_tokens():
model = ChatGLMLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21

View File

@@ -0,0 +1,17 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = ChatGLMProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})

View File

@@ -0,0 +1,191 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
def test_validate_credentials_for_chat_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_validate_credentials_for_completion_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_completion_model():
model = CohereLargeLanguageModel()
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
result = model.invoke(
model="command-light",
credentials=credentials,
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 1},
stream=False,
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
def test_invoke_stream_completion_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"p": 0.99,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_invoke_stream_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0
def test_get_num_tokens():
model = CohereLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 3
num_tokens = model.get_num_tokens(
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 15
def test_fine_tuned_model():
model = CohereLargeLanguageModel()
# test invoke
result = model.invoke(
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123",
)
assert isinstance(result, LLMResult)
def test_fine_tuned_chat_model():
model = CohereLargeLanguageModel()
# test invoke
result = model.invoke(
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123",
)
assert isinstance(result, LLMResult)

View File

@@ -0,0 +1,15 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.cohere import CohereProvider
def test_validate_provider_credentials():
provider = CohereProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})

View File

@@ -0,0 +1,40 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
def test_validate_credentials():
model = CohereRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_model():
model = CohereRerankModel()
result = model.invoke(
model="rerank-english-v2.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
"is the capital of the United States. It is a federal district. The President of the USA and many major "
"national government offices are in the territory. This makes it the political center of the United "
"States of America.",
],
score_threshold=0.8,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 1
assert result.docs[0].score >= 0.8

View File

@@ -0,0 +1,45 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel
def test_validate_credentials():
model = CohereTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
)
def test_invoke_model():
model = CohereTextEmbeddingModel()
result = model.invoke(
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 4
assert result.usage.total_tokens == 811
def test_get_num_tokens():
model = CohereTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world"],
)
assert num_tokens == 3

View File

@@ -0,0 +1,186 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
def test_predefined_models():
model = FireworksLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = FireworksLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
model.validate_credentials(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock):
model = FireworksLargeLanguageModel()
result = model.invoke(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=["How"],
stream=False,
user="foo",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model_with_tools(setup_openai_mock):
model = FireworksLargeLanguageModel()
result = model.invoke(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
PromptMessageTool(
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="foo",
)
assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)
assert len(result.message.tool_calls) > 0
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock):
model = FireworksLargeLanguageModel()
result = model.invoke(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="foo",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0
def test_get_num_tokens():
model = FireworksLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 10
num_tokens = model.get_num_tokens(
model="accounts/fireworks/models/llama-v3p1-8b-instruct",
credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
],
)
assert num_tokens == 77

View File

@@ -0,0 +1,17 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = FireworksProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")})

View File

@@ -0,0 +1,54 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.fireworks.text_embedding.text_embedding import FireworksTextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = FireworksTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": "invalid_key"}
)
model.validate_credentials(
model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}
)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = FireworksTextEmbeddingModel()
result = model.invoke(
model="nomic-ai/nomic-embed-text-v1.5",
credentials={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="foo",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 4
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = FireworksTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="nomic-ai/nomic-embed-text-v1.5",
credentials={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@@ -0,0 +1,33 @@
import os
import httpx
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True)
def test_validate_provider_credentials(setup_fishaudio_mock):
print("-----", httpx.get)
provider = FishAudioProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
"api_key": "bad_api_key",
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
}
)
provider.validate_provider_credentials(
credentials={
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
}
)

View File

@@ -0,0 +1,32 @@
import os
import pytest
from core.model_runtime.model_providers.fishaudio.tts.tts import (
FishAudioText2SpeechModel,
)
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True)
def test_invoke_model(setup_fishaudio_mock):
model = FishAudioText2SpeechModel()
result = model.invoke(
model="tts-default",
tenant_id="test",
credentials={
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
},
content_text="Hello, world!",
voice="03397b4c4be74759b72533b663fbd001",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@@ -0,0 +1,132 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel
def test_predefined_models():
model = GiteeAILargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
model = GiteeAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model="Qwen2-7B-Instruct",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
)
def test_invoke_chat_model():
model = GiteeAILargeLanguageModel()
result = model.invoke(
model="Qwen2-7B-Instruct",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
"stream": False,
},
stop=["How"],
stream=False,
user="foo",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_invoke_stream_chat_model():
model = GiteeAILargeLanguageModel()
result = model.invoke(
model="Qwen2-7B-Instruct",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False},
stream=True,
user="foo",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
def test_get_num_tokens():
model = GiteeAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model="Qwen2-7B-Instruct",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 10
num_tokens = model.get_num_tokens(
model="Qwen2-7B-Instruct",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
],
)
assert num_tokens == 77

View File

@@ -0,0 +1,15 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider
def test_validate_provider_credentials():
provider = GiteeAIProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "invalid_key"})
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")})

View File

@@ -0,0 +1,47 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel
def test_validate_credentials():
model = GiteeAIRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={"api_key": "invalid_key"},
)
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"api_key": os.environ.get("GITEE_AI_API_KEY"),
},
)
def test_invoke_model():
model = GiteeAIRerankModel()
result = model.invoke(
model="bge-reranker-v2-m3",
credentials={
"api_key": os.environ.get("GITEE_AI_API_KEY"),
},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
top_n=1,
score_threshold=0.01,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].score >= 0.01

View File

@@ -0,0 +1,45 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel
def test_validate_credentials():
model = GiteeAISpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="whisper-base",
credentials={"api_key": "invalid_key"},
)
model.validate_credentials(
model="whisper-base",
credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")},
)
def test_invoke_model():
model = GiteeAISpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file
)
assert isinstance(result, str)
assert result == "1 2 3 4 5 6 7 8 9 10"

View File

@@ -0,0 +1,46 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel
def test_validate_credentials():
model = GiteeAIEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"})
model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")})
def test_invoke_model():
model = GiteeAIEmbeddingModel()
result = model.invoke(
model="bge-large-zh-v1.5",
credentials={
"api_key": os.environ.get("GITEE_AI_API_KEY"),
},
texts=["hello", "world"],
user="user",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
def test_get_num_tokens():
model = GiteeAIEmbeddingModel()
num_tokens = model.get_num_tokens(
model="bge-large-zh-v1.5",
credentials={
"api_key": os.environ.get("GITEE_AI_API_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@@ -0,0 +1,23 @@
import os
from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel
def test_invoke_model():
model = GiteeAIText2SpeechModel()
result = model.invoke(
model="speecht5_tts",
tenant_id="test",
credentials={
"api_key": os.environ.get("GITEE_AI_API_KEY"),
},
content_text="Hello, world!",
voice="",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,17 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.google import GoogleProvider
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_google_mock):
provider = GoogleProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})

View File

@@ -0,0 +1,49 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
GPUStackTextEmbeddingModel,
)
def test_validate_credentials():
model = GPUStackTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-m3",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="bge-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackTextEmbeddingModel()
result = model.invoke(
model="bge-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"context_size": 8192,
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 7

View File

@@ -0,0 +1,162 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
def test_validate_credentials_for_chat_model():
model = GPUStackLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
"mode": "chat",
},
)
model.validate_credentials(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
)
def test_invoke_completion_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_chat_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = GPUStackLanguageModel()
num_tokens = model.get_num_tokens(
model="????",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 80
num_tokens = model.get_num_tokens(
model="????",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert isinstance(num_tokens, int)
assert num_tokens == 10

View File

@@ -0,0 +1,107 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.rerank.rerank import (
GPUStackRerankModel,
)
def test_validate_credentials_for_rerank_model():
model = GPUStackRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_rerank_model():
model = GPUStackRerankModel()
response = model.invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=-0.75,
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 3
def test__invoke():
model = GPUStackRerankModel()
# Test case 1: Empty docs
result = model._invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0
# Test case 2: Expected docs
result = model._invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=-0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

View File

@@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
def test_validate_credentials():
model = GPUStackSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
file = Path(audio_file_path).read_bytes()
result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@@ -0,0 +1,24 @@
import os
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
def test_invoke_model():
model = GPUStackText2SpeechModel()
result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@@ -0,0 +1,278 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
@pytest.mark.skip
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="HuggingFaceH4/zephyr-7b-beta",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="fake-model",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
model.validate_credentials(
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="openchat/openchat_3.5",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
model.validate_credentials(
model="openchat/openchat_3.5",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="openchat/openchat_3.5",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="openchat/openchat_3.5",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="google/mt5-base",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
model.validate_credentials(
model="google/mt5-base",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="google/mt5-base",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model="google/mt5-base",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=["How"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="google/mt5-base",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 7

View File

@@ -0,0 +1,112 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import (
HuggingfaceHubTextEmbeddingModel,
)
def test_hosted_inference_api_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="facebook/bart-base",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": "invalid_key",
},
)
model.validate_credentials(
model="facebook/bart-base",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
def test_hosted_inference_api_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model="facebook/bart-base",
credentials={
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_inference_endpoints_validate_credentials():
model = HuggingfaceHubTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="all-MiniLM-L6-v2",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
model.validate_credentials(
model="all-MiniLM-L6-v2",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
def test_inference_endpoints_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model="all-MiniLM-L6-v2",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 0
def test_get_num_tokens():
model = HuggingfaceHubTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="all-MiniLM-L6-v2",
credentials={
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@@ -0,0 +1,73 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
HuggingfaceTeiTextEmbeddingModel,
TeiHelper,
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
# model name is only used in mock
model_name = "embedding"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="reranker",
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
model_name = "embedding"
result = model.invoke(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0

View File

@@ -0,0 +1,80 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
HuggingfaceTeiRerankModel,
)
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = "reranker"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="embedding",
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
model.validate_credentials(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = "reranker"
result = model.invoke(
model=model_name,
credentials={
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
"api_key": os.environ.get("TEI_API_KEY", ""),
},
query="Who is Kasumi?",
docs=[
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty.",
],
score_threshold=0.8,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.8

View File

@@ -0,0 +1,90 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel
def test_validate_credentials():
model = HunyuanLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model="hunyuan-standard",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
def test_invoke_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model="hunyuan-standard",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model="hunyuan-standard",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = HunyuanLargeLanguageModel()
num_tokens = model.get_num_tokens(
model="hunyuan-standard",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14

View File

@@ -0,0 +1,20 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.hunyuan.hunyuan import HunyuanProvider
def test_validate_provider_credentials():
provider = HunyuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
provider.validate_provider_credentials(
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
}
)

View File

@@ -0,0 +1,96 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel
def test_validate_credentials():
model = HunyuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model="hunyuan-embedding",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
def test_invoke_model():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model="hunyuan-embedding",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = HunyuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="hunyuan-embedding",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model="hunyuan-embedding",
credentials={
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
"hello",
"world",
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22

View File

@@ -0,0 +1,15 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.jina.jina import JinaProvider
def test_validate_provider_credentials():
provider = JinaProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})

View File

@@ -0,0 +1,49 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel
def test_validate_credentials():
model = JinaTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
)
def test_invoke_model():
model = JinaTextEmbeddingModel()
result = model.invoke(
model="jina-embeddings-v2-base-en",
credentials={
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = JinaTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="jina-embeddings-v2-base-en",
credentials={
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 6

View File

@@ -0,0 +1,4 @@
"""
LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now.
"""

View File

@@ -0,0 +1,172 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.llm.llm import LocalAILanguageModel
def test_validate_credentials_for_chat_model():
model = LocalAILanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="chinese-llama-2-7b",
credentials={
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model="chinese-llama-2-7b",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model="chinese-llama-2-7b",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model="chinese-llama-2-7b",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model="chinese-llama-2-7b",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_stream_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model="chinese-llama-2-7b",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = LocalAILanguageModel()
num_tokens = model.get_num_tokens(
model="????",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model="????",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert isinstance(num_tokens, int)
assert num_tokens == 10

View File

@@ -0,0 +1,96 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
def test_validate_credentials_for_chat_model():
model = LocalaiRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model="bge-reranker-base",
credentials={
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_rerank_model():
model = LocalaiRerankModel()
response = model.invoke(
model="bge-reranker-base",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 3
def test__invoke():
model = LocalaiRerankModel()
# Test case 1: Empty docs
result = model._invoke(
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0
# Test case 2: Valid invocation
result = model._invoke(
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

View File

@@ -0,0 +1,42 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.speech2text.speech2text import LocalAISpeech2text
def test_validate_credentials():
model = LocalAISpeech2text()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
def test_invoke_model():
model = LocalAISpeech2text()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model="whisper-1",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
file=file,
user="abc-123",
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@@ -0,0 +1,58 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel
def test_validate_credentials():
model = MinimaxTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="embo-01",
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
)
model.validate_credentials(
model="embo-01",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
model = MinimaxTextEmbeddingModel()
result = model.invoke(
model="embo-01",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 16
def test_get_num_tokens():
model = MinimaxTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="embo-01",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@@ -0,0 +1,143 @@
import os
from collections.abc import Generator
from time import sleep
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel
def test_predefined_models():
model = MinimaxLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = MinimaxLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
)
model.validate_credentials(
model="abab5.5-chat",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model="abab5-chat",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=["you"],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model="abab5.5-chat",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model="abab5.5-chat",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
assert "参考资料" in total_message
def test_get_num_tokens():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.get_num_tokens(
model="abab5.5-chat",
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 30

View File

@@ -0,0 +1,25 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider
def test_validate_provider_credentials():
provider = MinimaxProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
"minimax_api_key": "hahahaha",
"minimax_group_id": "123",
}
)
provider.validate_provider_credentials(
credentials={
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
}
)

View File

@@ -0,0 +1,28 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.mixedbread.mixedbread import MixedBreadProvider
def test_validate_provider_credentials():
provider = MixedBreadProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"usage": {"prompt_tokens": 3, "total_tokens": 3},
"model": "mixedbread-ai/mxbai-embed-large-v1",
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
"object": "list",
"normalized": "true",
"encoding_format": "float",
"dimensions": 1024,
}
mock_response.status_code = 200
mock_post.return_value = mock_response
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")})

View File

@@ -0,0 +1,100 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel
def test_validate_credentials():
model = MixedBreadRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="mxbai-rerank-large-v1",
credentials={"api_key": "invalid_key"},
)
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"usage": {"prompt_tokens": 86, "total_tokens": 86},
"model": "mixedbread-ai/mxbai-rerank-large-v1",
"data": [
{
"index": 0,
"score": 0.06762695,
"input": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
"States Census, Carson City had a population of 55,274.",
"object": "text_document",
},
{
"index": 1,
"score": 0.057403564,
"input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific "
"Ocean that are a political division controlled by the United States. Its capital is "
"Saipan.",
"object": "text_document",
},
],
"object": "list",
"top_k": 2,
"return_input": True,
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(
model="mxbai-rerank-large-v1",
credentials={
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
},
)
def test_invoke_model():
model = MixedBreadRerankModel()
with patch("httpx.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"usage": {"prompt_tokens": 56, "total_tokens": 56},
"model": "mixedbread-ai/mxbai-rerank-large-v1",
"data": [
{
"index": 0,
"score": 0.6044922,
"input": "Kasumi is a girl name of Japanese origin meaning mist.",
"object": "text_document",
},
{
"index": 1,
"score": 0.0703125,
"input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a "
"team named PopiParty.",
"object": "text_document",
},
],
"object": "list",
"top_k": 2,
"return_input": "true",
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="mxbai-rerank-large-v1",
credentials={
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl name of Japanese origin meaning mist.",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
"PopiParty.",
],
score_threshold=0.5,
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.5

View File

@@ -0,0 +1,78 @@
import os
from unittest.mock import Mock, patch
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.mixedbread.text_embedding.text_embedding import MixedBreadTextEmbeddingModel
def test_validate_credentials():
model = MixedBreadTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model="mxbai-embed-large-v1", credentials={"api_key": "invalid_key"})
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"usage": {"prompt_tokens": 3, "total_tokens": 3},
"model": "mixedbread-ai/mxbai-embed-large-v1",
"data": [{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"}],
"object": "list",
"normalized": "true",
"encoding_format": "float",
"dimensions": 1024,
}
mock_response.status_code = 200
mock_post.return_value = mock_response
model.validate_credentials(
model="mxbai-embed-large-v1", credentials={"api_key": os.environ.get("MIXEDBREAD_API_KEY")}
)
def test_invoke_model():
model = MixedBreadTextEmbeddingModel()
with patch("requests.post") as mock_post:
mock_response = Mock()
mock_response.json.return_value = {
"usage": {"prompt_tokens": 6, "total_tokens": 6},
"model": "mixedbread-ai/mxbai-embed-large-v1",
"data": [
{"embedding": [0.23333 for _ in range(1024)], "index": 0, "object": "embedding"},
{"embedding": [0.23333 for _ in range(1024)], "index": 1, "object": "embedding"},
],
"object": "list",
"normalized": "true",
"encoding_format": "float",
"dimensions": 1024,
}
mock_response.status_code = 200
mock_post.return_value = mock_response
result = model.invoke(
model="mxbai-embed-large-v1",
credentials={
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = MixedBreadTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="mxbai-embed-large-v1",
credentials={
"api_key": os.environ.get("MIXEDBREAD_API_KEY"),
},
texts=["ping"],
)
assert num_tokens == 1

Some files were not shown because too many files have changed in this diff Show More