Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

@@ -0,0 +1 @@
.env.test

View File

@@ -0,0 +1,102 @@
import os
from textwrap import dedent
import pytest
from flask import Flask
from yarl import URL
from configs.app_config import DifyConfig
EXAMPLE_ENV_FILENAME = ".env"
@pytest.fixture
def example_env_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME)
file_path.write_text(
dedent(
"""
CONSOLE_API_URL=https://example.com
CONSOLE_WEB_URL=https://example.com
HTTP_REQUEST_MAX_WRITE_TIMEOUT=30
"""
)
)
return str(file_path)
def test_dify_config_undefined_entry(example_env_file):
# NOTE: See https://github.com/microsoft/pylance-release/issues/6099 for more details about this type error.
# load dotenv file with pydantic-settings
config = DifyConfig(_env_file=example_env_file)
# entries not defined in app settings
with pytest.raises(TypeError):
# TypeError: 'AppSettings' object is not subscriptable
assert config["LOG_LEVEL"] == "INFO"
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_dify_config(example_env_file):
# clear system environment variables
os.environ.clear()
# load dotenv file with pydantic-settings
config = DifyConfig(_env_file=example_env_file)
# constant values
assert config.COMMIT_SHA == ""
# default values
assert config.EDITION == "SELF_HOSTED"
assert config.API_COMPRESSION_ENABLED is False
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
# annotated field with default value
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60
# annotated field with configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_flask_configs(example_env_file):
flask_app = Flask("app")
# clear system environment variables
os.environ.clear()
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore
config = flask_app.config
# configs read from pydantic-settings
assert config["LOG_LEVEL"] == "INFO"
assert config["COMMIT_SHA"] == ""
assert config["EDITION"] == "SELF_HOSTED"
assert config["API_COMPRESSION_ENABLED"] is False
assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0
# value from env file
assert config["CONSOLE_API_URL"] == "https://example.com"
# fallback to alias choices value as CONSOLE_API_URL
assert config["FILES_URL"] == "https://example.com"
assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify"
assert config["SQLALCHEMY_ENGINE_OPTIONS"] == {
"connect_args": {
"options": "-c timezone=UTC",
},
"max_overflow": 10,
"pool_pre_ping": False,
"pool_recycle": 3600,
"pool_size": 30,
}
assert config["CONSOLE_WEB_URL"] == "https://example.com"
assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"]
assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"]
assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/"
assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1"

View File

@@ -0,0 +1,23 @@
import os
import pytest
from flask import Flask
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
CACHED_APP = Flask(__name__)
@pytest.fixture
def app() -> Flask:
return CACHED_APP
@pytest.fixture(autouse=True)
def _provide_app_context(app: Flask):
with app.app_context():
yield

View File

@@ -0,0 +1,24 @@
import pytest
from controllers.console.version import _has_new_version
@pytest.mark.parametrize(
("latest_version", "current_version", "expected"),
[
("1.0.1", "1.0.0", True),
("1.1.0", "1.0.0", True),
("2.0.0", "1.9.9", True),
("1.0.0", "1.0.0", False),
("1.0.0", "1.0.1", False),
("1.0.0", "2.0.0", False),
("1.0.1", "1.0.0-beta", True),
("1.0.0", "1.0.0-alpha", True),
("1.0.0-beta", "1.0.0-alpha", True),
("1.0.0", "1.0.0-rc1", True),
("1.0.0", "0.9.9", True),
("1.0.0", "1.0.0-dev", True),
],
)
def test_has_new_version(latest_version, current_version, expected):
assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected

View File

@@ -0,0 +1,61 @@
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.models import FileTransferMethod, FileUploadConfig, ImageConfig
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
def test_convert_with_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
"image": {"detail": "high"},
}
}
result = FileUploadConfigManager.convert(config, is_vision=True)
expected = FileUploadConfig(
image_config=ImageConfig(
number_limits=5,
transfer_methods=[FileTransferMethod.REMOTE_URL],
detail=ImagePromptMessageContent.DETAIL.HIGH,
)
)
assert result == expected
def test_convert_without_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result = FileUploadConfigManager.convert(config, is_vision=False)
expected = FileUploadConfig(
image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL])
)
assert result == expected
def test_validate_and_set_defaults():
config = {}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
def test_validate_and_set_defaults_with_existing_config():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
assert result["file_upload"]["enabled"] is True
assert result["file_upload"]["number_limits"] == 5
assert result["file_upload"]["allowed_file_upload_methods"] == [FileTransferMethod.REMOTE_URL]

View File

@@ -0,0 +1,52 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
def test_validate_inputs_with_zero():
base_app_generator = BaseAppGenerator()
var = VariableEntity(
variable="test_var",
label="test_var",
type=VariableEntityType.NUMBER,
required=True,
)
# Test with input 0
result = base_app_generator._validate_inputs(
variable_entity=var,
value=0,
)
assert result == 0
# Test with input "0" (string)
result = base_app_generator._validate_inputs(
variable_entity=var,
value="0",
)
assert result == 0
def test_validate_input_with_none_for_required_variable():
base_app_generator = BaseAppGenerator()
for var_type in VariableEntityType:
var = VariableEntity(
variable="test_var",
label="test_var",
type=var_type,
required=True,
)
# Test with input None
with pytest.raises(ValueError) as exc_info:
base_app_generator._validate_inputs(
variable_entity=var,
value=None,
)
assert str(exc_info.value) == "test_var is required in input form"

View File

@@ -0,0 +1,165 @@
from uuid import uuid4
import pytest
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import ArrayAnySegment
from factories import variable_factory
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None
def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]

View File

@@ -0,0 +1,58 @@
from core.helper import encrypter
from core.variables import SecretVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
assert segments_group.log == (
f"Hello, fake-user-id! Your query is fake-user-query."
f" And your key is {encrypter.obfuscated_token('fake-secret-key')}."
)
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "Hello, world!"
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, world!"
assert segments_group.log == "Hello, world!"
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "{{#sys.user_id#}}"
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"

View File

@@ -0,0 +1,90 @@
import pytest
from pydantic import ValidationError
from core.variables import (
ArrayFileVariable,
ArrayVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
SegmentType,
StringVariable,
)
def test_frozen_variables():
var = StringVariable(name="text", value="text")
with pytest.raises(ValidationError):
var.value = "new value"
int_var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
int_var.value = 100
float_var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
float_var.value = 2.718
secret_var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
secret_var.value = "new_secret_value"
def test_variable_value_type_immutable():
with pytest.raises(ValidationError):
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
with pytest.raises(ValidationError):
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
def test_object_variable_to_object():
var = ObjectVariable(
name="object",
value={
"key1": {
"key2": "value2",
},
"key2": ["value5_1", 42, {}],
},
)
assert var.to_object() == {
"key1": {
"key2": "value2",
},
"key2": [
"value5_1",
42,
{},
],
}
def test_variable_to_object():
var = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
var = FloatVariable(name="float", value=3.14)
assert var.to_object() == 3.14
var = SecretVariable(name="secret", value="secret_value")
assert var.to_object() == "secret_value"
def test_array_file_variable_is_array_variable():
var = ArrayFileVariable(name="files", value=[])
assert isinstance(var, ArrayVariable)

View File

@@ -0,0 +1,7 @@
from core.helper.marketplace import download_plugin_pkg
def test_download_plugin_pkg():
pkg = download_plugin_pkg("langgenius/bing:0.0.1@e58735424d2104f208c2bd683c5142e0332045b425927067acf432b26f3d970b")
assert pkg is not None
assert len(pkg) > 0

View File

@@ -0,0 +1,52 @@
import random
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch("httpx.Client.request")
def test_successful_request(mock_request):
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
mock_response = MagicMock()
mock_response.status_code = 500
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
mock_request.side_effect = side_effects
with pytest.raises(Exception) as e:
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = random.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)
mock_response_200 = MagicMock()
mock_response_200.status_code = 200
side_effects.append(mock_response_200)
mock_request.side_effect = side_effects
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@@ -0,0 +1,190 @@
from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
UserPromptMessage,
)
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import Conversation
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
inputs = {"name": "John"}
files = []
context = "I am superman."
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template=prompt_template_config,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert len(prompt_messages) == 1
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
{
"#context#": context,
"#histories#": "\n".join(
[
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: {prompt.content}"
for prompt in history_prompt_messages
]
),
**inputs,
}
)
def test__get_chat_model_prompt_messages(get_chat_model_args):
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
files = []
query = "Hi2."
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert prompt_messages[5].content == query
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
files = []
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
)
]
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[1].data == files[0].remote_url
@pytest.fixture
def get_chat_model_args():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
),
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
inputs = {"name": "John"}
context = "I am superman."
return model_config_mock, memory_config, prompt_messages, inputs, context

View File

@@ -0,0 +1,75 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.entities.provider_configuration import ProviderModelBundle
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation
def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content="System Template"),
UserPromptMessage(content="User Query"),
]
history_messages = [
SystemPromptMessage(content="System Prompt 1"),
UserPromptMessage(content="User Prompt 1"),
AssistantPromptMessage(content="Assistant Thought 1"),
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
SystemPromptMessage(content="System Prompt 2"),
UserPromptMessage(content="User Prompt 2"),
AssistantPromptMessage(content="Assistant Thought 2"),
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
UserPromptMessage(content="User Prompt 3"),
AssistantPromptMessage(content="Assistant Thought 3"),
]
# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = "openai"
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory,
)
max_token_limit = 5
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) <= max_token_limit
assert len(result) == 4
max_token_limit = 20
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) <= max_token_limit
assert len(result) == 12

View File

@@ -0,0 +1,91 @@
from uuid import uuid4
from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages
class TestMessage:
def __init__(self, id, parent_message_id):
self.id = id
self.parent_message_id = parent_message_id
def __getitem__(self, item):
return getattr(self, item)
def test_extract_thread_messages_single_message():
messages = [TestMessage(str(uuid4()), UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0] == messages[0]
def test_extract_thread_messages_linear_thread():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id3),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 5
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]
def test_extract_thread_messages_branched_thread():
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id4, id2, id1]
def test_extract_thread_messages_empty_list():
messages = []
result = extract_thread_messages(messages)
assert len(result) == 0
def test_extract_thread_messages_partially_loaded():
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, id0),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]
def test_extract_thread_messages_legacy_messages():
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, UUID_NIL),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]
def test_extract_thread_messages_mixed_with_legacy_messages():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]

View File

@@ -0,0 +1,52 @@
# from unittest.mock import MagicMock
# from core.app.app_config.entities import ModelConfigEntity
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
# from core.model_runtime.entities.message_entities import UserPromptMessage
# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from core.model_runtime.entities.provider_entities import ProviderEntity
# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
# def test__calculate_rest_token():
# model_schema_mock = MagicMock(spec=AIModelEntity)
# parameter_rule_mock = MagicMock(spec=ParameterRule)
# parameter_rule_mock.name = "max_tokens"
# model_schema_mock.parameter_rules = [parameter_rule_mock]
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
# large_language_model_mock.get_num_tokens.return_value = 6
# provider_mock = MagicMock(spec=ProviderEntity)
# provider_mock.provider = "openai"
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
# provider_configuration_mock.provider = provider_mock
# provider_configuration_mock.model_settings = None
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
# provider_model_bundle_mock.configuration = provider_configuration_mock
# model_config_mock = MagicMock(spec=ModelConfigEntity)
# model_config_mock.model = "gpt-4"
# model_config_mock.credentials = {}
# model_config_mock.parameters = {"max_tokens": 50}
# model_config_mock.model_schema = model_schema_mock
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
# prompt_transform = PromptTransform()
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
# # Validate based on the mock configuration and expected logic
# expected_rest_tokens = (
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
# - model_config_mock.parameters["max_tokens"]
# - large_language_model_mock.get_num_tokens.return_value
# )
# assert rest_tokens == expected_rest_tokens
# assert rest_tokens == 6

View File

@@ -0,0 +1,247 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import AppMode, Conversation
def test_get_common_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_common_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_baichuan_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
print(prompt_template["prompt_template"].template)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_q():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
assert prompt_template["special_variable_keys"] == ["#query#"]
def test_get_common_chat_app_prompt_template_with_cq():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_p():
prompt_transform = SimplePromptTransform()
pre_prompt = "you are {{name}}"
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=False,
with_memory_prompt=False,
)
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
assert prompt_template["custom_variable_keys"] == ["name"]
assert prompt_template["special_variable_keys"] == []
def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_mock = MagicMock(spec=TokenBufferMemory)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=memory_mock,
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=False,
with_memory_prompt=False,
)
full_inputs = {**inputs, "#context#": context}
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 4
assert prompt_messages[0].content == real_system_prompt
assert prompt_messages[1].content == history_prompt_messages[0].content
assert prompt_messages[2].content == history_prompt_messages[1].content
assert prompt_messages[3].content == query
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=memory,
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
full_inputs = {
**inputs,
"#context#": context,
"#query#": query,
"#histories#": memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
),
}
real_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt

View File

@@ -0,0 +1,18 @@
import pytest
from pydantic.error_wrappers import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
def test_default_value():
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
assert config.database == "default"

View File

@@ -0,0 +1,26 @@
import os
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
"includePaths": [],
"excludePaths": [],
"maxDepth": 1,
"limit": 1,
}
mocked_firecrawl = {
"id": "test",
}
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
print(f"job_id: {job_id}")
assert job_id is not None
assert isinstance(job_id, str)

View File

@@ -0,0 +1,91 @@
from unittest import mock
from core.rag.extractor import notion_extractor
user_id = "user1"
database_id = "database1"
page_id = "page1"
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
)
def _generate_page(page_title: str):
return {
"object": "page",
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
}
},
}
def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {"type": "page_id", "page_id": page_id},
"type": block_type,
"has_children": False,
block_type: {
"rich_text": [
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}
]
},
}
def _mock_response(data):
response = mock.Mock()
response.status_code = 200
response.json.return_value = data
return response
def _remove_multiple_new_lines(text):
while "\n\n" in text:
text = text.replace("\n\n", "\n")
return text.strip()
def test_notion_page(mocker):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3]),
],
"next_cursor": None,
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None,
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == "\n".join([f"Page:{i}" for i in page_title_list])

View File

@@ -0,0 +1,56 @@
import json
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_to_dict():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
)
file_dict = file.to_dict()
assert "_storage_key" not in file_dict
assert "url" in file_dict
def test_workflow_features_with_image():
# Create a feature dict that mimics the old structure with image config
features = {
"file_upload": {
"image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]}
}
}
# Create a workflow instance with the features
workflow = Workflow(
tenant_id="tenant-1",
app_id="app-1",
type="chat",
version="1.0",
graph="{}",
features=json.dumps(features),
created_by="user-1",
environment_variables=[],
conversation_variables=[],
)
# Get the converted features through the property
converted_features = json.loads(workflow.features)
# Create FileUploadConfig from the converted features
file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"])
# Validate the config
assert file_upload_config.number_limits == 5
assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE]
assert list(file_upload_config.allowed_file_upload_methods) == [
FileTransferMethod.REMOTE_URL,
FileTransferMethod.LOCAL_FILE,
]
assert list(file_upload_config.allowed_file_extensions) == []

View File

@@ -0,0 +1,72 @@
from unittest.mock import MagicMock, patch
import pytest
import redis
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
@pytest.fixture
def lb_model_manager():
load_balancing_configs = [
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
]
lb_model_manager = LBModelManager(
tenant_id="tenant_id",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4",
load_balancing_configs=load_balancing_configs,
managed_credentials={"openai_api_key": "fake_key"},
)
lb_model_manager.cooldown = MagicMock(return_value=None)
def is_cooldown(config: ModelLoadBalancingConfiguration):
if config.id == "id1":
return True
return False
lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
# initialize redis client
redis_client.initialize(redis.Redis())
assert len(lb_model_manager._load_balancing_configs) == 3
config1 = lb_model_manager._load_balancing_configs[0]
config2 = lb_model_manager._load_balancing_configs[1]
config3 = lb_model_manager._load_balancing_configs[2]
assert lb_model_manager.in_cooldown(config1) is True
assert lb_model_manager.in_cooldown(config2) is False
assert lb_model_manager.in_cooldown(config3) is False
start_index = 0
def incr(key):
nonlocal start_index
start_index += 1
return start_index
with (
patch.object(redis_client, "incr", side_effect=incr),
patch.object(redis_client, "set", return_value=None),
patch.object(redis_client, "expire", return_value=None),
):
config = lb_model_manager.fetch_next()
assert config == config2
config = lb_model_manager.fetch_next()
assert config == config3

View File

@@ -0,0 +1,190 @@
# from core.entities.provider_entities import ModelSettings
# from core.model_runtime.entities.model_entities import ModelType
# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
# from core.provider_manager import ProviderManager
# from models.provider import LoadBalancingModelConfig, ProviderModelSetting
# def test__to_model_settings(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=True,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# ),
# LoadBalancingModelConfig(
# id="id2",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="first",
# encrypted_config='{"openai_api_key": "fake_key"}',
# enabled=True,
# ),
# ]
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
# )
# provider_manager = ProviderManager()
# # Running the method
# result = provider_manager._to_model_settings(provider_entity,
# provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 2
# assert result[0].load_balancing_configs[0].name == "__inherit__"
# assert result[0].load_balancing_configs[1].name == "first"
# def test__to_model_settings_only_one_lb(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=True,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# )
# ]
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
# )
# provider_manager = ProviderManager()
# # Running the method
# result = provider_manager._to_model_settings(
# provider_entity, provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 0
# def test__to_model_settings_lb_disabled(mocker):
# # Get all provider entities
# model_provider_factory = ModelProviderFactory("test_tenant")
# provider_entities = model_provider_factory.get_providers()
# provider_entity = None
# for provider in provider_entities:
# if provider.provider == "openai":
# provider_entity = provider
# # Mocking the inputs
# provider_model_settings = [
# ProviderModelSetting(
# id="id",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# enabled=True,
# load_balancing_enabled=False,
# )
# ]
# load_balancing_model_configs = [
# LoadBalancingModelConfig(
# id="id1",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="__inherit__",
# encrypted_config=None,
# enabled=True,
# ),
# LoadBalancingModelConfig(
# id="id2",
# tenant_id="tenant_id",
# provider_name="openai",
# model_name="gpt-4",
# model_type="text-generation",
# name="first",
# encrypted_config='{"openai_api_key": "fake_key"}',
# enabled=True,
# ),
# ]
# mocker.patch(
# "core.helper.model_provider_cache.ProviderCredentialsCache.get",
# return_value={"openai_api_key": "fake_key"}
# )
# provider_manager = ProviderManager()
# # Running the method
# result = provider_manager._to_model_settings(provider_entity,
# provider_model_settings, load_balancing_model_configs)
# # Asserting that the result is as expected
# assert len(result) == 1
# assert isinstance(result[0], ModelSettings)
# assert result[0].model == "gpt-4"
# assert result[0].model_type == ModelType.LLM
# assert result[0].enabled is True
# assert len(result[0].load_balancing_configs) == 0

View File

@@ -0,0 +1,49 @@
from core.tools.entities.tool_entities import ToolParameter
def test_get_parameter_type():
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
def test_cast_parameter_by_type():
# string
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
# secret input
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
# select
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
# boolean
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
# number
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None

View File

@@ -0,0 +1,791 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
def test_init():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start",
"target": "qc",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {"type": "question-classifier"},
"id": "qc",
},
{
"data": {
"type": "http-request",
},
"id": "http",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
start_node_id = "start"
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_iteration_graph():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
},
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {"type": "iteration"},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
},
],
}
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="", value="5")],
),
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
start_edges = graph.edge_mapping.get("start")
assert start_edges is not None
assert start_edges[i].target_node_id == f"llm{i + 1}"
llm_edges = graph.edge_mapping.get(f"llm{i + 1}")
assert llm_edges is not None
assert llm_edges[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id

View File

@@ -0,0 +1,504 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_parallel_in_workflow(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "llm1",
},
{
"id": "2",
"source": "llm1",
"target": "llm2",
},
{
"id": "3",
"source": "llm1",
"target": "llm3",
},
{
"id": "4",
"source": "llm2",
"target": "end1",
},
{
"id": "5",
"source": "llm3",
"target": "end2",
},
],
"nodes": [
{
"data": {
"type": "start",
"title": "start",
"variables": [
{
"label": "query",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "query",
}
],
},
"id": "start",
},
{
"data": {
"type": "llm",
"title": "llm1",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say hi"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm1",
},
{
"data": {
"type": "llm",
"title": "llm2",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say bye"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm2",
},
{
"data": {
"type": "llm",
"title": "llm3",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say good morning"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm3",
},
{
"data": {
"type": "end",
"title": "end1",
"outputs": [
{"value_selector": ["llm2", "text"], "variable": "result2"},
{"value_selector": ["start", "query"], "variable": "query"},
],
},
"id": "end1",
},
{
"data": {
"type": "end",
"title": "end2",
"outputs": [
{"value_selector": ["llm1", "text"], "variable": "result1"},
{"value_selector": ["llm3", "text"], "variable": "result3"},
],
},
"id": "end2",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(
chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"]
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
# print("")
with patch.object(LLMNode, "_run", new=llm_generator):
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}:
assert item.parallel_id is not None
assert len(items) == 18
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == "start"
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == "start"
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "answer1",
},
{
"id": "2",
"source": "answer1",
"target": "answer2",
},
{
"id": "3",
"source": "answer1",
"target": "answer3",
},
{
"id": "4",
"source": "answer2",
"target": "answer4",
},
{
"id": "5",
"source": "answer3",
"target": "answer5",
},
],
"nodes": [
{"data": {"type": "start", "title": "start"}, "id": "start"},
{"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"},
{
"data": {"type": "answer", "title": "answer2", "answer": "2"},
"id": "answer2",
},
{
"data": {"type": "answer", "title": "answer3", "answer": "3"},
"id": "answer3",
},
{
"data": {"type": "answer", "title": "answer4", "answer": "4"},
"id": "answer4",
},
{
"data": {"type": "answer", "title": "answer5", "answer": "5"},
"id": "answer5",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
# print("")
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {
"answer2",
"answer3",
"answer4",
"answer5",
}:
assert item.parallel_id is not None
assert len(items) == 23
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == "start"
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == "start"
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_branch(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "if-else-1",
},
{
"id": "2",
"source": "if-else-1",
"sourceHandle": "true",
"target": "answer-1",
},
{
"id": "3",
"source": "if-else-1",
"sourceHandle": "false",
"target": "if-else-2",
},
{
"id": "4",
"source": "if-else-2",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "5",
"source": "if-else-2",
"sourceHandle": "false",
"target": "answer-3",
},
],
"nodes": [
{
"data": {
"title": "Start",
"type": "start",
"variables": [
{
"label": "uid",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "uid",
}
],
},
"id": "start",
},
{
"data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []},
"id": "answer-1",
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
"value": "hi",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"title": "IF/ELSE",
"type": "if-else",
},
"id": "if-else-1",
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
"value": "takatost",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"title": "IF/ELSE 2",
"type": "if-else",
},
"id": "if-else-2",
},
{
"data": {
"answer": "2",
"title": "Answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"answer": "3",
"title": "Answer 3",
"type": "answer",
},
"id": "answer-3",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "hi",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={"uid": "takato"},
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
# print("")
items = []
generator = graph_engine.run()
for item in generator:
items.append(item)
assert len(items) == 10
assert items[3].route_node_state.node_id == "if-else-1"
assert items[4].route_node_state.node_id == "if-else-1"
assert isinstance(items[5], NodeRunStreamChunkEvent)
assert isinstance(items[6], NodeRunStreamChunkEvent)
assert items[6].chunk_content == "takato"
assert items[7].route_node_state.node_id == "answer-1"
assert items[8].route_node_state.node_id == "answer-1"
assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato"
assert isinstance(items[9], GraphRunSucceededEvent)
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))

View File

@@ -0,0 +1,82 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_answer():
graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@@ -0,0 +1,109 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
def test_init():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"},
"id": "answer",
},
{
"data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping
)
assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"]
assert answer_stream_generate_route.answer_dependencies["answer2"] == []

View File

@@ -0,0 +1,216 @@
import uuid
from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
if next_node_id == "start":
yield from _publish_events(graph, next_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _publish_events(graph, edge.target_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _recursive_process(graph, edge.target_node_id)
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
if parallel_id:
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
node_execution_id = str(uuid.uuid4())
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id,
)
if "llm" in next_node_id:
length = int(next_node_id[-1])
for i in range(0, length):
yield NodeRunStreamChunkEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None)
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
def test_process():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
{
"data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"},
"id": "answer",
},
{
"data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
)
answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool)
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
# print("")
for event in _recursive_process(graph, "start"):
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunSucceededEvent):
if "llm" in event.route_node_state.node_id:
variable_pool.add(
[event.route_node_state.node_id, "text"],
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))),
)
yield event
result_generator = answer_stream_processor.process(graph_generator())
stream_contents = ""
for event in result_generator:
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunStreamChunkEvent):
stream_contents += event.chunk_content
pass
assert stream_contents == "c012da01b"

View File

@@ -0,0 +1,140 @@
from unittest.mock import Mock, PropertyMock, patch
import httpx
import pytest
from core.workflow.nodes.http_request.entities import Response
@pytest.fixture
def mock_response():
response = Mock(spec=httpx.Response)
response.headers = {}
return response
def test_is_file_with_attachment_disposition(mock_response):
"""Test is_file when content-disposition header contains 'attachment'"""
mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"}
response = Response(mock_response)
assert response.is_file
def test_is_file_with_filename_disposition(mock_response):
"""Test is_file when content-disposition header contains filename parameter"""
mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"}
response = Response(mock_response)
assert response.is_file
@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"])
def test_is_file_with_file_content_types(mock_response, content_type):
"""Test is_file with various file content types"""
mock_response.headers = {"content-type": content_type}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file, f"Content type {content_type} should be identified as a file"
@pytest.mark.parametrize(
"content_type",
[
"application/json",
"application/xml",
"application/javascript",
"application/x-www-form-urlencoded",
"application/yaml",
"application/graphql",
],
)
def test_text_based_application_types(mock_response, content_type):
"""Test common text-based application types are not identified as files"""
mock_response.headers = {"content-type": content_type}
response = Response(mock_response)
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
@pytest.mark.parametrize(
("content", "content_type"),
[
(b'{"key": "value"}', "application/octet-stream"),
(b"[1, 2, 3]", "application/unknown"),
(b"function test() {}", "application/x-unknown"),
(b"<root>test</root>", "application/binary"),
(b"var x = 1;", "application/data"),
],
)
def test_content_based_detection(mock_response, content, content_type):
"""Test content-based detection for text-like content"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=content)
response = Response(mock_response)
assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file"
@pytest.mark.parametrize(
("content", "content_type"),
[
(bytes([0x00, 0xFF] * 512), "application/octet-stream"),
(bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers
(bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers
],
)
def test_binary_content_detection(mock_response, content, content_type):
"""Test content-based detection for binary content"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=content)
response = Response(mock_response)
assert response.is_file, f"Binary content with type {content_type} should be identified as a file"
@pytest.mark.parametrize(
("content_type", "expected_main_type"),
[
("x-world/x-vrml", "model"), # VRML 3D model
("font/ttf", "application"), # TrueType font
("text/csv", "text"), # CSV text file
("unknown/xyz", None), # Unknown type
],
)
def test_mimetype_based_detection(mock_response, content_type, expected_main_type):
"""Test detection using mimetypes.guess_type for non-application content types"""
mock_response.headers = {"content-type": content_type}
type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content
with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type:
# Mock the return value based on expected_main_type
if expected_main_type:
mock_guess_type.return_value = (f"{expected_main_type}/subtype", None)
else:
mock_guess_type.return_value = (None, None)
response = Response(mock_response)
# Check if the result matches our expectation
if expected_main_type in ("application", "image", "audio", "video"):
assert response.is_file, f"Content type {content_type} should be identified as a file"
else:
assert not response.is_file, f"Content type {content_type} should not be identified as a file"
# Verify that guess_type was called
mock_guess_type.assert_called_once()
def test_is_file_with_inline_disposition(mock_response):
"""Test is_file when content-disposition is 'inline'"""
mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file
def test_is_file_with_no_content_disposition(mock_response):
"""Test is_file when no content-disposition header is present"""
mock_response.headers = {"content-type": "application/pdf"}
# Mock binary content
type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512))
response = Response(mock_response)
assert response.is_file

View File

@@ -0,0 +1,336 @@
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Number Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"number": {{#pre_node_id.number#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '{"number": 42}' in raw_request
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value="{{#pre_node_id.object#}}",
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"object": {{#pre_node_id.object#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == []
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"object": {' in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool()
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="test: {{#node_id.custom_query#}}",
body=HttpRequestNodeBody(
type="none",
data=[],
),
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
assert executor.params == [("test", "line1\nline2")]
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
variable_pool.add(["pre_node_id", "number_field"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test Form Data",
method="post",
url="https://api.example.com/upload",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: multipart/form-data",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="text_field",
type="text",
value="{{#pre_node_id.text_field#}}",
),
BodyData(
key="number_field",
type="text",
value="{{#pre_node_id.number_field#}}",
),
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params == []
assert executor.json is None
assert executor.files is None
assert executor.content is None
# Check that the form data is correctly loaded in executor.data
assert isinstance(executor.data, dict)
assert "text_field" in executor.data
assert executor.data["text_field"] == "Hello, World!"
assert "number_field" in executor.data
assert executor.data["number_field"] == "42"
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /upload HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: multipart/form-data" in raw_request
assert "text_field" in raw_request
assert "Hello, World!" in raw_request
assert "number_field" in raw_request
assert "42" in raw_request
def test_init_headers():
def create_executor(headers: str) -> Executor:
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers=headers,
params="",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
executor = create_executor("aa\n cc:")
executor._init_headers()
assert executor.headers == {"aa": "", "cc": ""}
executor = create_executor("aa:bb\n cc:dd")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
executor = create_executor("aa:bb\n cc:dd\n")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
executor = create_executor("aa:bb\n\n cc : dd\n\n")
executor._init_headers()
assert executor.headers == {"aa": "bb", "cc": "dd"}
def test_init_params():
def create_executor(params: str) -> Executor:
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params=params,
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
# Test empty values
executor = create_executor("key1:\nkey2:")
executor._init_params()
assert executor.params == [("key1", ""), ("key2", "")]
# Test duplicate keys (which is allowed for params)
executor = create_executor("key1:value1\nkey1:value2")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key1", "value2")]
# Test whitespace handling
executor = create_executor(" key1 : value1 \n key2 : value2 ")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
# Test empty lines and extra whitespace
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]

View File

@@ -0,0 +1,196 @@
import httpx
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import FileVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNode,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="binary",
data=[
BodyData(
key="file",
type="file",
value="",
file=["1111", "file"],
)
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == "test"
def test_http_request_node_form_with_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="file",
type="file",
file=["1111", "file"],
),
BodyData(
key="name",
type="text",
value="test",
),
],
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")}
return httpx.Response(200, content=b"")
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
attr_checker,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == ""

View File

@@ -0,0 +1,860 @@
import time
import uuid
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_run():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.CHAT,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
)
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
result = iteration_node._run()
count = 0
for item in result:
# print(type(item), item)
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 20
def test_run_parallel():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "iteration-start-source-tt-target",
"source": "iteration-start",
"target": "tt",
},
{
"id": "iteration-start-source-tt-2-target",
"source": "iteration-start",
"target": "tt-2",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "tt-2-source-if-else-target",
"source": "tt-2",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt-2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.CHAT,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
)
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
result = iteration_node._run()
count = 0
for item in result:
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 32
def test_iteration_run_in_parallel_mode():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "iteration-start-source-tt-target",
"source": "iteration-start",
"target": "tt",
},
{
"id": "iteration-start-source-tt-2-target",
"source": "iteration-start",
"target": "tt-2",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "tt-2-source-if-else-target",
"source": "tt-2",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt-2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.CHAT,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
)
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
)
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
for item in parallel_result:
count += 1
parallel_arr.append(item)
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 32
for item in sequential_result:
sequential_arr.append(item)
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 64
def test_iteration_run_error_handle():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "tt-source-if-else-target",
"source": "iteration-start",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "tt",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "tt2",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt2", "output"],
"output_type": "array[string]",
"start_node_id": "if-else",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1.split(arg2) }}",
"title": "template transform",
"type": "template-transform",
"variables": [
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
{"value_selector": ["iteration-1", "index"], "variable": "arg2"},
],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }}",
"title": "template transform",
"type": "template-transform",
"variables": [
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
],
},
"id": "tt2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "1",
"variable_selector": ["iteration-1", "item"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.CHAT,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
},
)
# execute continue on error node
result = iteration_node._run()
result_arr = []
count = 0
for item in result:
result_arr.append(item)
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": [None, None]}
assert count == 14
# execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": []}
assert count == 14

View File

@@ -0,0 +1,467 @@
from collections.abc import Sequence
from typing import Optional
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
class MockTokenBufferMemory:
def __init__(self, history_messages=None):
self.history_messages = history_messages or []
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages
@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(tenant_id="test")
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance,
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
model_type_instance=model_type_instance,
)
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)
def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
# TODO: Add test
# pass
# prompt_template = []
# llm_node.node_data.prompt_template = prompt_template
# fake_vision_detail = faker.random_element(
# [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
# )
# fake_remote_url = faker.url()
# files = [
# File(
# id="1",
# tenant_id="test",
# type=FileType.IMAGE,
# filename="test1.jpg",
# transfer_method=FileTransferMethod.REMOTE_URL,
# remote_url=fake_remote_url,
# storage_key="",
# )
# ]
# fake_query = faker.sentence()
# prompt_messages, _ = llm_node._fetch_prompt_messages(
# sys_query=fake_query,
# sys_files=files,
# context=None,
# memory=None,
# model_config=model_config,
# prompt_template=prompt_template,
# memory_config=None,
# vision_enabled=False,
# vision_detail=fake_vision_detail,
# variable_pool=llm_node.graph_runtime_state.variable_pool,
# jinja2_variables=[],
# )
# assert prompt_messages == [UserPromptMessage(content=fake_query)]
# def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# TODO: Add test
# pass
# Setup dify config
# dify_config.MULTIMODAL_SEND_FORMAT = "url"
# # Generate fake values for prompt template
# fake_assistant_prompt = faker.sentence()
# fake_query = faker.sentence()
# fake_context = faker.sentence()
# fake_window_size = faker.random_int(min=1, max=3)
# fake_vision_detail = faker.random_element(
# [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
# )
# fake_remote_url = faker.url()
# # Setup mock memory with history messages
# mock_history = [
# UserPromptMessage(content=faker.sentence()),
# AssistantPromptMessage(content=faker.sentence()),
# UserPromptMessage(content=faker.sentence()),
# AssistantPromptMessage(content=faker.sentence()),
# UserPromptMessage(content=faker.sentence()),
# AssistantPromptMessage(content=faker.sentence()),
# ]
# # Setup memory configuration
# memory_config = MemoryConfig(
# role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
# window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
# query_prompt_template=None,
# )
# memory = MockTokenBufferMemory(history_messages=mock_history)
# # Test scenarios covering different file input combinations
# test_scenarios = [
# LLMNodeTestScenario(
# description="No files",
# sys_query=fake_query,
# sys_files=[],
# features=[],
# vision_enabled=False,
# vision_detail=None,
# window_size=fake_window_size,
# prompt_template=[
# LLMNodeChatModelMessage(
# text=fake_context,
# role=PromptMessageRole.SYSTEM,
# edition_type="basic",
# ),
# LLMNodeChatModelMessage(
# text="{#context#}",
# role=PromptMessageRole.USER,
# edition_type="basic",
# ),
# LLMNodeChatModelMessage(
# text=fake_assistant_prompt,
# role=PromptMessageRole.ASSISTANT,
# edition_type="basic",
# ),
# ],
# expected_messages=[
# SystemPromptMessage(content=fake_context),
# UserPromptMessage(content=fake_context),
# AssistantPromptMessage(content=fake_assistant_prompt),
# ]
# + mock_history[fake_window_size * -2 :]
# + [
# UserPromptMessage(content=fake_query),
# ],
# ),
# LLMNodeTestScenario(
# description="User files",
# sys_query=fake_query,
# sys_files=[
# File(
# tenant_id="test",
# type=FileType.IMAGE,
# filename="test1.jpg",
# transfer_method=FileTransferMethod.REMOTE_URL,
# remote_url=fake_remote_url,
# extension=".jpg",
# mime_type="image/jpg",
# storage_key="",
# )
# ],
# vision_enabled=True,
# vision_detail=fake_vision_detail,
# features=[ModelFeature.VISION],
# window_size=fake_window_size,
# prompt_template=[
# LLMNodeChatModelMessage(
# text=fake_context,
# role=PromptMessageRole.SYSTEM,
# edition_type="basic",
# ),
# LLMNodeChatModelMessage(
# text="{#context#}",
# role=PromptMessageRole.USER,
# edition_type="basic",
# ),
# LLMNodeChatModelMessage(
# text=fake_assistant_prompt,
# role=PromptMessageRole.ASSISTANT,
# edition_type="basic",
# ),
# ],
# expected_messages=[
# SystemPromptMessage(content=fake_context),
# UserPromptMessage(content=fake_context),
# AssistantPromptMessage(content=fake_assistant_prompt),
# ]
# + mock_history[fake_window_size * -2 :]
# + [
# UserPromptMessage(
# content=[
# TextPromptMessageContent(data=fake_query),
# ImagePromptMessageContent(
# url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
# ),
# ]
# ),
# ],
# ),
# LLMNodeTestScenario(
# description="Prompt template with variable selector of File",
# sys_query=fake_query,
# sys_files=[],
# vision_enabled=False,
# vision_detail=fake_vision_detail,
# features=[ModelFeature.VISION],
# window_size=fake_window_size,
# prompt_template=[
# LLMNodeChatModelMessage(
# text="{{#input.image#}}",
# role=PromptMessageRole.USER,
# edition_type="basic",
# ),
# ],
# expected_messages=[
# UserPromptMessage(
# content=[
# ImagePromptMessageContent(
# url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
# ),
# ]
# ),
# ]
# + mock_history[fake_window_size * -2 :]
# + [UserPromptMessage(content=fake_query)],
# file_variables={
# "input.image": File(
# tenant_id="test",
# type=FileType.IMAGE,
# filename="test1.jpg",
# transfer_method=FileTransferMethod.REMOTE_URL,
# remote_url=fake_remote_url,
# extension=".jpg",
# mime_type="image/jpg",
# storage_key="",
# )
# },
# ),
# ]
# for scenario in test_scenarios:
# model_config.model_schema.features = scenario.features
# for k, v in scenario.file_variables.items():
# selector = k.split(".")
# llm_node.graph_runtime_state.variable_pool.add(selector, v)
# # Call the method under test
# prompt_messages, _ = llm_node._fetch_prompt_messages(
# sys_query=scenario.sys_query,
# sys_files=scenario.sys_files,
# context=fake_context,
# memory=memory,
# model_config=model_config,
# prompt_template=scenario.prompt_template,
# memory_config=memory_config,
# vision_enabled=scenario.vision_enabled,
# vision_detail=scenario.vision_detail,
# variable_pool=llm_node.graph_runtime_state.variable_pool,
# jinja2_variables=[],
# )
# # Verify the result
# assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
# assert prompt_messages == scenario.expected_messages, (
# f"Message content mismatch in scenario: {scenario.description}"
# )
def test_handle_list_messages_basic(llm_node):
messages = [
LLMNodeChatModelMessage(
text="Hello, {#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
)
]
context = "world"
jinja2_variables = []
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]

View File

@@ -0,0 +1,25 @@
from collections.abc import Mapping, Sequence
from pydantic import BaseModel, Field
from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
sys_query: str = Field(..., description="User query input")
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

View File

@@ -0,0 +1,85 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_answer():
graph_config = {
"edges": [
{
"id": "start-source-answer-target",
"source": "start",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@@ -0,0 +1,508 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.enums import UserFrom
from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
"data": {
"outputs": {"result": {"type": "number"}},
"error_strategy": error_strategy,
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
**retry_config,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_http_node(
error_strategy: str = "fail-branch",
default_value: dict | None = None,
authorization_success: bool = False,
retry_config: dict = {},
):
"""Helper method to create a http node configuration"""
authorization = (
{
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
}
if authorization_success
else {
"type": "api-key",
# missing config field
}
)
node = {
"id": "node",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": authorization,
"headers": "X-Header:123",
"params": "A:b",
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
**retry_config,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a http node configuration"""
node = {
"id": "node",
"data": {
"type": "http-request",
"title": "HTTP Request",
"desc": "",
"variables": [],
"method": "get",
"url": "https://api.github.com/issues",
"authorization": {"type": "no-auth", "config": None},
"headers": "",
"params": "",
"body": {"type": "none", "data": []},
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a tool node configuration"""
node = {
"id": "node",
"data": {
"title": "a",
"desc": "a",
"provider_id": "maths",
"provider_type": "builtin",
"provider_name": "maths",
"tool_name": "eval_expression",
"tool_label": "eval_expression",
"tool_configurations": {},
"tool_parameters": {
"expression": {
"type": "variable",
"value": ["1", "123", "args1"],
}
},
"type": "tool",
"error_strategy": error_strategy,
},
}
if default_value:
node.node_data.default_value = default_value
return node
@staticmethod
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a llm node configuration"""
node = {
"id": "node",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}
return GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
FAIL_BRANCH_EDGES = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-true-success-target",
"source": "node",
"target": "success",
"sourceHandle": "source",
},
{
"id": "node-false-error-target",
"source": "node",
"target": "error",
"sourceHandle": "fail-branch",
},
]
def test_code_default_value_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_code_fail_branch_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_code_node(error_code),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
)
def test_http_node_default_value_continue_on_error():
"""Test HTTP node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
# def test_tool_node_default_value_continue_on_error():
# """Test tool node with default value error strategy"""
# graph_config = {
# "edges": DEFAULT_VALUE_EDGE,
# "nodes": [
# {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
# {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
# ContinueOnErrorTestHelper.get_tool_node(
# "default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
# ),
# ],
# }
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
# events = list(graph_engine.run())
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
# assert any(
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501
# )
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
# def test_tool_node_fail_branch_continue_on_error():
# """Test HTTP node with fail-branch error strategy"""
# graph_config = {
# "edges": FAIL_BRANCH_EDGES,
# "nodes": [
# {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
# {
# "data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
# "id": "success",
# },
# {
# "data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
# "id": "error",
# },
# ContinueOnErrorTestHelper.get_tool_node(),
# ],
# }
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
# events = list(graph_engine.run())
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
# assert any(
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501
# )
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_default_value_continue_on_error():
"""Test LLM node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_llm_node(
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_fail_branch_continue_on_error():
"""Test LLM node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_status_code_error_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_variable_pool_error_type_variable():
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"
def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0

View File

@@ -0,0 +1,178 @@
from unittest.mock import Mock, patch
import pytest
from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
from core.workflow.nodes.document_extractor.node import (
_extract_text_from_docx,
_extract_text_from_pdf,
_extract_text_from_plain_text,
)
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture
def document_extractor_node():
node_data = DocumentExtractorNodeData(
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
return DocumentExtractorNode(
id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()},
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
@pytest.fixture
def mock_graph_runtime_state():
return Mock()
def test_run_variable_not_found(document_extractor_node, mock_graph_runtime_state):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_graph_runtime_state.variable_pool.get.return_value = None
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error is not None
assert "File variable not found" in result.error
def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_state):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_graph_runtime_state.variable_pool.get.return_value = StringVariable(
value="Not an ArrayFileSegment", name="test"
)
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error is not None
assert "is not an ArrayFileSegment" in result.error
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
["Mocked PDF content"],
FileTransferMethod.LOCAL_FILE,
".pdf",
),
(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
b"PK\x03\x04",
["Mocked DOCX content"],
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
],
)
def test_run_extract_text(
document_extractor_node,
mock_graph_runtime_state,
mime_type,
file_content,
expected_text,
transfer_method,
extension,
monkeypatch,
):
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
mock_file = Mock(spec=File)
mock_file.mime_type = mime_type
mock_file.transfer_method = transfer_method
mock_file.related_id = "test_file_id" if transfer_method == FileTransferMethod.LOCAL_FILE else None
mock_file.remote_url = "https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
mock_file.extension = extension
mock_array_file_segment = Mock(spec=ArrayFileSegment)
mock_array_file_segment.value = [mock_file]
mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment
mock_download = Mock(return_value=file_content)
mock_ssrf_proxy_get = Mock()
mock_ssrf_proxy_get.return_value.content = file_content
mock_ssrf_proxy_get.return_value.raise_for_status = Mock()
monkeypatch.setattr("core.file.file_manager.download", mock_download)
monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get)
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.outputs is not None
assert result.outputs["text"] == expected_text
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
mock_download.assert_called_once_with(mock_file)
def test_extract_text_from_plain_text():
text = _extract_text_from_plain_text(b"Hello, world!")
assert text == "Hello, world!"
def test_extract_text_from_plain_text_non_utf8():
import tempfile
non_utf8_content = b"Hello, world\xa9." # \xA9 represents © in Latin-1
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
temp_file.write(non_utf8_content)
temp_file.seek(0)
text = _extract_text_from_plain_text(temp_file.read())
assert text == "Hello, world."
@patch("pypdfium2.PdfDocument")
def test_extract_text_from_pdf(mock_pdf_document):
mock_page = Mock()
mock_text_page = Mock()
mock_text_page.get_text_range.return_value = "PDF content"
mock_page.get_textpage.return_value = mock_text_page
mock_pdf_document.return_value = [mock_page]
text = _extract_text_from_pdf(b"%PDF-1.5\n%Test PDF content")
assert text == "PDF content"
@patch("docx.Document")
def test_extract_text_from_docx(mock_document):
mock_paragraph1 = Mock()
mock_paragraph1.text = "Paragraph 1"
mock_paragraph2 = Mock()
mock_paragraph2.text = "Paragraph 2"
mock_document.return_value.paragraphs = [mock_paragraph1, mock_paragraph2]
text = _extract_text_from_docx(b"PK\x03\x04")
assert text == "Paragraph 1\nParagraph 2"
def test_node_type(document_extractor_node):
assert document_extractor_node._node_type == NodeType.DOCUMENT_EXTRACTOR

View File

@@ -0,0 +1,260 @@
import time
import uuid
from unittest.mock import MagicMock, Mock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_if_else_result_true():
graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
pool.add(["start", "not_contains"], "zacde")
pool.add(["start", "start_with"], "abc")
pool.add(["start", "end_with"], "zzab")
pool.add(["start", "is"], "ab")
pool.add(["start", "is_not"], "aab")
pool.add(["start", "empty"], "")
pool.add(["start", "not_empty"], "aaa")
pool.add(["start", "equals"], 22)
pool.add(["start", "not_equals"], 23)
pool.add(["start", "greater_than"], 23)
pool.add(["start", "less_than"], 21)
pool.add(["start", "greater_than_or_equal"], 22)
pool.add(["start", "less_than_or_equal"], 21)
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
def test_execute_if_else_result_false():
graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is False
def test_array_file_contains_file_name():
node_data = IfElseNodeData(
title="123",
logical_operator="and",
cases=[
IfElseNodeData.Case(
case_id="true",
logical_operator="and",
conditions=[
Condition(
comparison_operator="contains",
variable_selector=["start", "array_contains"],
sub_variable_condition=SubVariableCondition(
logical_operator="and",
conditions=[
SubCondition(
key="name",
comparison_operator="contains",
value="ab",
)
],
),
)
],
)
],
)
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
config={
"id": "if-else",
"data": node_data.model_dump(),
},
)
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
storage_key="",
),
],
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True

View File

@@ -0,0 +1,168 @@
from unittest.mock import MagicMock
import pytest
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.nodes.list_operator.entities import (
ExtractConfig,
FilterBy,
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture
def list_operator_node():
config = {
"variable": ["test_variable"],
"filter_by": FilterBy(
enabled=True,
conditions=[
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node = ListOperatorNode(
id="test_node_id",
config={
"id": "test_node_id",
"data": node_data.model_dump(),
},
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
def test_filter_files_by_type(list_operator_node):
# Setup test data
files = [
File(
filename="image1.jpg",
type=FileType.IMAGE,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
type=FileType.DOCUMENT,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
type=FileType.IMAGE,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
type=FileType.AUDIO,
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
),
]
variable = ArrayFileSegment(value=files)
list_operator_node.graph_runtime_state.variable_pool.get.return_value = variable
# Run the node
result = list_operator_node._run()
# Verify the result
expected_files = [
{
"filename": "image1.jpg",
"type": FileType.IMAGE,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related1",
},
{
"filename": "document1.pdf",
"type": FileType.DOCUMENT,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related2",
},
{
"filename": "image2.png",
"type": FileType.IMAGE,
"tenant_id": "tenant1",
"transfer_method": FileTransferMethod.LOCAL_FILE,
"related_id": "related3",
},
]
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
assert expected_file["filename"] == result_file.filename
assert expected_file["type"] == result_file.type
assert expected_file["tenant_id"] == result_file.tenant_id
assert expected_file["transfer_method"] == result_file.transfer_method
assert expected_file["related_id"] == result_file.related_id
def test_get_file_extract_string_func():
# Create a File object
file = File(
tenant_id="test_tenant",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
remote_url="https://example.com/test_file.txt",
related_id="test_related_id",
storage_key="",
)
# Test each case
assert _get_file_extract_string_func(key="name")(file) == "test_file.txt"
assert _get_file_extract_string_func(key="type")(file) == "document"
assert _get_file_extract_string_func(key="extension")(file) == ".txt"
assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain"
assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file"
assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt"
# Test with empty values
empty_file = File(
tenant_id="test_tenant",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename=None,
extension=None,
mime_type=None,
remote_url=None,
related_id="test_related_id",
storage_key="",
)
assert _get_file_extract_string_func(key="name")(empty_file) == ""
assert _get_file_extract_string_func(key="extension")(empty_file) == ""
assert _get_file_extract_string_func(key="mime_type")(empty_file) == ""
assert _get_file_extract_string_func(key="url")(empty_file) == ""
# Test invalid key
with pytest.raises(InvalidKeyError):
_get_file_extract_string_func(key="invalid_key")

View File

@@ -0,0 +1,67 @@
from core.model_runtime.entities import ImagePromptMessageContent
from core.workflow.nodes.question_classifier import QuestionClassifierNodeData
def test_init_question_classifier_node_data():
data = {
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
"memory": {
"role_prefix": {"user": "Human:", "assistant": "AI:"},
"window": {"enabled": True, "size": 5},
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
},
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
}
node_data = QuestionClassifierNodeData(**data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
assert node_data.memory.role_prefix.user == "Human:"
assert node_data.memory.role_prefix.assistant == "AI:"
assert node_data.memory.window.enabled == True
assert node_data.memory.window.size == 5
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
assert node_data.vision.enabled == True
assert node_data.vision.configs.variable_selector == ["image"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW
def test_init_question_classifier_node_data_without_vision_config():
data = {
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
"memory": {
"role_prefix": {"user": "Human:", "assistant": "AI:"},
"window": {"enabled": True, "size": 5},
"query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:",
},
}
node_data = QuestionClassifierNodeData(**data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
assert node_data.memory.role_prefix.user == "Human:"
assert node_data.memory.role_prefix.assistant == "AI:"
assert node_data.memory.window.enabled == True
assert node_data.memory.window.size == 5
assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:"
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH

View File

@@ -0,0 +1,72 @@
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
NodeRunRetryEvent,
)
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@@ -0,0 +1,257 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from models.enums import UserFrom
from models.workflow import WorkflowType
DEFAULT_NODE_ID = "node_id"
def test_overwrite_string_variable():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = StringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value="the first value",
)
input_variable = StringVariable(
id=str(uuid4()),
name="test_string_variable",
value="the second value",
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.value == "the second value"
assert got.to_object() == "the second value"
def test_append_variable_to_array():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=["the first value"],
)
input_variable = StringVariable(
id=str(uuid4()),
name="test_string_variable",
value="the second value",
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == ["the first value", "the second value"]
def test_clear_array():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=["the first value"],
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []

View File

@@ -0,0 +1,22 @@
from core.variables import SegmentType
from core.workflow.nodes.variable_assigner.v2.enums import Operation
from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid
def test_is_input_value_valid_overwrite_array_string():
# Valid cases
assert is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"]
)
assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[])
# Invalid cases
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array"
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3]
)
assert not is_input_value_valid(
variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"]
)

View File

@@ -0,0 +1,46 @@
import pytest
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
@pytest.fixture
def pool():
return VariablePool(system_variables={}, user_inputs={})
@pytest.fixture
def file():
return File(
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="",
)
def test_get_file_attribute(pool, file):
# Add a FileSegment to the pool
pool.add(("node_1", "file_var"), FileSegment(value=file))
# Test getting the 'name' attribute of the file
result = pool.get(("node_1", "file_var", "name"))
assert result is not None
assert result.value == file.filename
# Test getting a non-existent attribute
result = pool.get(("node_1", "file_var", "non_existent_attr"))
assert result is None
def test_use_long_selector(pool):
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"

View File

@@ -0,0 +1,28 @@
from core.variables import SecretVariable
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
selectors = variable_template_parser.extract_selectors_from_template(template)
assert selectors == [
VariableSelector(variable="#sys.user_id#", value_selector=["sys", "user_id"]),
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]

View File

@@ -0,0 +1,21 @@
import pytest
from libs.helper import email
def test_email_with_valid_email():
assert email("test@example.com") == "test@example.com"
assert email("TEST12345@example.com") == "TEST12345@example.com"
assert email("test+test@example.com") == "test+test@example.com"
assert email("!#$%&'*+-/=?^_{|}~`@example.com") == "!#$%&'*+-/=?^_{|}~`@example.com"
def test_email_with_invalid_email():
with pytest.raises(ValueError, match="invalid_email is not a valid email."):
email("invalid_email")
with pytest.raises(ValueError, match="@example.com is not a valid email."):
email("@example.com")
with pytest.raises(ValueError, match="()@example.com is not a valid email."):
email("()@example.com")

View File

@@ -0,0 +1,58 @@
import pandas as pd
def test_pandas_csv(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to csv file
csv_file_path = tmp_path.joinpath("example.csv")
df1.to_csv(csv_file_path, index=False)
# read from csv file
df2 = pd.read_csv(csv_file_path, on_bad_lines="skip")
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to xlsx file
xlsx_file_path = tmp_path.joinpath("example.xlsx")
df1.to_excel(xlsx_file_path, index=False)
# read from xlsx file
df2 = pd.read_excel(xlsx_file_path)
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data1)
data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]}
df2 = pd.DataFrame(data2)
# write to xlsx file with sheets
xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx")
sheet1 = "Sheet1"
sheet2 = "Sheet2"
with pd.ExcelWriter(xlsx_file_path) as excel_writer:
df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
# read from xlsx file with sheets
with pd.ExcelFile(xlsx_file_path) as excel_file:
df1 = pd.read_excel(excel_file, sheet_name=sheet1)
assert df1[df1.columns[0]].to_list() == data1["col1"]
assert df1[df1.columns[1]].to_list() == data1["col2"]
df2 = pd.read_excel(excel_file, sheet_name=sheet2)
assert df2[df2.columns[0]].to_list() == data2["col1"]
assert df2[df2.columns[1]].to_list() == data2["col2"]

View File

@@ -0,0 +1,29 @@
import rsa as pyrsa
from Crypto.PublicKey import RSA
from libs import gmpy2_pkcs10aep_cipher
def test_gmpy2_pkcs10aep_cipher() -> None:
rsa_key_pair = pyrsa.newkeys(2048)
public_key = rsa_key_pair[0].save_pkcs1()
private_key = rsa_key_pair[1].save_pkcs1()
public_rsa_key = RSA.import_key(public_key)
public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key)
private_rsa_key = RSA.import_key(private_key)
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
raw_text = "raw_text"
raw_text_bytes = raw_text.encode()
# RSA encryption by public key and decryption by private key
encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes)
decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key)
assert decrypted_by_pub_key == raw_text_bytes
# RSA encryption and decryption by private key
encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes)
decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key)
assert decrypted_by_private_key == raw_text_bytes

View File

@@ -0,0 +1,29 @@
import pytest
from yarl import URL
def test_yarl_urls():
expected_1 = "https://dify.ai/api"
assert str(URL("https://dify.ai") / "api") == expected_1
assert str(URL("https://dify.ai/") / "api") == expected_1
expected_2 = "http://dify.ai:12345/api"
assert str(URL("http://dify.ai:12345") / "api") == expected_2
assert str(URL("http://dify.ai:12345/") / "api") == expected_2
expected_3 = "https://dify.ai/api/v1"
assert str(URL("https://dify.ai") / "api" / "v1") == expected_3
assert str(URL("https://dify.ai") / "api/v1") == expected_3
assert str(URL("https://dify.ai/") / "api/v1") == expected_3
assert str(URL("https://dify.ai/api") / "v1") == expected_3
assert str(URL("https://dify.ai/api/") / "v1") == expected_3
expected_4 = "api"
assert str(URL("") / "api") == expected_4
expected_5 = "/api"
assert str(URL("/") / "api") == expected_5
with pytest.raises(ValueError) as e1:
str(URL("https://dify.ai") / "/api")
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"

View File

@@ -0,0 +1,14 @@
from models.account import TenantAccountRole
def test_account_is_privileged_role() -> None:
assert TenantAccountRole.ADMIN == "admin"
assert TenantAccountRole.OWNER == "owner"
assert TenantAccountRole.EDITOR == "editor"
assert TenantAccountRole.NORMAL == "normal"
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR)
assert not TenantAccountRole.is_privileged_role("")

View File

@@ -0,0 +1,26 @@
from uuid import uuid4
from core.variables import SegmentType
from factories import variable_factory
from models import ConversationVariable
def test_from_variable_and_to_variable():
variable = variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"name": "name",
"value_type": SegmentType.OBJECT,
"value": {
"key": {
"key": "value",
}
},
}
)
conversation_variable = ConversationVariable.from_variable(
app_id="app_id", conversation_id="conversation_id", variable=variable
)
assert conversation_variable.to_variable() == variable

View File

@@ -0,0 +1,139 @@
from unittest import mock
from uuid import uuid4
import contexts
from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow
def test_environment_variables():
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate(
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
)
variable2 = IntegerVariable.model_validate(
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
)
variable3 = SecretVariable.model_validate(
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
)
variable4 = FloatVariable.model_validate(
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
# Set the environment_variables property of the Workflow instance
variables = [variable1, variable2, variable3, variable4]
workflow.environment_variables = variables
# Get the environment_variables property and assert its value
assert workflow.environment_variables == variables
def test_update_environment_variables():
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate(
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
)
variable2 = IntegerVariable.model_validate(
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
)
variable3 = SecretVariable.model_validate(
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
)
variable4 = FloatVariable.model_validate(
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
)
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
variables = [variable1, variable2, variable3, variable4]
# Set the environment_variables property of the Workflow instance
workflow.environment_variables = variables
assert workflow.environment_variables == [variable1, variable2, variable3, variable4]
# Update the name of variable3 and keep the value as it is
variables[2] = variable3.model_copy(
update={
"name": "new name",
"value": HIDDEN_VALUE,
}
)
workflow.environment_variables = variables
assert workflow.environment_variables[2].name == "new name"
assert workflow.environment_variables[2].value == variable3.value
def test_to_dict():
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
# Set the environment_variables property of the Workflow instance
workflow.environment_variables = [
SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}),
StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}),
]
workflow_dict = workflow.to_dict()
assert workflow_dict["environment_variables"][0]["value"] == ""
assert workflow_dict["environment_variables"][1]["value"] == "text"
workflow_dict = workflow.to_dict(include_secret=True)
assert workflow_dict["environment_variables"][0]["value"] == "secret"
assert workflow_dict["environment_variables"][1]["value"] == "text"

View File

@@ -0,0 +1,100 @@
import os
import posixpath
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from oss2 import Bucket # type: ignore
from oss2.models import GetObjectResult, PutObjectResult # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
get_example_data,
get_example_filename,
get_example_filepath,
get_example_folder,
)
class MockResponse:
def __init__(self, status, headers, request_id):
self.status = status
self.headers = headers
self.request_id = request_id
class MockAliyunOssClass:
def __init__(
self,
auth,
endpoint,
bucket_name,
is_cname=False,
session=None,
connect_timeout=None,
app_name="",
enable_crc=True,
proxies=None,
region=None,
cloudbox_id=None,
is_path_style=False,
is_verify_object_strict=True,
):
self.bucket_name = get_example_bucket()
self.key = posixpath.join(get_example_folder(), get_example_filename())
self.content = get_example_data()
self.filepath = get_example_filepath()
self.resp = MockResponse(
200,
{
"etag": "ee8de918d05640145b18f70f4c3aa602",
"x-oss-version-id": "CAEQNhiBgMDJgZCA0BYiIDc4MGZjZGI2OTBjOTRmNTE5NmU5NmFhZjhjYmY0****",
},
"request_id",
)
def put_object(self, key, data, headers=None, progress_callback=None):
assert key == self.key
assert data == self.content
return PutObjectResult(self.resp)
def get_object(self, key, byte_range=None, headers=None, progress_callback=None, process=None, params=None):
assert key == self.key
get_object_output = MagicMock(GetObjectResult)
get_object_output.read.return_value = self.content
return get_object_output
def get_object_to_file(
self, key, filename, byte_range=None, headers=None, progress_callback=None, process=None, params=None
):
assert key == self.key
assert filename == self.filepath
def object_exists(self, key, headers=None):
assert key == self.key
return True
def delete_object(self, key, params=None, headers=None):
assert key == self.key
self.resp.headers["x-oss-delete-marker"] = True
return self.resp
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_aliyun_oss_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Bucket, "__init__", MockAliyunOssClass.__init__)
monkeypatch.setattr(Bucket, "put_object", MockAliyunOssClass.put_object)
monkeypatch.setattr(Bucket, "get_object", MockAliyunOssClass.get_object)
monkeypatch.setattr(Bucket, "get_object_to_file", MockAliyunOssClass.get_object_to_file)
monkeypatch.setattr(Bucket, "object_exists", MockAliyunOssClass.object_exists)
monkeypatch.setattr(Bucket, "delete_object", MockAliyunOssClass.delete_object)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,62 @@
from collections.abc import Generator
import pytest
from extensions.storage.base_storage import BaseStorage
def get_example_folder() -> str:
return "~/dify"
def get_example_bucket() -> str:
return "dify"
def get_opendal_bucket() -> str:
return "./dify"
def get_example_filename() -> str:
return "test.txt"
def get_example_data() -> bytes:
return b"test"
def get_example_filepath() -> str:
return "~/test"
class BaseStorageTest:
@pytest.fixture(autouse=True)
def setup_method(self, *args, **kwargs):
"""Should be implemented in child classes to setup specific storage."""
self.storage: BaseStorage
def test_save(self):
"""Test saving data."""
self.storage.save(get_example_filename(), get_example_data())
def test_load_once(self):
"""Test loading data once."""
assert self.storage.load_once(get_example_filename()) == get_example_data()
def test_load_stream(self):
"""Test loading data as a stream."""
generator = self.storage.load_stream(get_example_filename())
assert isinstance(generator, Generator)
assert next(generator) == get_example_data()
def test_download(self):
"""Test downloading data."""
self.storage.download(get_example_filename(), get_example_filepath())
def test_exists(self):
"""Test checking if a file exists."""
assert self.storage.exists(get_example_filename())
def test_delete(self):
"""Test deleting a file."""
self.storage.delete(get_example_filename())

View File

@@ -0,0 +1,57 @@
import os
import shutil
from pathlib import Path
from unittest.mock import MagicMock, mock_open, patch
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tests.unit_tests.oss.__mock.base import (
get_example_data,
get_example_filename,
get_example_filepath,
get_example_folder,
)
class MockLocalFSClass:
def write_bytes(self, data):
assert data == get_example_data()
def read_bytes(self):
return get_example_data()
@staticmethod
def copyfile(src, dst):
assert src == os.path.join(get_example_folder(), get_example_filename())
assert dst == get_example_filepath()
@staticmethod
def exists(path):
assert path == os.path.join(get_example_folder(), get_example_filename())
return True
@staticmethod
def remove(path):
assert path == os.path.join(get_example_folder(), get_example_filename())
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_local_fs_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Path, "write_bytes", MockLocalFSClass.write_bytes)
monkeypatch.setattr(Path, "read_bytes", MockLocalFSClass.read_bytes)
monkeypatch.setattr(shutil, "copyfile", MockLocalFSClass.copyfile)
monkeypatch.setattr(os.path, "exists", MockLocalFSClass.exists)
monkeypatch.setattr(os, "remove", MockLocalFSClass.remove)
os.makedirs = MagicMock()
with patch("builtins.open", mock_open(read_data=get_example_data())):
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,81 @@
import os
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from qcloud_cos import CosS3Client # type: ignore
from qcloud_cos.streambody import StreamBody # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
get_example_data,
get_example_filename,
get_example_filepath,
)
class MockTencentCosClass:
def __init__(self, conf, retry=1, session=None):
self.bucket_name = get_example_bucket()
self.key = get_example_filename()
self.content = get_example_data()
self.filepath = get_example_filepath()
self.resp = {
"ETag": "ee8de918d05640145b18f70f4c3aa602",
"Server": "tencent-cos",
"x-cos-hash-crc64ecma": 16749565679157681890,
"x-cos-request-id": "NWU5MDNkYzlfNjRiODJhMDlfMzFmYzhfMTFm****",
}
def put_object(self, Bucket, Body, Key, EnableMD5=False, **kwargs): # noqa: N803
assert Bucket == self.bucket_name
assert Key == self.key
assert Body == self.content
return self.resp
def get_object(self, Bucket, Key, KeySimplifyCheck=True, **kwargs): # noqa: N803
assert Bucket == self.bucket_name
assert Key == self.key
mock_stream_body = MagicMock(StreamBody)
mock_raw_stream = MagicMock()
mock_stream_body.get_raw_stream.return_value = mock_raw_stream
mock_raw_stream.read.return_value = self.content
mock_stream_body.get_stream_to_file = MagicMock()
def chunk_generator(chunk_size=2):
for i in range(0, len(self.content), chunk_size):
yield self.content[i : i + chunk_size]
mock_stream_body.get_stream.return_value = chunk_generator(chunk_size=4096)
return {"Body": mock_stream_body}
def object_exists(self, Bucket, Key): # noqa: N803
assert Bucket == self.bucket_name
assert Key == self.key
return True
def delete_object(self, Bucket, Key, **kwargs): # noqa: N803
assert Bucket == self.bucket_name
assert Key == self.key
self.resp.update({"x-cos-delete-marker": True})
return self.resp
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tencent_cos_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(CosS3Client, "__init__", MockTencentCosClass.__init__)
monkeypatch.setattr(CosS3Client, "put_object", MockTencentCosClass.put_object)
monkeypatch.setattr(CosS3Client, "get_object", MockTencentCosClass.get_object)
monkeypatch.setattr(CosS3Client, "object_exists", MockTencentCosClass.object_exists)
monkeypatch.setattr(CosS3Client, "delete_object", MockTencentCosClass.delete_object)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,91 @@
import os
from collections import UserDict
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2 # type: ignore
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
get_example_data,
get_example_filename,
get_example_filepath,
)
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)
class MockVolcengineTosClass:
def __init__(self, ak="", sk="", endpoint="", region=""):
self.bucket_name = get_example_bucket()
self.key = get_example_filename()
self.content = get_example_data()
self.filepath = get_example_filepath()
self.resp = AttrDict(
{
"x-tos-server-side-encryption": "kms",
"x-tos-server-side-encryption-kms-key-id": "trn:kms:cn-beijing:****:keyrings/ring-test/keys/key-test",
"x-tos-server-side-encryption-customer-algorithm": "AES256",
"x-tos-version-id": "test",
"x-tos-hash-crc64ecma": 123456,
"request_id": "test",
"headers": {
"x-tos-id-2": "test",
"ETag": "123456",
},
"status": 200,
}
)
def put_object(self, bucket: str, key: str, content=None) -> PutObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
assert content == self.content
return PutObjectOutput(self.resp)
def get_object(self, bucket: str, key: str) -> GetObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
get_object_output = MagicMock(GetObjectOutput)
get_object_output.read.return_value = self.content
return get_object_output
def get_object_to_file(self, bucket: str, key: str, file_path: str):
assert bucket == self.bucket_name
assert key == self.key
assert file_path == self.filepath
def head_object(self, bucket: str, key: str) -> HeadObjectOutput:
assert bucket == self.bucket_name
assert key == self.key
return HeadObjectOutput(self.resp)
def delete_object(self, bucket: str, key: str):
assert bucket == self.bucket_name
assert key == self.key
return DeleteObjectOutput(self.resp)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_volcengine_tos_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(TosClientV2, "__init__", MockVolcengineTosClass.__init__)
monkeypatch.setattr(TosClientV2, "put_object", MockVolcengineTosClass.put_object)
monkeypatch.setattr(TosClientV2, "get_object", MockVolcengineTosClass.get_object)
monkeypatch.setattr(TosClientV2, "get_object_to_file", MockVolcengineTosClass.get_object_to_file)
monkeypatch.setattr(TosClientV2, "head_object", MockVolcengineTosClass.head_object)
monkeypatch.setattr(TosClientV2, "delete_object", MockVolcengineTosClass.delete_object)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,22 @@
from unittest.mock import patch
import pytest
from oss2 import Auth # type: ignore
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock
from tests.unit_tests.oss.__mock.base import (
BaseStorageTest,
get_example_bucket,
get_example_folder,
)
class TestAliyunOss(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_aliyun_oss_mock):
"""Executed before each test method."""
with patch.object(Auth, "__init__", return_value=None):
self.storage = AliyunOssStorage()
self.storage.bucket_name = get_example_bucket()
self.storage.folder = get_example_folder()

View File

@@ -0,0 +1,85 @@
from collections.abc import Generator
from pathlib import Path
import pytest
from extensions.storage.opendal_storage import OpenDALStorage
from tests.unit_tests.oss.__mock.base import (
get_example_data,
get_example_filename,
get_opendal_bucket,
)
class TestOpenDAL:
@pytest.fixture(autouse=True)
def setup_method(self, *args, **kwargs):
"""Executed before each test method."""
self.storage = OpenDALStorage(
scheme="fs",
root=get_opendal_bucket(),
)
@pytest.fixture(scope="class", autouse=True)
def teardown_class(self, request):
"""Clean up after all tests in the class."""
def cleanup():
folder = Path(get_opendal_bucket())
if folder.exists() and folder.is_dir():
for item in folder.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
item.rmdir()
folder.rmdir()
return cleanup()
def test_save_and_exists(self):
"""Test saving data and checking existence."""
filename = get_example_filename()
data = get_example_data()
assert not self.storage.exists(filename)
self.storage.save(filename, data)
assert self.storage.exists(filename)
def test_load_once(self):
"""Test loading data once."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
loaded_data = self.storage.load_once(filename)
assert loaded_data == data
def test_load_stream(self):
"""Test loading data as a stream."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
generator = self.storage.load_stream(filename)
assert isinstance(generator, Generator)
assert next(generator) == data
def test_download(self):
"""Test downloading data to a file."""
filename = get_example_filename()
filepath = str(Path(get_opendal_bucket()) / filename)
data = get_example_data()
self.storage.save(filename, data)
self.storage.download(filename, filepath)
def test_delete(self):
"""Test deleting a file."""
filename = get_example_filename()
data = get_example_data()
self.storage.save(filename, data)
assert self.storage.exists(filename)
self.storage.delete(filename)
assert not self.storage.exists(filename)

View File

@@ -0,0 +1,20 @@
from unittest.mock import patch
import pytest
from qcloud_cos import CosConfig # type: ignore
from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import (
BaseStorageTest,
get_example_bucket,
)
from tests.unit_tests.oss.__mock.tencent_cos import setup_tencent_cos_mock
class TestTencentCos(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_tencent_cos_mock):
"""Executed before each test method."""
with patch.object(CosConfig, "__init__", return_value=None):
self.storage = TencentCosStorage()
self.storage.bucket_name = get_example_bucket()

View File

@@ -0,0 +1,23 @@
import pytest
from tos import TosClientV2 # type: ignore
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import (
BaseStorageTest,
get_example_bucket,
)
from tests.unit_tests.oss.__mock.volcengine_tos import setup_volcengine_tos_mock
class TestVolcengineTos(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_volcengine_tos_mock):
"""Executed before each test method."""
self.storage = VolcengineTosStorage()
self.storage.bucket_name = get_example_bucket()
self.storage.client = TosClientV2(
ak="dify",
sk="dify",
endpoint="https://xxx.volces.com",
region="cn-beijing",
)

View File

@@ -0,0 +1,424 @@
# test for api/services/workflow/workflow_converter.py
import json
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter
@pytest.fixture
def default_variables():
value = [
VariableEntity(
variable="text_input",
label="text-input",
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="paragraph",
label="paragraph",
type=VariableEntityType.PARAGRAPH,
),
VariableEntity(
variable="select",
label="select",
type=VariableEntityType.SELECT,
),
]
return value
def test__convert_to_start_node(default_variables):
# act
result = WorkflowConverter()._convert_to_start_node(default_variables)
# assert
assert isinstance(result["data"]["variables"][0]["type"], str)
assert result["data"]["variables"][0]["type"] == "text-input"
assert result["data"]["variables"][0]["variable"] == "text_input"
assert result["data"]["variables"][1]["variable"] == "paragraph"
assert result["data"]["variables"][2]["variable"] == "select"
def test__convert_to_http_request_node_for_chatbot(default_variables):
"""
Test convert to http request nodes for chatbot
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT.value
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
id=api_based_extension_id,
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == "{{#sys.query#}}" # for chatbot
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_http_request_node_for_workflow_app(default_variables):
"""
Test convert to http request nodes for workflow app
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW.value
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
id=api_based_extension_id,
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == ""
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_knowledge_retrieval_node_for_chatbot():
new_app_mode = AppMode.ADVANCED_CHAT
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_knowledge_retrieval_node_for_workflow_app():
new_app_mode = AppMode.WORKFLOW
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable="query",
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM,
),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], list)
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
template = prompt_template.advanced_chat_prompt_template.messages[0].text
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], dict)
template = prompt_template.advanced_completion_prompt_template.prompt
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template

View File

@@ -0,0 +1,162 @@
from unittest.mock import MagicMock
import pytest
from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()
@pytest.fixture
def mock_app(self):
app = MagicMock(spec=App)
app.id = "app-id-1"
app.workflow_id = "workflow-id-1"
app.tenant_id = "tenant-id-1"
return app
@pytest.fixture
def mock_workflows(self):
workflows = []
for i in range(5):
workflow = MagicMock(spec=Workflow)
workflow.id = f"workflow-id-{i}"
workflow.app_id = "app-id-1"
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
workflows.append(workflow)
return workflows
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_not_called()
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = mock_workflows[:3]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
assert workflows == mock_workflows[:3]
assert has_more is False
mock_session.scalars.assert_called_once()
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Return 4 items when limit is 3, which should indicate has_more=True
mock_scalar_result.all.return_value = mock_workflows[:4]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
# Should return only the first 3 items
assert len(workflows) == 3
assert workflows == mock_workflows[:3]
assert has_more is True
# Test page 2
mock_scalar_result.all.return_value = mock_workflows[3:]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
)
assert len(workflows) == 2
assert has_more is False
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows for user-id-1
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a user filter clause
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows that have a marked_name
named_workflows = [w for w in mock_workflows if w.marked_name]
mock_scalar_result.all.return_value = named_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
)
assert workflows == named_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a named_only filter clause
args = mock_session.scalars.call_args[0][0]
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Combined filter: user-id-1 and has marked_name
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that both filters are applied
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = []
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()

View File

@@ -0,0 +1,123 @@
from textwrap import dedent
import pytest
from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(
dedent(
"""\
- first
- second
# - commented
- third
- 9999999999999
- forth
"""
)
)
return str(tmp_path)
@pytest.fixture
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(
dedent(
"""\
# - commented1
# - commented2
-
-
"""
)
)
return str(tmp_path)
def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
assert len(position_map) == 4
assert position_map == {
"first": 0,
"second": 1,
"third": 2,
"forth": 3,
}
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
position_map = get_position_map(
folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml"
)
assert position_map == {}
def test_excluded_position_data(prepare_example_positions_yaml):
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = set()
exclude_set = {"9999999999999"}
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
data = [
"forth",
"first",
"second",
"third",
"9999999999999",
"extra1",
"extra2",
]
# filter out the data
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
# sort data by position map
sorted_data = sort_by_position_map(
position_map=position_map,
data=data,
name_func=lambda x: x,
)
# assert the result in the correct order
assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"]
def test_included_position_data(prepare_example_positions_yaml):
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = {"forth", "first"}
exclude_set = {}
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
data = [
"forth",
"first",
"second",
"third",
"9999999999999",
"extra1",
"extra2",
]
# filter out the data
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
# sort data by position map
sorted_data = sort_by_position_map(
position_map=position_map,
data=data,
name_func=lambda x: x,
)
# assert the result in the correct order
assert sorted_data == ["forth", "first"]

View File

@@ -0,0 +1,18 @@
import pytest
from core.tools.utils.text_processing_utils import remove_leading_symbols
@pytest.mark.parametrize(
("input_text", "expected_output"),
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),
],
)
def test_remove_leading_symbols(input_text, expected_output):
assert remove_leading_symbols(input_text) == expected_output

View File

@@ -0,0 +1,83 @@
from textwrap import dedent
import pytest
from yaml import YAMLError # type: ignore
from core.tools.utils.yaml_utils import load_yaml_file
EXAMPLE_YAML_FILE = "example_yaml.yaml"
INVALID_YAML_FILE = "invalid_yaml.yaml"
NON_EXISTING_YAML_FILE = "non_existing_file.yaml"
@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text(
dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
empty_key:
"""
)
)
return str(file_path)
@pytest.fixture
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text(
dedent(
"""\
address:
city: Example City
country: Example Country
age: 30
gender: male
languages:
- Python
- Java
- C++
"""
)
)
return str(file_path)
def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path="") == {}
with pytest.raises(FileNotFoundError):
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data["age"] == 30
assert yaml_data["gender"] == "male"
assert yaml_data["address"]["city"] == "Example City"
assert set(yaml_data["languages"]) == {"Python", "Java", "C++"}
assert yaml_data.get("empty_key") is None
assert yaml_data.get("non_existed_key") is None
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
# yaml syntax error
with pytest.raises(YAMLError):
load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False)
# ignore error
assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {}