Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from . import errors
__all__ = ["errors"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
has_context = args["has_context"]
if "baichuan" in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
# default return empty dict
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
)
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
)
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
# default return empty dict
return {}

View File

@@ -0,0 +1,176 @@
import threading
from typing import Optional
import pytz
from flask_login import current_user # type: ignore
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.plugin.manager.agent import PluginAgentManager
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
"""
Service to get agent logs
"""
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation | None = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
.first()
)
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Optional[Message] = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
.first()
)
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field
executor = (
db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
)
else:
executor = (
db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
)
if executor:
executor = executor.name
else:
executor = "Unknown"
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("App model config not found")
result = {
"meta": {
"status": "success",
"executor": executor,
"start_time": message.created_at.astimezone(timezone).isoformat(),
"elapsed_time": message.provider_response_latency,
"total_tokens": message.answer_tokens + message.message_tokens,
"agent_mode": app_model_config.agent_mode_dict.get("strategy", "react"),
"iterations": len(agent_thoughts),
},
"iterations": [],
"files": message.message_files,
}
agent_config = AgentConfigManager.convert(app_model_config.to_dict())
if not agent_config:
raise ValueError("Agent config not found")
agent_tools = agent_config.tools or []
def find_agent_tool(tool_name: str):
for agent_tool in agent_tools:
if agent_tool.tool_name == tool_name:
return agent_tool
for agent_thought in agent_thoughts:
tools = agent_thought.tools
tool_labels = agent_thought.tool_labels
tool_meta = agent_thought.tool_meta
tool_inputs = agent_thought.tool_inputs_dict
tool_outputs = agent_thought.tool_outputs_dict or {}
tool_calls = []
for tool in tools:
tool_name = tool
tool_label = tool_labels.get(tool_name, tool_name)
tool_input = tool_inputs.get(tool_name, {})
tool_output = tool_outputs.get(tool_name, {})
tool_meta_data = tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get("tool_config", {})
if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_config.get("tool_provider_type", ""),
provider_id=tool_config.get("tool_provider", ""),
)
if not tool_icon:
tool_entity = find_agent_tool(tool_name)
if tool_entity:
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,
)
else:
tool_icon = ""
tool_calls.append(
{
"status": "success" if not tool_meta_data.get("error") else "error",
"error": tool_meta_data.get("error"),
"time_cost": tool_meta_data.get("time_cost", 0),
"tool_name": tool_name,
"tool_label": tool_label,
"tool_input": tool_input,
"tool_output": tool_output,
"tool_parameters": tool_meta_data.get("tool_parameters", {}),
"tool_icon": tool_icon,
}
)
result["iterations"].append(
{
"tokens": agent_thought.tokens,
"tool_calls": tool_calls,
"tool_raw": {
"inputs": agent_thought.tool_input,
"outputs": agent_thought.observation,
},
"thought": agent_thought.thought,
"created_at": agent_thought.created_at.isoformat(),
"files": agent_thought.files,
}
)
return result
@classmethod
def list_agent_providers(cls, user_id: str, tenant_id: str):
"""
List agent providers
"""
manager = PluginAgentManager()
return manager.fetch_agent_strategy_providers(tenant_id)
@classmethod
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
"""
Get agent provider
"""
manager = PluginAgentManager()
try:
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
except PluginDaemonClientSideError as e:
raise ValueError(str(e)) from e

View File

@@ -0,0 +1,444 @@
import datetime
import uuid
from typing import cast
import pandas as pd
from flask_login import current_user # type: ignore
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
if args.get("message_id"):
message_id = str(args["message_id"])
# get message info
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
annotation.question = args["question"]
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args["answer"],
question=args["question"],
account_id=current_user.id,
)
else:
annotation = MessageAnnotation(
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
enable_annotation_reply_task.delay(
str(job_id),
app_id,
current_user.id,
current_user.current_tenant_id,
args["score_threshold"],
args["embedding_provider_name"],
args["embedding_model_name"],
)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
return {"job_id": cache_result, "job_status": "processing"}
# async job
job_id = str(uuid.uuid4())
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(disable_app_annotation_job_key, "waiting")
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
if keyword:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
MessageAnnotation.content.ilike("%{}%".format(keyword)),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
else:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
return annotations.items, annotations.total
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = MessageAnnotation(
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
current_user.current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation.content = args["answer"]
annotation.question = args["question"]
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
update_annotation_to_index_task.delay(
annotation.id,
annotation.question,
current_user.current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
)
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
delete_annotation_index_task.delay(
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
)
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
result.append(content)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# check annotation limit
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
raise ValueError("The number of annotations exceeds the limit of your subscription.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
)
except Exception as e:
return {"error_msg": str(e)}
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (
AppAnnotationHitHistory.query.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
return annotation_hit_histories.items, annotation_hit_histories.total
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
return annotation
@classmethod
def add_annotation_history(
cls,
annotation_id: str,
app_id: str,
annotation_question: str,
annotation_content: str,
query: str,
user_id: str,
message_id: str,
from_source: str,
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
)
annotation_hit_history = AppAnnotationHitHistory(
annotation_id=annotation_id,
app_id=app_id,
account_id=user_id,
question=query,
source=from_source,
score=score,
message_id=message_id,
annotation_question=annotation_question,
annotation_content=annotation_content,
)
db.session.add(annotation_hit_history)
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting)
.filter(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
.first()
)
if not annotation_setting:
raise NotFound("App annotation not found")
annotation_setting.score_threshold = args["score_threshold"]
annotation_setting.updated_user_id = current_user.id
annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(annotation_setting)
db.session.commit()
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}

View File

@@ -0,0 +1,105 @@
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token, encrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
extension_list = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.order_by(APIBasedExtension.created_at.desc())
.all()
)
for extension in extension_list:
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension_list
@classmethod
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
cls._validation(extension_data)
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
db.session.add(extension_data)
db.session.commit()
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
db.session.delete(extension_data)
db.session.commit()
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=tenant_id)
.filter_by(id=api_based_extension_id)
.first()
)
if not extension:
raise ValueError("API based extension is not found")
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
# name
if not extension_data.name:
raise ValueError("name must not be empty")
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = (
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.filter(APIBasedExtension.id != extension_data.id)
.first()
)
if is_name_existed:
raise ValueError("name must be unique, it is already existed")
# api_endpoint
if not extension_data.api_endpoint:
raise ValueError("api_endpoint must not be empty")
# api_key
if not extension_data.api_key:
raise ValueError("api_key must not be empty")
if len(extension_data.api_key) < 5:
raise ValueError("api_key must be at least 5 characters")
# check endpoint
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
if resp.get("result") != "pong":
raise ValueError(resp)
except Exception as e:
raise ValueError("connection error: {}".format(e))

View File

@@ -0,0 +1,722 @@
import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from events.app_event import app_model_config_was_updated, app_was_created
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account, App, AppMode
from models.model import AppModelConfig
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.5"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class Import(BaseModel):
id: str
status: ImportStatus
app_id: Optional[str] = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# Compare major version and minor version
if current_ver.major != imported_ver.major or current_ver.minor != imported_ver.minor:
return ImportStatus.PENDING
if current_ver.micro != imported_ver.micro:
return ImportStatus.COMPLETED_WITH_WARNINGS
return ImportStatus.COMPLETED
class PendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
app_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None
class AppDslService:
def __init__(self, session: Session):
self._session = session
def import_app(
self,
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
app_id: Optional[str] = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="File size exceeds the limit of 10MB",
)
if not content:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Empty content from url",
)
except Exception as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error fetching YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
# Process YAML content
try:
# Parse YAML to validate format
data = yaml.safe_load(content)
if not isinstance(data, dict):
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: content must be a mapping",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
if not data.get("kind") or data.get("kind") != "app":
data["kind"] = "app"
imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract app data
app_data = data.get("app")
if not app_data:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Missing app data in YAML content",
)
# If app_id is provided, check if it exists
app = None
if app_id:
stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id)
app = self._session.scalar(stmt)
if not app:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="App not found",
)
if app.mode not in [AppMode.WORKFLOW.value, AppMode.ADVANCED_CHAT.value]:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Only workflow or advanced chat apps can be overwritten",
)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = PendingData(
import_mode=import_mode,
yaml_content=content,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
app_id=app_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return Import(
id=import_id,
status=status,
app_id=app_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
elif imported_version <= "0.1.5":
if "workflow" in data:
graph = data.get("workflow", {}).get("graph", {})
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
else:
dependencies_list = self._extract_dependencies_from_model_config(data.get("model_config", {}))
check_dependencies_pending_data = DependenciesAnalysisService.generate_latest_dependencies(
dependencies_list
)
# Create or update app
app = self._create_or_update_app(
app=app,
data=data,
account=account,
name=name,
description=description,
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
dependencies=check_dependencies_pending_data,
)
return Import(
id=import_id,
status=status,
app_id=app.id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import app")
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> Import:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return Import(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data = PendingData.model_validate_json(pending_data)
data = yaml.safe_load(pending_data.yaml_content)
app = None
if pending_data.app_id:
stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id)
app = self._session.scalar(stmt)
# Create or update app
app = self._create_or_update_app(
app=app,
data=data,
account=account,
name=pending_data.name,
description=pending_data.description,
icon_type=pending_data.icon_type,
icon=pending_data.icon,
icon_background=pending_data.icon_background,
)
# Delete import info from Redis
redis_client.delete(redis_key)
return Import(
id=import_id,
status=ImportStatus.COMPLETED,
app_id=app.id,
current_dsl_version=CURRENT_DSL_VERSION,
imported_dsl_version=data.get("version", "0.1.0"),
)
except Exception as e:
logger.exception("Error confirming import")
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(
self,
*,
app_model: App,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_app(
self,
*,
app: Optional[App],
data: dict,
account: Account,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
app_mode = app_data.get("mode")
if not app_mode:
raise ValueError("loss app mode")
app_mode = AppMode(app_mode)
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = icon or str(app_data.get("icon", ""))
if app:
# Update existing app
app.name = name or app_data.get("name", app.name)
app.description = description or app_data.get("description", app.description)
app.icon_type = icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
app.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
app = App()
app.id = str(uuid4())
app.tenant_id = account.current_tenant_id
app.mode = app_mode.value
app.name = name or app_data.get("name", "")
app.description = description or app_data.get("description", "")
app.icon_type = icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF")
app.enable_site = True
app.enable_api = True
app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False)
app.created_by = account.id
app.updated_by = account.id
self._session.add(app)
self._session.commit()
app_was_created.send(app, account=account)
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig().from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
app_model_config.app_id = app.id
app_model_config.created_by = account.id
app_model_config.updated_by = account.id
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
else:
raise ValueError("Invalid app mode")
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
"""
Export app
:param app_model: App instance
:return:
"""
app_mode = AppMode.value_of(app_model.mode)
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "app",
"app": {
"name": app_model.name,
"mode": app_model.mode,
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
"description": app_model.description,
"use_icon_as_answer_icon": app_model.use_icon_as_answer_icon,
},
}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
cls._append_workflow_export_data(
export_data=export_data, app_model=app_model, include_secret=include_secret
)
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
:param app_model: App instance
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies
)
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
"""
Append model config export data
:param export_data: export data
:param app_model: App instance
"""
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
return dependencies
@classmethod
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "reranking_model"
):
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
),
)
elif (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "weighted_score"
):
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
vector_setting = (
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
)
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
vector_setting.embedding_provider_name
),
)
elif knowledge_retrieval_entity.retrieval_mode == "single":
model_config = knowledge_retrieval_entity.single_retrieval_config
if model_config:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
model_config.model.provider
),
)
case _:
# TODO: Handle default case or unknown node types
pass
except Exception as e:
logger.exception("Error extracting node dependency", exc_info=e)
return dependencies
@classmethod
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
"""
Extract dependencies from model config
:param model_config: model config dict
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
try:
# completion model
model_dict = model_config.get("model", {})
if model_dict:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
)
# reranking model
dataset_configs = model_config.get("dataset_configs", {})
if dataset_configs:
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
if dataset_config.get("reranking_model"):
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
dataset_config.get("reranking_model", {})
.get("reranking_provider_name", {})
.get("provider")
)
)
# tools
agent_configs = model_config.get("agent_mode", {})
if agent_configs:
for agent_config in agent_configs.get("tools", []):
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
)
except Exception as e:
logger.exception("Error extracting model config dependency", exc_info=e)
return dependencies
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
if not dependencies:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)

View File

@@ -0,0 +1,203 @@
from collections.abc import Generator, Mapping
from typing import Any, Union
from openai._exceptions import RateLimitError
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
class AppGenerateService:
@classmethod
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
App Content Generate
:param app_model: app model
:param user: user
:param args: args
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
max_active_request = AppGenerateService._get_max_active_requests(app_model)
rate_limit = RateLimit(app_model.id, max_active_request)
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, invoke_from)
return rate_limit.generate(
AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, invoke_from)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
workflow_thread_pool_id=None,
),
),
request_id,
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
except RateLimitError as e:
raise InvokeRateLimitError(str(e))
except Exception:
rate_limit.exit(request_id)
raise
finally:
if not streaming:
rate_limit.exit(request_id)
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
"""
Generate more like this
:param app_model: app model
:param user: user
:param message_id: message id
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
return CompletionAppGenerator().generate_more_like_this(
app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
"""
Get workflow
:param app_model: app model
:param invoke_from: invoke from
:return:
"""
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = workflow_service.get_published_workflow(app_model=app_model)
if not workflow:
raise ValueError("Workflow not published")
return workflow

View File

@@ -0,0 +1,17 @@
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:
return AgentChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.COMPLETION:
return CompletionAppConfigManager.config_validate(tenant_id, config)
else:
raise ValueError(f"Invalid app mode: {app_mode}")

View File

@@ -0,0 +1,376 @@
import json
import logging
from datetime import UTC, datetime
from typing import Optional, cast
from flask_login import current_user # type: ignore
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
from constants.model_template import default_app_templates
from core.agent.entities import AgentToolEntity
from core.app.features.rate_limiting import RateLimit
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
class AppService:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:return:
"""
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
elif args["mode"] == "chat":
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
if args.get("tag_ids"):
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
if target_ids:
filters.append(App.id.in_(target_ids))
else:
return None
app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
error_out=False,
)
return app_models
def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
"""
Create app
:param tenant_id: tenant id
:param args: request args
:param account: Account instance
"""
app_mode = AppMode.value_of(args["mode"])
app_template = default_app_templates[app_mode]
# get model config
default_model_config = app_template.get("model_config")
default_model_config = default_model_config.copy() if default_model_config else None
if default_model_config and "model" in default_model_config:
# get model provider
model_manager = ModelManager()
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
except Exception as e:
logging.exception(f"Get default model instance failed, tenant_id: {tenant_id}")
model_instance = None
if model_instance:
if (
model_instance.model == default_model_config["model"]["name"]
and model_instance.provider == default_model_config["model"]["provider"]
):
default_model_dict = default_model_config["model"]
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for model {model_instance.model}")
default_model_dict = {
"provider": model_instance.provider,
"name": model_instance.model,
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
"completion_params": {},
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
default_model_dict = default_model_config["model"]
default_model_config["model"] = json.dumps(default_model_dict)
app = App(**app_template["app"])
app.name = args["name"]
app.description = args.get("description", "")
app.mode = args["mode"]
app.icon_type = args.get("icon_type", "emoji")
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.tenant_id = tenant_id
app.api_rph = args.get("api_rph", 0)
app.api_rpm = args.get("api_rpm", 0)
app.created_by = account.id
app.updated_by = account.id
db.session.add(app)
db.session.flush()
if default_model_config:
app_model_config = AppModelConfig(**default_model_config)
app_model_config.app_id = app.id
app_model_config.created_by = account.id
app_model_config.updated_by = account.id
db.session.add(app_model_config)
db.session.flush()
app.app_model_config_id = app_model_config.id
db.session.commit()
app_was_created.send(app, account=account)
return app
def get_app(self, app: App) -> App:
"""
Get App
"""
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app.id}",
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool["tool_parameters"] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app
def update_app(self, app: App, args: dict) -> App:
"""
Update app
:param app: App instance
:param args: request args
:return: App instance
"""
app.name = args.get("name")
app.description = args.get("description", "")
app.max_active_requests = args.get("max_active_requests")
app.icon_type = args.get("icon_type", "emoji")
app.icon = args.get("icon")
app.icon_background = args.get("icon_background")
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if app.max_active_requests is not None:
rate_limit = RateLimit(app.id, app.max_active_requests)
rate_limit.flush_cache(use_local_value=True)
return app
def update_app_name(self, app: App, name: str) -> App:
"""
Update app name
:param app: App instance
:param name: new name
:return: App instance
"""
app.name = name
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return app
def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
"""
Update app icon
:param app: App instance
:param icon: new icon
:param icon_background: new icon_background
:return: App instance
"""
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return app
def update_app_site_status(self, app: App, enable_site: bool) -> App:
"""
Update app site status
:param app: App instance
:param enable_site: enable site status
:return: App instance
"""
if enable_site == app.enable_site:
return app
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return app
def update_app_api_status(self, app: App, enable_api: bool) -> App:
"""
Update app api status
:param app: App instance
:param enable_api: enable api status
:return: App instance
"""
if enable_api == app.enable_api:
return app
app.enable_api = enable_api
app.updated_by = current_user.id
app.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return app
def delete_app(self, app: App) -> None:
"""
Delete app
:param app: App instance
"""
db.session.delete(app)
db.session.commit()
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
"""
Get app meta info
:param app_model: app model
:return:
"""
app_mode = AppMode.value_of(app_model.mode)
meta: dict = {"tool_icons": {}}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
return meta
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
tools = []
for node in nodes:
if node.get("data", {}).get("type") == "tool":
node_data = node.get("data", {})
tools.append(
{
"provider_type": node_data.get("provider_type"),
"provider_id": node_data.get("provider_id"),
"tool_name": node_data.get("tool_name"),
"tool_parameters": {},
}
)
else:
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
if not app_model_config:
return meta
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = agent_config.get("tools", [])
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
tool_name = tool.get("tool_name", "")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
return meta

View File

@@ -0,0 +1,161 @@
import io
import logging
import uuid
from typing import Optional
from werkzeug.datastructures import FileStorage
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.model import App, AppMode, AppModelConfig, Message
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechServiceError,
UnsupportedAudioTypeServiceError,
)
FILE_SIZE = 30
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"]
logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
features_dict = workflow.features_dict
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")
if file is None:
raise NoAudioUploadedServiceError()
extension = file.mimetype
if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]:
raise UnsupportedAudioTypeServiceError()
file_content = file.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
)
if model_instance is None:
raise ProviderNotSupportSpeechToTextServiceError()
buffer = io.BytesIO(file_content)
buffer.name = "temp.mp3"
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
@classmethod
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
):
from collections.abc import Generator
from flask import Response, stream_with_context
from app import app
from extensions.ext_database import db
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None):
with app.app_context():
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("TTS is not enabled")
features_dict = workflow.features_dict
if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
raise ValueError("TTS is not enabled")
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
else:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice") if voice is None else voice
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
if message_id:
try:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == "normal":
return None
else:
response = invoke_tts(message.answer, app_model=app_model, voice=voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
if text is None:
raise ValueError("Text is required")
response = invoke_tts(text, app_model, voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
@classmethod
def transcript_tts_voices(cls, tenant_id: str, language: str):
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
if model_instance is None:
raise ProviderNotSupportTextToSpeechServiceError()
try:
return model_instance.get_tts_voices(language)
except Exception as e:
raise e

View File

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
raise NotImplementedError

View File

@@ -0,0 +1,25 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)
def validate_credentials(self):
return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
return FirecrawlAuth
case AuthType.JINA:
from services.auth.jina.jina import JinaAuth
return JinaAuth
case _:
raise ValueError("Invalid provider")

View File

@@ -0,0 +1,74 @@
import json
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args["category"]
data_source_api_key_binding.provider = args["provider"]
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False),
)
.first()
)
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
JINA = "jinareader"

View File

@@ -0,0 +1,49 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,44 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,44 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
"url": "https://example.com",
}
response = self._post_request("https://r.jina.ai", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")

View File

@@ -0,0 +1,133 @@
import os
from typing import Literal, Optional
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from extensions.ext_database import db
from libs.helper import RateLimiter
from models.account import TenantAccountJoin, TenantAccountRole
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
return {
"limit": knowledge_rate_limit.get("limit", 10),
"subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
}
@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/subscription/payment-link", params=params)
@classmethod
def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
params = {
"provider_name": provider_name,
"tenant_id": tenant_id,
"account_id": account_id,
"prefilled_email": prefilled_email,
}
return cls._send_request("GET", "/model-provider/payment-link", params=params)
@classmethod
def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
return cls._send_request("GET", "/invoices", params=params)
@classmethod
@retry(
wait=wait_fixed(2),
stop=stop_before_delay(10),
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
if not join:
raise ValueError("Tenant account join not found")
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError("Only team owner or team admin can perform this action")
@classmethod
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
return cls._send_request("DELETE", "/account/", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:
params = {"email": email}
try:
response = cls._send_request("GET", "/account/in-freeze", params=params)
return bool(response.get("data", False))
except Exception:
return False
@classmethod
def update_account_deletion_feedback(cls, email: str, feedback: str):
"""Update account deletion feedback."""
json = {"email": email, "feedback": feedback}
return cls._send_request("POST", "/account/delete-feedback", json=json)
@classmethod
def get_compliance_download_link(
cls,
doc_name: str,
account_id: str,
tenant_id: str,
ip: str,
device_info: str,
):
limiter_key = f"{account_id}:{tenant_id}"
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
from controllers.console.error import CompilanceRateLimitError
raise CompilanceRateLimitError()
json = {
"doc_name": doc_name,
"account_id": account_id,
"tenant_id": tenant_id,
"ip_address": ip,
"device_info": device_info,
}
res = cls._send_request("POST", "/compliance/download", json=json)
cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key)
return res

View File

@@ -0,0 +1,16 @@
from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
module_extensions = code_based_extension.module_extensions(module)
return [
{
"name": module_extension.name,
"label": module_extension.label,
"form_schema": module_extension.form_schema,
}
for module_extension in module_extensions
if not module_extension.builtin
]

View File

@@ -0,0 +1,168 @@
from collections.abc import Callable, Sequence
from datetime import UTC, datetime
from typing import Optional, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.errors.message import MessageNotExistsError
class ConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
stmt = select(Conversation).where(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)
if include_ids is not None:
stmt = stmt.where(Conversation.id.in_(include_ids))
if exclude_ids is not None:
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)
if last_id:
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
if not last_conversation:
raise LastConversationNotExistsError()
# build filters based on sorting
filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=last_conversation,
)
stmt = stmt.where(filter_condition)
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
conversations = session.scalars(query_stmt).all()
has_more = False
if len(conversations) == limit:
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
@classmethod
def _get_sort_params(cls, sort_by: str):
if sort_by.startswith("-"):
return sort_by[1:], desc
return sort_by, asc
@classmethod
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
field_value = getattr(reference_conversation, sort_field)
if sort_direction == desc:
return getattr(Conversation, sort_field) < field_value
else:
return getattr(Conversation, sort_field) > field_value
@classmethod
def rename(
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
name: str,
auto_generate: bool,
):
conversation = cls.get_conversation(app_model, conversation_id, user)
if auto_generate:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return conversation
@classmethod
def auto_generate_name(cls, app_model: App, conversation: Conversation):
# get conversation first message
message = (
db.session.query(Message)
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
)
if not message:
raise MessageNotExistsError()
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, message.query, conversation.id, app_model.id
)
conversation.name = name
except:
pass
db.session.commit()
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False,
)
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.is_deleted = True
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
import os
import requests
class EnterpriseRequest:
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
proxies = {
"http": "",
"https": "",
}
@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
return response.json()

View File

@@ -0,0 +1,11 @@
from services.enterprise.base import EnterpriseRequest
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_app_web_sso_enabled(cls, app_code):
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")

View File

@@ -0,0 +1,26 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel
class AuthorizationConfig(BaseModel):
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
class ProcessStatusSetting(BaseModel):
request_method: str
url: str
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None

View File

@@ -0,0 +1,126 @@
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
class SegmentUpdateEntity(BaseModel):
content: str
answer: Optional[str] = None
keywords: Optional[list[str]] = None
enabled: Optional[bool] = None
class ParentMode(str, Enum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
type: str
class NotionInfo(BaseModel):
workspace_id: str
pages: list[NotionPage]
class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True
class FileInfo(BaseModel):
file_ids: list[str]
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
metadata: Optional[MetaDataConfig] = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str

View File

@@ -0,0 +1,169 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
from configs import dify_config
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
)
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
from models.provider import ProviderType
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
tenant_id: str
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@@ -0,0 +1,29 @@
from . import (
account,
app,
app_model_config,
audio,
base,
completion,
conversation,
dataset,
document,
file,
index,
message,
)
__all__ = [
"account",
"app",
"app_model_config",
"audio",
"base",
"completion",
"conversation",
"dataset",
"document",
"file",
"index",
"message",
]

View File

@@ -0,0 +1,61 @@
from services.errors.base import BaseServiceError
class AccountNotFoundError(BaseServiceError):
pass
class AccountRegisterError(BaseServiceError):
pass
class AccountLoginError(BaseServiceError):
pass
class AccountPasswordError(BaseServiceError):
pass
class AccountNotLinkTenantError(BaseServiceError):
pass
class CurrentPasswordIncorrectError(BaseServiceError):
pass
class LinkAccountIntegrateError(BaseServiceError):
pass
class TenantNotFoundError(BaseServiceError):
pass
class AccountAlreadyInTenantError(BaseServiceError):
pass
class InvalidActionError(BaseServiceError):
pass
class CannotOperateSelfError(BaseServiceError):
pass
class NoPermissionError(BaseServiceError):
pass
class MemberNotInTenantError(BaseServiceError):
pass
class RoleAlreadyAssignedError(BaseServiceError):
pass
class RateLimitExceededError(BaseServiceError):
pass

View File

@@ -0,0 +1,6 @@
class MoreLikeThisDisabledError(Exception):
pass
class WorkflowHashNotEqualError(Exception):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class AppModelConfigBrokenError(BaseServiceError):
pass

View File

@@ -0,0 +1,22 @@
class NoAudioUploadedServiceError(Exception):
pass
class AudioTooLargeServiceError(Exception):
pass
class UnsupportedAudioTypeServiceError(Exception):
pass
class ProviderNotSupportSpeechToTextServiceError(Exception):
pass
class ProviderNotSupportTextToSpeechServiceError(Exception):
pass
class ProviderNotSupportTextToSpeechLanageServiceError(Exception):
pass

View File

@@ -0,0 +1,6 @@
from typing import Optional
class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None):
self.description = description

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class CompletionStoppedError(BaseServiceError):
pass

View File

@@ -0,0 +1,13 @@
from services.errors.base import BaseServiceError
class LastConversationNotExistsError(BaseServiceError):
pass
class ConversationNotExistsError(BaseServiceError):
pass
class ConversationCompletedError(Exception):
pass

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError):
pass
class DatasetInUseError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class DocumentIndexingError(BaseServiceError):
pass

View File

@@ -0,0 +1,13 @@
from services.errors.base import BaseServiceError
class FileNotExistsError(BaseServiceError):
pass
class FileTooLargeError(BaseServiceError):
description = "{message}"
class UnsupportedFileTypeError(BaseServiceError):
pass

View File

@@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class IndexNotInitializedError(BaseServiceError):
pass

View File

@@ -0,0 +1,19 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"

View File

@@ -0,0 +1,17 @@
from services.errors.base import BaseServiceError
class FirstMessageNotExistsError(BaseServiceError):
pass
class LastMessageNotExistsError(BaseServiceError):
pass
class MessageNotExistsError(BaseServiceError):
pass
class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError):
pass

View File

@@ -0,0 +1,10 @@
class WorkflowInUseError(ValueError):
"""Raised when attempting to delete a workflow that's in use by an app"""
pass
class DraftWorkflowDeletionError(ValueError):
"""Raised when attempting to delete a draft workflow"""
pass

View File

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class WorkSpaceNotAllowedCreateError(BaseServiceError):
pass
class WorkSpaceNotFoundError(BaseServiceError):
pass

View File

@@ -0,0 +1,288 @@
import json
from copy import deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
import httpx
import validators
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from extensions.ext_database import db
from models.dataset import (
Dataset,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
)
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
ExternalKnowledgeApis.created_at.desc()
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
return external_knowledge_apis.items, external_knowledge_apis.total
@classmethod
def validate_api_list(cls, api_settings: dict):
if not api_settings:
raise ValueError("api list is empty")
if "endpoint" not in api_settings and not api_settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in api_settings and not api_settings["api_key"]:
raise ValueError("api_key is required")
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
settings = args.get("settings")
if settings is None:
raise ValueError("settings is required")
ExternalDatasetService.check_endpoint_and_api_key(settings)
external_knowledge_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=user_id,
updated_by=user_id,
name=args.get("name"),
description=args.get("description", ""),
settings=json.dumps(args.get("settings"), ensure_ascii=False),
)
db.session.add(external_knowledge_api)
db.session.commit()
return external_knowledge_api
@staticmethod
def check_endpoint_and_api_key(settings: dict):
if "endpoint" not in settings or not settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in settings or not settings["api_key"]:
raise ValueError("api_key is required")
endpoint = f"{settings['endpoint']}/retrieval"
api_key = settings["api_key"]
if not validators.url(endpoint, simple_host=True):
if not endpoint.startswith("http://") and not endpoint.startswith("https://"):
raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://")
else:
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
if response.status_code == 404:
raise ValueError(f"Not Found: failed to connect to the endpoint: {endpoint}")
if response.status_code == 403:
raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}")
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return external_knowledge_api
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
db.session.delete(external_knowledge_api)
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
for setting in settings:
custom_parameters = setting.get("document_process_setting")
if custom_parameters:
for parameter in custom_parameters:
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
raise ValueError(f"{parameter.get('name')} is required")
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
"""
do http request depending on api bundle
"""
kwargs = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,
}
response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
data=json.dumps(settings.params), files=files, **kwargs
)
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
else:
headers = {}
if authorization.type == "api-key":
if authorization.config is None:
raise ValueError("authorization config is required")
if authorization.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if authorization.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif authorization.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
return headers
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.parse_obj(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
).first()
if external_knowledge_api is None:
raise ValueError("api template not found")
dataset = Dataset(
tenant_id=tenant_id,
name=args.get("name"),
description=args.get("description", ""),
provider="external",
retrieval_model=args.get("external_retrieval_model"),
created_by=user_id,
)
db.session.add(dataset)
db.session.flush()
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,
external_knowledge_api_id=args.get("external_knowledge_api_id"),
external_knowledge_id=args.get("external_knowledge_id"),
created_by=user_id,
)
db.session.add(external_knowledge_binding)
db.session.commit()
return dataset
@staticmethod
def fetch_external_knowledge_retrieval(
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_binding.external_knowledge_api_id
).first()
if not external_knowledge_api:
raise ValueError("external api template not found")
settings = json.loads(external_knowledge_api.settings)
headers = {"Content-Type": "application/json"}
if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
score_threshold_enabled = external_retrieval_parameters.get("score_threshold_enabled") or False
score_threshold = external_retrieval_parameters.get("score_threshold", 0.0) if score_threshold_enabled else 0.0
request_params = {
"retrieval_setting": {
"top_k": external_retrieval_parameters.get("top_k"),
"score_threshold": score_threshold,
},
"query": query,
"knowledge_id": external_knowledge_binding.external_knowledge_id,
}
response = ExternalDatasetService.process_external_api(
ExternalKnowledgeApiSetting(
url=f"{settings.get('endpoint')}/retrieval",
request_method="post",
headers=headers,
params=request_params,
),
None,
)
if response.status_code == 200:
return cast(list[Any], response.json().get("records", []))
return []

View File

@@ -0,0 +1,207 @@
from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from configs import dify_config
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
class SubscriptionModel(BaseModel):
plan: str = "sandbox"
interval: str = ""
class BillingModel(BaseModel):
enabled: bool = False
subscription: SubscriptionModel = SubscriptionModel()
class LimitationModel(BaseModel):
size: int = 0
limit: int = 0
class LicenseStatus(StrEnum):
NONE = "none"
INACTIVE = "inactive"
ACTIVE = "active"
EXPIRING = "expiring"
EXPIRED = "expired"
LOST = "lost"
class LicenseModel(BaseModel):
status: LicenseStatus = LicenseStatus.NONE
expired_at: str = ""
class FeatureModel(BaseModel):
billing: BillingModel = BillingModel()
members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard"
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class KnowledgeRateLimitModel(BaseModel):
enabled: bool = False
limit: int = 10
subscription_plan: str = ""
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
enable_marketplace: bool = True
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
class FeatureService:
@classmethod
def get_features(cls, tenant_id: str) -> FeatureModel:
features = FeatureModel()
cls._fulfill_params_from_env(features)
if dify_config.BILLING_ENABLED and tenant_id:
cls._fulfill_params_from_billing_api(features, tenant_id)
return features
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
knowledge_rate_limit = KnowledgeRateLimitModel()
if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
return knowledge_rate_limit
@classmethod
def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel()
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
return system_features
@classmethod
def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel):
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
features.billing.enabled = billing_info["enabled"]
features.billing.subscription.plan = billing_info["subscription"]["plan"]
features.billing.subscription.interval = billing_info["subscription"]["interval"]
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
features.members.limit = billing_info["members"]["limit"]
if "apps" in billing_info:
features.apps.size = billing_info["apps"]["size"]
features.apps.limit = billing_info["apps"]["limit"]
if "vector_space" in billing_info:
features.vector_space.size = billing_info["vector_space"]["size"]
features.vector_space.limit = billing_info["vector_space"]["limit"]
if "documents_upload_quota" in billing_info:
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
if "annotation_quota_limit" in billing_info:
features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
if "docs_processing" in billing_info:
features.docs_processing = billing_info["docs_processing"]
if "can_replace_logo" in billing_info:
features.can_replace_logo = billing_info["can_replace_logo"]
if "model_load_balancing_enabled" in billing_info:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
@classmethod
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()
if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"]
if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
if "license" in enterprise_info:
license_info = enterprise_info["license"]
if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expired_at" in license_info:
features.license.expired_at = license_info["expired_at"]

View File

@@ -0,0 +1,205 @@
import datetime
import hashlib
import uuid
from typing import Any, Literal, Union
from flask_login import current_user # type: ignore
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants import (
AUDIO_EXTENSIONS,
DOCUMENT_EXTENSIONS,
IMAGE_EXTENSIONS,
VIDEO_EXTENSIONS,
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser, UploadFile
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
PREVIEW_WORDS_LIMIT = 3000
class FileService:
@staticmethod
def upload_file(
*,
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser, Any],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:
# get file extension
extension = filename.split(".")[-1].lower()
if len(filename) > 200:
filename = filename.split(".")[0][:200] + "." + extension
if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
raise UnsupportedFileTypeError()
# get file size
file_size = len(content)
# check if the file size is exceeded
if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size):
raise FileTooLargeError
# generate file key
file_uuid = str(uuid.uuid4())
if isinstance(user, Account):
current_tenant_id = user.current_tenant_id
else:
# end_user
current_tenant_id = user.tenant_id
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
# save file to storage
storage.save(file_key, content)
# save file to db
upload_file = UploadFile(
tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=filename,
size=file_size,
extension=extension,
mime_type=mimetype,
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER),
created_by=user.id,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=False,
hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url,
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def is_file_size_within_limit(*, extension: str, file_size: int) -> bool:
if extension in IMAGE_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
elif extension in VIDEO_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024
elif extension in AUDIO_EXTENSIONS:
file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024
else:
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
return file_size <= file_size_limit
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
# save file to storage
storage.save(file_key, text.encode("utf-8"))
# save file to db
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=text_name,
size=len(text),
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_by_role=CreatedByRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True,
used_by=current_user.id,
used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_preview(file_id: str):
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in DOCUMENT_EXTENSIONS:
raise UnsupportedFileTypeError()
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
return text
@staticmethod
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_image_signature(
upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file.mime_type
@staticmethod
def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file
@staticmethod
def get_public_image_preview(file_id: str):
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type

View File

@@ -0,0 +1,146 @@
import logging
import time
from typing import Any
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DatasetQuery
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class HitTestingService:
@classmethod
def retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
"content": query,
"tsne_position": {"x": 0, "y": 0},
},
"records": [],
}
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 2),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents) # type: ignore
@classmethod
def external_retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
) -> dict:
if dataset.provider != "external":
return {
"query": {"content": query},
"records": [],
}
start = time.perf_counter()
all_documents = RetrievalService.external_retrieve(
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()
logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds")
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
)
db.session.add(dataset_query)
db.session.commit()
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]):
records = RetrievalService.format_retrieval_documents(documents)
return {
"query": {
"content": query,
},
"records": [record.model_dump() for record in records],
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
records = []
if dataset.provider == "external":
for document in documents:
record = {
"content": document.get("content", None),
"title": document.get("title", None),
"score": document.get("score", None),
"metadata": document.get("metadata", None),
}
records.append(record)
return {
"query": {"content": query},
"records": records,
}
return {"query": {"content": query}, "records": []}
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

View File

@@ -0,0 +1,45 @@
import boto3 # type: ignore
from configs import dify_config
class ExternalDatasetTestService:
# this service is only for internal testing
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
# example: us-east-1
region_name="us-east-1",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}

View File

@@ -0,0 +1,291 @@
import json
from typing import Optional, Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.workflow_service import WorkflowService
class MessageService:
@classmethod
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
conversation_id: str,
first_id: Optional[str],
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
if not conversation_id:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
conversation = ConversationService.get_conversation(
app_model=app_model, user=user, conversation_id=conversation_id
)
fetch_limit = limit + 1
if first_id:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == first_id)
.first()
)
if not first_message:
raise FirstMessageNotExistsError()
history_messages = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
has_more = False
if len(history_messages) > limit:
has_more = True
history_messages = history_messages[:-1]
if order == "asc":
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Message)
fetch_limit = limit + 1
if conversation_id is not None:
conversation = ConversationService.get_conversation(
app_model=app_model, user=user, conversation_id=conversation_id
)
base_query = base_query.filter(Message.conversation_id == conversation.id)
if include_ids is not None:
base_query = base_query.filter(Message.id.in_(include_ids))
if last_id:
last_message = base_query.filter(Message.id == last_id).first()
if not last_message:
raise LastMessageNotExistsError()
history_messages = (
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
)
else:
history_messages = base_query.order_by(Message.created_at.desc()).limit(fetch_limit).all()
has_more = False
if len(history_messages) > limit:
has_more = True
history_messages = history_messages[:-1]
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod
def create_feedback(
cls,
*,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
):
if not user:
raise ValueError("user cannot be None")
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
if not rating and feedback:
db.session.delete(feedback)
elif rating and feedback:
feedback.rating = rating
feedback.content = content
elif not rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)
db.session.add(feedback)
db.session.commit()
return feedback
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message:
raise MessageNotExistsError()
return message
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
if not user:
raise ValueError("user cannot be None")
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=message.conversation_id, user=user
)
if not conversation:
raise ConversationNotExistsError()
if conversation.status != "normal":
raise ConversationCompletedError()
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
else:
workflow = workflow_service.get_published_workflow(app_model=app_model)
if workflow is None:
return []
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if not app_config.additional_features.suggested_questions_after_answer:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.LLM
)
else:
if not conversation.override_model_configs:
app_model_config = (
db.session.query(AppModelConfig)
.filter(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
.first()
)
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict["provider"],
model_type=ModelType.LLM,
model=app_model_config.model_dict["name"],
)
# get memory of conversation (read-only)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
histories = memory.get_history_prompt_text(
max_token_limit=3000,
message_limit=3,
)
with measure_time() as timer:
questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
)
)
return questions

View File

@@ -0,0 +1,571 @@
import datetime
import json
import logging
from json import JSONDecodeError
from typing import Optional, Union
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ModelCredentialSchema,
ProviderCredentialSchema,
)
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.provider import LoadBalancingModelConfig
logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
enable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
disable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=model_type_enum,
model=model,
)
is_load_balancing_enabled = False
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
)
if provider_configuration.custom_configuration.provider:
# check if the inherit configuration exists,
# inherit is represented for the provider or model custom credentials
inherit_config_exists = False
for load_balancing_config in load_balancing_configs:
if load_balancing_config.name == "__inherit__":
inherit_config_exists = True
break
if not inherit_config_exists:
# Initialize the inherit configuration
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# fetch status and ttl for each config
datas = []
for load_balancing_config in load_balancing_configs:
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
tenant_id=tenant_id,
provider=provider,
model=model,
model_type=model_type_enum,
config_id=load_balancing_config.id,
)
try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get provider credential secret variables
credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
# decrypt credentials
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
)
except ValueError:
pass
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
datas.append(
{
"id": load_balancing_config.id,
"name": load_balancing_config.name,
"credentials": credentials,
"enabled": load_balancing_config.enabled,
"in_cooldown": in_cooldown,
"ttl": ttl,
}
)
return is_load_balancing_enabled, datas
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
"""
Get load balancing configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
return None
try:
if load_balancing_model_config.encrypted_config:
credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType
) -> LoadBalancingModelConfig:
"""
Initialize the inherit configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Initialize the inherit configuration
inherit_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name="__inherit__",
)
db.session.add(inherit_config)
db.session.commit()
return inherit_config
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
) -> None:
"""
Update load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param configs: load balancing configs
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
updated_config_ids = set()
for config in configs:
if not isinstance(config, dict):
raise ValueError("Invalid load balancing config")
config_id = config.get("id")
name = config.get("name")
credentials = config.get("credentials")
enabled = config.get("enabled")
if not name:
raise ValueError("Invalid load balancing config name")
if enabled is None:
raise ValueError("Invalid load balancing config enabled")
# is config exists
if config_id:
config_id = str(config_id)
if config_id not in current_load_balancing_configs_dict:
raise ValueError("Invalid load balancing config id: {}".format(config_id))
updated_config_ids.add(config_id)
load_balancing_config = current_load_balancing_configs_dict[config_id]
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
raise ValueError("Load balancing config name {} already exists".format(name))
if credentials:
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
validate=False,
)
# update load balancing config
load_balancing_config.encrypted_config = json.dumps(credentials)
load_balancing_config.name = name
load_balancing_config.enabled = enabled
load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.name == name:
raise ValueError("Load balancing config name {} already exists".format(name))
if not credentials:
raise ValueError("Invalid load balancing config credentials")
if not isinstance(credentials, dict):
raise ValueError("Invalid load balancing config credentials")
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
)
# create load balancing config
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
)
db.session.add(load_balancing_model_config)
db.session.commit()
# get deleted config ids
deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
for config_id in deleted_config_ids:
db.session.delete(current_load_balancing_configs_dict[config_id])
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
def validate_load_balancing_credentials(
self,
tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
"""
Validate load balancing credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: credentials
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
.first()
)
if not load_balancing_model_config:
raise ValueError(f"Load balancing config {config_id} does not exist.")
# Validate custom provider config
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
)
def _custom_credentials_validate(
self,
tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True,
) -> dict:
"""
Validate custom credentials.
:param tenant_id: workspace id
:param provider_configuration: provider configuration
:param model_type: model type
:param model: model name
:param credentials: credentials
:param load_balancing_model_config: load balancing model config
:param validate: validate credentials
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get provider credential secret variables
provider_credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
if load_balancing_model_config:
try:
# fix origin data
if load_balancing_model_config.encrypted_config:
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
if validate:
model_provider_factory = ModelProviderFactory(tenant_id)
if isinstance(credential_schemas, ModelCredentialSchema):
credentials = model_provider_factory.model_credentials_validate(
provider=provider_configuration.provider.provider,
model_type=model_type,
model=model,
credentials=credentials,
)
else:
credentials = model_provider_factory.provider_credentials_validate(
provider=provider_configuration.provider.provider, credentials=credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
return credentials
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
return provider_configuration.provider.model_credential_schema
elif provider_configuration.provider.provider_credential_schema:
return provider_configuration.provider.provider_credential_schema
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
"""
Clear credentials cache.
:param tenant_id: workspace id
:param config_id: load balancing config id
:return:
"""
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
provider_model_credentials_cache.delete()

View File

@@ -0,0 +1,481 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from models.provider import ProviderType
from services.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
DefaultModelResponse,
ModelWithProviderEntityResponse,
ProviderResponse,
ProviderWithModelsResponse,
SimpleProviderEntityResponse,
SystemConfigurationResponse,
)
logger = logging.getLogger(__name__)
class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
"""
get provider list.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
provider_response = ProviderResponse(
tenant_id=tenant_id,
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations,
),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
"""
get provider models.
For the model provider page,
only supports passing in a single provider to query the list of supported models.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
return [
ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
"""
get provider credentials.
"""
provider_configurations = self.provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
validate provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_configuration.custom_credentials_validate(credentials)
def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
save custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:param credentials: provider credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom provider credentials.
provider_configuration.add_or_update_custom_credentials(credentials)
def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
"""
remove custom provider config.
:param tenant_id: workspace id
:param provider: provider name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]:
"""
get model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get model custom credentials from ProviderModel if exists
return provider_configuration.get_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
)
def model_credentials_validate(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
"""
validate model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Validate model credentials
provider_configuration.custom_model_credentials_validate(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def save_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
"""
save model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Add or update custom model credentials
provider_configuration.add_or_update_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
)
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
"""
remove model credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Remove custom model credentials
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
get models by model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
if model.status != ModelStatus.ACTIVE:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
providers_with_models.append(
ProviderWithModelsResponse(
tenant_id=tenant_id,
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[
ProviderModelWithStatusEntity(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
load_balancing_enabled=model.load_balancing_enabled,
)
for model in models
],
)
)
return providers_with_models
def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
"""
get model parameter rules.
Only supports LLM.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# fetch credentials
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
if not credentials:
return []
model_schema = provider_configuration.get_model_schema(
model_type=ModelType.LLM, model=model, credentials=credentials
)
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
get default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:return:
"""
model_type_enum = ModelType.value_of(model_type)
try:
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
return (
DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
tenant_id=tenant_id,
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
if result
else None
)
except Exception as e:
logger.debug(f"get_default_model_of_model_type error: {e}")
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
"""
update default model of model type.
:param tenant_id: workspace id
:param model_type: model type
:param provider: provider name
:param model: model name
:return:
"""
model_type_enum = ModelType.value_of(model_type)
self.provider_manager.update_default_model_record(
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
)
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
"""
get model provider icon.
:param tenant_id: workspace id
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param lang: language (zh_Hans or en_US)
:return:
"""
model_provider_factory = ModelProviderFactory(tenant_id)
byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
"""
switch preferred provider.
:param tenant_id: workspace id
:param provider: provider name
:param preferred_provider_type: preferred provider type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Convert preferred_provider_type to ProviderType
preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
enable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
disable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))

View File

@@ -0,0 +1,23 @@
from typing import Optional
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
from models.model import App, AppModelConfig
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: Optional[AppModelConfig] = None
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
if not app_model_config:
raise ValueError("app model config not found")
name = app_model_config.sensitive_word_avoidance_dict["type"]
config = app_model_config.sensitive_word_avoidance_dict["config"]
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
return moderation.moderation_for_outputs(text)

View File

@@ -0,0 +1,29 @@
import os
import requests
class OperationService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
return response.json()
@classmethod
def record_utm(cls, tenant_id: str, utm_info: dict):
params = {
"tenant_id": tenant_id,
"utm_source": utm_info.get("utm_source", ""),
"utm_medium": utm_info.get("utm_medium", ""),
"utm_campaign": utm_info.get("utm_campaign", ""),
"utm_content": utm_info.get("utm_content", ""),
"utm_term": utm_info.get("utm_term", ""),
}
return cls._send_request("POST", "/tenant_utms", params=params)

View File

@@ -0,0 +1,199 @@
from typing import Optional
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
class OpsService:
@classmethod
def get_tracing_app_config(cls, app_id: str, tracing_provider: str):
"""
Get tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config_data:
return None
# decrypt_token and obfuscated_token
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
if tracing_provider == "langfuse" and (
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
):
try:
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update(
{
"project_url": "{host}/project/{key}".format(
host=decrypt_tracing_config.get("host"), key=project_key
)
}
)
except Exception:
new_decrypt_tracing_config.update(
{"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))}
)
if tracing_provider == "langsmith" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"})
if tracing_provider == "opik" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()
@classmethod
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
"""
Create tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map and tracing_provider:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
config_class, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
# FIXME: ignore type error
default_config_instance = config_class(**tracing_config) # type: ignore
for key in other_keys: # type: ignore
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)
# api check
if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider):
return {"error": "Invalid Credentials"}
# get project url
if tracing_provider == "langfuse":
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key)
elif tracing_provider in ("langsmith", "opik"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
else:
project_url = None
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if trace_config_data:
return None
# get tenant id
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if project_url:
tracing_config["project_url"] = project_url
trace_config_data = TraceAppConfig(
app_id=app_id,
tracing_provider=tracing_provider,
tracing_config=tracing_config,
)
db.session.add(trace_config_data)
db.session.commit()
return {"result": "success"}
@classmethod
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
"""
Update tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
# check if trace config already exists
current_trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not current_trace_config:
return None
# get tenant id
tenant = db.session.query(App).filter(App.id == app_id).first()
if not tenant:
return None
tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
)
# api check
# decrypt_token
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if not OpsTraceManager.check_trace_config_is_effective(decrypt_tracing_config, tracing_provider):
raise ValueError("Invalid Credentials")
current_trace_config.tracing_config = tracing_config
db.session.commit()
return current_trace_config.to_dict()
@classmethod
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
"""
Delete tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:return:
"""
trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
if not trace_config:
return None
db.session.delete(trace_config)
db.session.commit()
return True

View File

@@ -0,0 +1,185 @@
import json
import logging
import click
from core.entities import DEFAULT_PLUGIN_ID
from models.engine import db
logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table
cls.migrate_db_records("provider_models", "provider_name")
cls.migrate_db_records("provider_orders", "provider_name")
cls.migrate_db_records("tenant_default_models", "provider_name")
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
cls.migrate_db_records("provider_model_settings", "provider_name")
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
cls.migrate_db_records("tool_builtin_providers", "provider")
@classmethod
def migrate_datasets(cls) -> None:
table_name = "datasets"
provider_column_name = "embedding_model_provider"
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
current_iter_count = 0
for i in rs:
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
print(type(retrieval_model))
if record_id in failed_ids:
continue
retrieval_model_changed = False
if retrieval_model:
if (
"reranking_model" in retrieval_model
and "reranking_provider_name" in retrieval_model["reranking_model"]
and retrieval_model["reranking_model"]["reranking_provider_name"]
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
):
click.echo(
click.style(
f"[{processed_count}] Migrating {table_name} {record_id} "
f"(reranking_provider_name: "
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
fg="white",
)
)
retrieval_model["reranking_model"]["reranking_provider_name"] = (
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
)
retrieval_model_changed = True
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
params = {"record_id": record_id}
update_retrieval_model_sql = ""
if retrieval_model and retrieval_model_changed:
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
current_iter_count += 1
processed_count += 1
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)
@classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
current_iter_count = 0
for i in rs:
current_iter_count += 1
processed_count += 1
record_id = str(i.id)
provider_name = str(i.provider_name)
if record_id in failed_ids:
continue
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
where id = :record_id"""
conn.execute(db.text(sql), {"record_id": record_id})
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)

View File

@@ -0,0 +1,121 @@
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.manager.plugin import PluginInstallationManager
class DependenciesAnalysisService:
@classmethod
def analyze_tool_dependency(cls, tool_id: str) -> str:
"""
Analyze the dependency of a tool.
Convert the tool id to the plugin_id
"""
try:
return ToolProviderID(tool_id).plugin_id
except Exception as e:
raise e
@classmethod
def analyze_model_provider_dependency(cls, model_provider_id: str) -> str:
"""
Analyze the dependency of a model provider.
Convert the model provider id to the plugin_id
"""
try:
return ModelProviderID(model_provider_id).plugin_id
except Exception as e:
raise e
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dependencies: list[PluginDependency]) -> list[PluginDependency]:
"""
Check dependencies, returns the leaked dependencies in current workspace
"""
required_plugin_unique_identifiers = []
for dependency in dependencies:
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
manager = PluginInstallationManager()
# get leaked dependencies
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins}
leaked_dependencies = []
for dependency in dependencies:
unique_identifier = dependency.value.plugin_unique_identifier
if unique_identifier in missing_plugin_unique_identifiers:
leaked_dependencies.append(
PluginDependency(
type=dependency.type,
value=dependency.value,
current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier,
)
)
return leaked_dependencies
@classmethod
def generate_dependencies(cls, tenant_id: str, dependencies: list[str]) -> list[PluginDependency]:
"""
Generate dependencies through the list of plugin ids
"""
dependencies = list(set(dependencies))
manager = PluginInstallationManager()
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
result = []
for plugin in plugins:
if plugin.source == PluginInstallationSource.Github:
result.append(
PluginDependency(
type=PluginDependency.Type.Github,
value=PluginDependency.Github(
repo=plugin.meta["repo"],
version=plugin.meta["version"],
package=plugin.meta["package"],
github_plugin_unique_identifier=plugin.plugin_unique_identifier,
),
)
)
elif plugin.source == PluginInstallationSource.Marketplace:
result.append(
PluginDependency(
type=PluginDependency.Type.Marketplace,
value=PluginDependency.Marketplace(
marketplace_plugin_unique_identifier=plugin.plugin_unique_identifier
),
)
)
elif plugin.source == PluginInstallationSource.Package:
result.append(
PluginDependency(
type=PluginDependency.Type.Package,
value=PluginDependency.Package(plugin_unique_identifier=plugin.plugin_unique_identifier),
)
)
elif plugin.source == PluginInstallationSource.Remote:
raise ValueError(
f"You used a remote plugin: {plugin.plugin_unique_identifier} in the app, please remove it first"
" if you want to export the DSL."
)
else:
raise ValueError(f"Unknown plugin source: {plugin.source}")
return result
@classmethod
def generate_latest_dependencies(cls, dependencies: list[str]) -> list[PluginDependency]:
"""
Generate the latest version of dependencies
"""
dependencies = list(set(dependencies))
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
return [
PluginDependency(
type=PluginDependency.Type.Marketplace,
value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=dep.latest_package_identifier),
)
for dep in deps
]

View File

@@ -0,0 +1,66 @@
from core.plugin.manager.endpoint import PluginEndpointManager
class EndpointService:
@classmethod
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
return PluginEndpointManager().create_endpoint(
tenant_id=tenant_id,
user_id=user_id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
@classmethod
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
return PluginEndpointManager().list_endpoints(
tenant_id=tenant_id,
user_id=user_id,
page=page,
page_size=page_size,
)
@classmethod
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
return PluginEndpointManager().list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
)
@classmethod
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
return PluginEndpointManager().update_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
@classmethod
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointManager().delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
@classmethod
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointManager().enable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
@classmethod
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
return PluginEndpointManager().disable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)

View File

@@ -0,0 +1,505 @@
import datetime
import json
import logging
import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4
import click
import tqdm
from flask import Flask, current_app
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolProviderType
from models.account import Tenant
from models.engine import db
from models.model import App, AppMode, AppModelConfig
from models.tools import BuiltinToolProvider
from models.workflow import Workflow
logger = logging.getLogger(__name__)
excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int) -> None:
"""
Migrate plugin.
"""
from threading import Lock
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
ended_at = datetime.datetime.now()
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
with Session(db.engine) as session:
total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
handled_tenant_count = 0
file_lock = Lock()
counter_lock = Lock()
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
with flask_app.app_context():
nonlocal handled_tenant_count
try:
plugins = cls.extract_installed_plugin_ids(tenant_id)
# Use lock when writing to file
with file_lock:
with open(filepath, "a") as f:
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
# Use lock when updating counter
with counter_lock:
nonlocal handled_tenant_count
handled_tenant_count += 1
click.echo(
click.style(
f"[{datetime.datetime.now()}] "
f"Processed {handled_tenant_count} tenants "
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
f"{handled_tenant_count}/{total_tenant_count}",
fg="green",
)
)
except Exception:
logger.exception(f"Failed to process tenant {tenant_id}")
futures = []
while current_time < ended_at:
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [
datetime.timedelta(days=1),
datetime.timedelta(hours=12),
datetime.timedelta(hours=6),
datetime.timedelta(hours=3),
datetime.timedelta(hours=1),
]
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
interval = test_interval
break
else:
# If all intervals have too many tenants, use minimum interval
interval = datetime.timedelta(hours=1)
# Adjust interval to target ~100 tenants per batch
if tenant_count > 0:
# Scale interval based on ratio to target count
interval = min(
datetime.timedelta(days=1), # Max 1 day
max(
datetime.timedelta(hours=1), # Min 1 hour
interval * (100 / tenant_count), # Scale to target 100
),
)
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
tenants = []
for row in rs:
tenant_id = str(row.id)
try:
tenants.append(tenant_id)
except Exception:
logger.exception(f"Failed to process tenant {tenant_id}")
continue
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
tenant_id,
)
)
current_time = batch_end
# wait for all threads to finish
for future in futures:
future.result()
@classmethod
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
"""
Extract installed plugin ids.
"""
tools = cls.extract_tool_tables(tenant_id)
models = cls.extract_model_tables(tenant_id)
workflows = cls.extract_workflow_tables(tenant_id)
apps = cls.extract_app_tables(tenant_id)
return list({*tools, *models, *workflows, *apps})
@classmethod
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract model tables.
"""
models: list[str] = []
table_pairs = [
("providers", "provider_name"),
("provider_models", "provider_name"),
("provider_orders", "provider_name"),
("tenant_default_models", "provider_name"),
("tenant_preferred_model_providers", "provider_name"),
("provider_model_settings", "provider_name"),
("load_balancing_model_configs", "provider_name"),
]
for table, column in table_pairs:
models.extend(cls.extract_model_table(tenant_id, table, column))
# duplicate models
models = list(set(models))
return models
@classmethod
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
"""
Extract model table.
"""
with Session(db.engine) as session:
rs = session.execute(
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
)
result = []
for row in rs:
provider_name = str(row[0])
result.append(ModelProviderID(provider_name).plugin_id)
return result
@classmethod
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract tool tables.
"""
with Session(db.engine) as session:
rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
result = []
for row in rs:
result.append(ToolProviderID(row.provider).plugin_id)
return result
@classmethod
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract workflow tables, only ToolNode is required.
"""
with Session(db.engine) as session:
rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
result = []
for row in rs:
graph = row.graph_dict
# get nodes
nodes = graph.get("nodes", [])
for node in nodes:
data = node.get("data", {})
if data.get("type") == "tool":
provider_name = data.get("provider_name")
provider_type = data.get("provider_type")
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
result.append(ToolProviderID(provider_name).plugin_id)
return result
@classmethod
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract app tables.
"""
with Session(db.engine) as session:
apps = session.query(App).filter(App.tenant_id == tenant_id).all()
if not apps:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
]
rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict
if "tools" in agent_config and isinstance(agent_config["tools"], list):
for tool in agent_config["tools"]:
if isinstance(tool, dict):
try:
tool_entity = AgentToolEntity(**tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
and tool_entity.provider_id not in excluded_providers
):
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
except Exception:
logger.exception(f"Failed to process tool {tool}")
continue
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
"""
Fetch plugin unique identifier using plugin id.
"""
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
if not plugin_manifest:
return None
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
"""
Extract unique plugins.
"""
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
@classmethod
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
plugins: dict[str, str] = {}
plugin_ids = []
plugin_not_exist = []
logger.info(f"Extracting unique plugins from {extracted_plugins}")
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
def fetch_plugin(plugin_id):
try:
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
if unique_identifier:
plugins[plugin_id] = unique_identifier
else:
plugin_not_exist.append(plugin_id)
except Exception:
logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}")
plugin_not_exist.append(plugin_id)
with ThreadPoolExecutor(max_workers=10) as executor:
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
"""
Install plugins.
"""
manager = PluginInstallationManager()
plugins = cls.extract_unique_plugins(extracted_plugins)
not_installed = []
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
thread_pool = ThreadPoolExecutor(max_workers=workers)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
logger.info(f"Installing installed_plugins_ids:")
logger.info(installed_plugins_ids)
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = plugin_ids[i : i + 64]
batch_plugin_identifiers = [
plugins["plugins"][plugin_id]
for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
]
manager.install_from_identifiers(
tenant_id,
batch_plugin_identifiers,
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
with open(extracted_plugins) as f:
"""
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
thread_pool.submit(install, tenant_id, plugin_ids)
thread_pool.shutdown(wait=True)
logger.info("Uninstall plugins")
# get installation
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
Path(output_file).write_text(
json.dumps(
{
"not_installed": not_installed,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
) -> Mapping[str, Any]:
"""
Install plugins for a tenant.
"""
manager = PluginInstallationManager()
# download all the plugins and upload
thread_pool = ThreadPoolExecutor(max_workers=10)
futures = []
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
if not plugin_package:
raise Exception(f"Failed to download plugin {plugin_identifier}")
# upload
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
# Wait for all downloads to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
thread_pool.shutdown(wait=True)
success = []
failed = []
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
# at most 8 plugins one batch
for i in range(0, len(plugin_identifiers_map), 8):
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
try:
response = manager.install_from_identifiers(
tenant_id=tenant_id,
identifiers=batch_plugin_identifiers,
source=PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
except Exception:
# add to failed
failed.extend(batch_plugin_identifiers)
continue
if response.all_installed:
success.extend(batch_plugin_identifiers)
continue
task_id = response.task_id
done = False
while not done:
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
for plugin in status.plugins:
if plugin.status == PluginInstallTaskStatus.Success:
success.append(reverse_map[plugin.plugin_unique_identifier])
else:
failed.append(reverse_map[plugin.plugin_unique_identifier])
logger.error(
f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
)
done = True
else:
time.sleep(1)
return {"success": success, "failed": failed}

View File

@@ -0,0 +1,34 @@
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
def change_permission(
tenant_id: str,
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
)
if not permission:
permission = TenantPluginPermission(
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
)
session.add(permission)
else:
permission.install_permission = install_permission
permission.debug_permission = debug_permission
session.commit()
return True

View File

@@ -0,0 +1,364 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import Optional
from pydantic import BaseModel
from configs import dify_config
from core.helper import marketplace
from core.helper.download import download_with_size_limit
from core.helper.marketplace import download_plugin_pkg
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
PluginDeclaration,
PluginEntity,
PluginInstallation,
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
from core.plugin.manager.asset import PluginAssetManager
from core.plugin.manager.debugging import PluginDebuggingManager
from core.plugin.manager.plugin import PluginInstallationManager
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PluginService:
class LatestPluginCache(BaseModel):
plugin_id: str
version: str
unique_identifier: str
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
"""
Fetch the latest plugin version
"""
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
try:
cache_not_exists = []
# Try to get from Redis first
for plugin_id in plugin_ids:
cached_data = redis_client.get(f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}")
if cached_data:
result[plugin_id] = PluginService.LatestPluginCache.model_validate_json(cached_data)
else:
cache_not_exists.append(plugin_id)
if cache_not_exists:
manifests = {
manifest.plugin_id: manifest
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
}
for plugin_id, manifest in manifests.items():
latest_plugin = PluginService.LatestPluginCache(
plugin_id=plugin_id,
version=manifest.latest_version,
unique_identifier=manifest.latest_package_identifier,
)
# Store in Redis
redis_client.setex(
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
PluginService.REDIS_TTL,
latest_plugin.model_dump_json(),
)
result[plugin_id] = latest_plugin
# pop plugin_id from cache_not_exists
cache_not_exists.remove(plugin_id)
for plugin_id in cache_not_exists:
result[plugin_id] = None
return result
except Exception:
logger.exception("failed to fetch latest plugin version")
return result
@staticmethod
def get_debugging_key(tenant_id: str) -> str:
"""
get the debugging key of the tenant
"""
manager = PluginDebuggingManager()
return manager.get_debugging_key(tenant_id)
@staticmethod
def list(tenant_id: str) -> list[PluginEntity]:
"""
list all plugins of the tenant
"""
manager = PluginInstallationManager()
plugins = manager.list_plugins(tenant_id)
plugin_ids = [plugin.plugin_id for plugin in plugins if plugin.source == PluginInstallationSource.Marketplace]
try:
manifests = PluginService.fetch_latest_plugin_version(plugin_ids)
except Exception:
manifests = {}
logger.exception("failed to fetch plugin manifests")
for plugin in plugins:
if plugin.source == PluginInstallationSource.Marketplace:
if plugin.plugin_id in manifests:
latest_plugin_cache = manifests[plugin.plugin_id]
if latest_plugin_cache:
# set latest_version
plugin.latest_version = latest_plugin_cache.version
plugin.latest_unique_identifier = latest_plugin_cache.unique_identifier
return plugins
@staticmethod
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
"""
List plugin installations from ids
"""
manager = PluginInstallationManager()
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
@staticmethod
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
"""
get the asset file of the plugin
"""
manager = PluginAssetManager()
# guess mime type
mime_type, _ = guess_type(asset_file)
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
@staticmethod
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
check if the plugin unique identifier is already installed by other tenant
"""
manager = PluginInstallationManager()
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
@staticmethod
def fetch_plugin_manifest(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
"""
Fetch plugin manifest
"""
manager = PluginInstallationManager()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""
Fetch plugin installation tasks
"""
manager = PluginInstallationManager()
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
@staticmethod
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
manager = PluginInstallationManager()
return manager.fetch_plugin_installation_task(tenant_id, task_id)
@staticmethod
def delete_install_task(tenant_id: str, task_id: str) -> bool:
"""
Delete a plugin installation task
"""
manager = PluginInstallationManager()
return manager.delete_plugin_installation_task(tenant_id, task_id)
@staticmethod
def delete_all_install_task_items(
tenant_id: str,
) -> bool:
"""
Delete all plugin installation task items
"""
manager = PluginInstallationManager()
return manager.delete_all_plugin_installation_task_items(tenant_id)
@staticmethod
def delete_install_task_item(tenant_id: str, task_id: str, identifier: str) -> bool:
"""
Delete a plugin installation task item
"""
manager = PluginInstallationManager()
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
@staticmethod
def upgrade_plugin_with_marketplace(
tenant_id: str, original_plugin_unique_identifier: str, new_plugin_unique_identifier: str
):
"""
Upgrade plugin with marketplace
"""
if original_plugin_unique_identifier == new_plugin_unique_identifier:
raise ValueError("you should not upgrade plugin with the same plugin")
# check if plugin pkg is already downloaded
manager = PluginInstallationManager()
try:
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
# already downloaded, skip, and record install event
marketplace.record_install_plugin_event(new_plugin_unique_identifier)
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(new_plugin_unique_identifier)
manager.upload_pkg(tenant_id, pkg, verify_signature=False)
return manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_plugin_unique_identifier,
},
)
@staticmethod
def upgrade_plugin_with_github(
tenant_id: str,
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
repo: str,
version: str,
package: str,
):
"""
Upgrade plugin with github
"""
manager = PluginInstallationManager()
return manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
PluginInstallationSource.Github,
{
"repo": repo,
"version": version,
"package": package,
},
)
@staticmethod
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse:
"""
Upload plugin package files
returns: plugin_unique_identifier
"""
manager = PluginInstallationManager()
return manager.upload_pkg(tenant_id, pkg, verify_signature)
@staticmethod
def upload_pkg_from_github(
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
) -> PluginUploadResponse:
"""
Install plugin from github release package files,
returns plugin_unique_identifier
"""
pkg = download_with_size_limit(
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
)
manager = PluginInstallationManager()
return manager.upload_pkg(
tenant_id,
pkg,
verify_signature,
)
@staticmethod
def upload_bundle(
tenant_id: str, bundle: bytes, verify_signature: bool = False
) -> Sequence[PluginBundleDependency]:
"""
Upload a plugin bundle and return the dependencies.
"""
manager = PluginInstallationManager()
return manager.upload_bundle(tenant_id, bundle, verify_signature)
@staticmethod
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
manager = PluginInstallationManager()
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Package,
[{}],
)
@staticmethod
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
"""
Install plugin from github release package files,
returns plugin_unique_identifier
"""
manager = PluginInstallationManager()
return manager.install_from_identifiers(
tenant_id,
[plugin_unique_identifier],
PluginInstallationSource.Github,
[
{
"repo": repo,
"version": version,
"package": package,
}
],
)
@staticmethod
def install_from_marketplace_pkg(
tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False
):
"""
Install plugin from marketplace package files,
returns installation task id
"""
manager = PluginInstallationManager()
# check if already downloaded
for plugin_unique_identifier in plugin_unique_identifiers:
try:
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
# already downloaded, skip
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(plugin_unique_identifier)
manager.upload_pkg(tenant_id, pkg, verify_signature)
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Marketplace,
[
{
"plugin_unique_identifier": plugin_unique_identifier,
}
for plugin_unique_identifier in plugin_unique_identifiers
],
)
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstallationManager()
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
"""
Check if the tools exist
"""
manager = PluginInstallationManager()
return manager.check_tools_existence(tenant_id, provider_ids)

View File

@@ -0,0 +1,64 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: Optional[dict] = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
def get_recommended_apps_and_categories(self, language: str) -> dict:
result = self.fetch_recommended_apps_from_builtin(language)
return result
def get_recommend_app_detail(self, app_id: str):
result = self.fetch_recommended_app_detail_from_builtin(app_id)
return result
@classmethod
def _get_builtin_data(cls) -> dict:
"""
Get builtin data.
:return:
"""
if cls.builtin_data:
return cls.builtin_data
root_path = current_app.root_path
cls.builtin_data = json.loads(
Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
)
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
"""
Fetch recommended apps from builtin.
:param language: language
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("app_details", {}).get(app_id)

View File

@@ -0,0 +1,105 @@
from typing import Optional
from constants.languages import languages
from extensions.ext_database import db
from models.model import App, RecommendedApp
from services.app_dsl_service import AppDslService
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str) -> dict:
result = self.fetch_recommended_apps_from_db(language)
return result
def get_recommend_app_detail(self, app_id: str):
result = self.fetch_recommended_app_detail_from_db(app_id)
return result
def get_type(self) -> str:
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
"""
Fetch recommended apps from db.
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
categories = set()
recommended_apps_result = []
for recommended_app in recommended_apps:
app = recommended_app.app
if not app or not app.is_public:
continue
site = app.site
if not site:
continue
recommended_app_result = {
"id": recommended_app.id,
"app": recommended_app.app,
"app_id": recommended_app.app_id,
"description": site.description,
"copyright": site.copyright,
"privacy_policy": site.privacy_policy,
"custom_disclaimer": site.custom_disclaimer,
"category": recommended_app.category,
"position": recommended_app.position,
"is_listed": recommended_app.is_listed,
}
recommended_apps_result.append(recommended_app_result)
categories.add(recommended_app.category)
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
"""
Fetch recommended app detail from db.
:param app_id: App ID
:return:
"""
# is in public recommended list
recommended_app = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first()
)
if not recommended_app:
return None
# get app detail
app_model = db.session.query(App).filter(App.id == app_id).first()
if not app_model or not app_model.is_public:
return None
return {
"id": app_model.id,
"name": app_model.name,
"icon": app_model.icon,
"icon_background": app_model.icon_background,
"mode": app_model.mode,
"export_data": AppDslService.export_dsl(app_model=app_model),
}

View File

@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
class RecommendAppRetrievalBase(ABC):
"""Interface for recommend app retrieval."""
@abstractmethod
def get_recommended_apps_and_categories(self, language: str) -> dict:
raise NotImplementedError
@abstractmethod
def get_recommend_app_detail(self, app_id: str):
raise NotImplementedError
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError

View File

@@ -0,0 +1,23 @@
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
class RecommendAppRetrievalFactory:
@staticmethod
def get_recommend_app_factory(mode: str) -> type[RecommendAppRetrievalBase]:
match mode:
case RecommendAppType.REMOTE:
return RemoteRecommendAppRetrieval
case RecommendAppType.DATABASE:
return DatabaseRecommendAppRetrieval
case RecommendAppType.BUILDIN:
return BuildInRecommendAppRetrieval
case _:
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
@staticmethod
def get_buildin_recommend_app_retrieval():
return BuildInRecommendAppRetrieval

View File

@@ -0,0 +1,7 @@
from enum import StrEnum
class RecommendAppType(StrEnum):
REMOTE = "remote"
BUILDIN = "builtin"
DATABASE = "db"

View File

@@ -0,0 +1,71 @@
import logging
from typing import Optional
import requests
from configs import dify_config
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
logger = logging.getLogger(__name__)
class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from dify official
"""
def get_recommend_app_detail(self, app_id: str):
try:
result = self.fetch_recommended_app_detail_from_dify_official(app_id)
except Exception as e:
logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.")
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
return result
def get_recommended_apps_and_categories(self, language: str) -> dict:
try:
result = self.fetch_recommended_apps_from_dify_official(language)
except Exception as e:
logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language)
return result
def get_type(self) -> str:
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID
:return:
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/apps/{app_id}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
return None
data: dict = response.json()
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
"""
Fetch recommended apps from dify official.
:param language: language
:return:
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/apps?language={language}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
result: dict = response.json()
if "categories" in result:
result["categories"] = sorted(result["categories"])
return result

View File

@@ -0,0 +1,37 @@
from typing import Optional
from configs import dify_config
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str) -> dict:
"""
Get recommended apps and categories.
:param language: language
:return:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result = retrieval_instance.get_recommended_apps_and_categories(language)
if not result.get("recommended_apps") and language != "en-US":
result = (
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
"en-US"
)
)
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
"""
Get recommend app detail.
:param app_id: app id
:return:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
return result

View File

@@ -0,0 +1,83 @@
from typing import Optional, Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.model import App, EndUser
from models.web import SavedMessage
from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.order_by(SavedMessage.created_at.desc())
.all()
)
message_ids = [sm.message_id for sm in saved_messages]
return MessageService.pagination_by_last_id(
app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
)
if saved_message:
return
message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id)
saved_message = SavedMessage(
app_id=app_model.id,
message_id=message.id,
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by=user.id,
)
db.session.add(saved_message)
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
if not user:
return
saved_message = (
db.session.query(SavedMessage)
.filter(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
)
.first()
)
if not saved_message:
return
db.session.delete(saved_message)
db.session.commit()

View File

@@ -0,0 +1,158 @@
import uuid
from typing import Optional
from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id)
results: list = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = (
db.session.query(Tag)
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = (
db.session.query(TagBinding.target_id)
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type,
)
.all()
)
return tags or []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args["name"],
type=args["type"],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)
db.session.add(tag)
db.session.commit()
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args["name"]
db.session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args["type"], args["target_id"])
# save tag binding
for tag_id in args["tag_ids"]:
tag_binding = (
db.session.query(TagBinding)
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first()
)
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args["target_id"],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id,
)
db.session.add(new_tag_binding)
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args["type"], args["target_id"])
# delete tag binding
tag_bindings = (
db.session.query(TagBinding)
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first()
)
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == "knowledge":
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found")
elif type == "app":
app = (
db.session.query(App)
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first()
)
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")

View File

@@ -0,0 +1,483 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
credentials_schema = [
ProviderConfig(
name="auth_type",
type=ProviderConfig.Type.SELECT,
required=True,
default="none",
options=[
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
],
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
),
ProviderConfig(
name="api_key_header",
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key headerX-API-KEY"),
default="api_key",
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
),
ProviderConfig(
name="api_key_value",
type=ProviderConfig.Type.TEXT_INPUT,
required=False,
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
default="",
),
]
return cast(
Mapping,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
"""
convert schema to tool bundles
:return: the list of tool bundles, description
"""
try:
return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def create_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is not None:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError("the number of apis should be less than 100")
# create db provider
db_provider = ApiToolProvider(
tenant_id=tenant_id,
user_id=user_id,
name=provider_name,
icon=json.dumps(icon),
schema=schema,
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db.session.add(db_provider)
db.session.commit()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
"""
get api tool provider remote schema
"""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
"Accept": "*/*",
}
try:
response = get(url, headers=headers, timeout=10)
if response.status_code != 200:
raise ValueError(f"Got status code {response.status_code}")
schema = response.text
# try to parse schema, avoid SSRF attack
ApiToolManageService.parser_api_schema(schema)
except Exception:
logger.exception("parse api schema error")
raise ValueError("invalid schema, please check the url you provided")
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
"""
list api tool provider tools
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool_bundle,
tenant_id=tenant_id,
labels=labels,
)
for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str,
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
.first()
)
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
"""
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
"""
get api tool provider
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@staticmethod
def test_api_tool_preview(
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str,
):
"""
test api tool before adding api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema_type}")
try:
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
except Exception:
raise ValueError("invalid schema")
# get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
db_provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
.first()
)
if not db_provider:
# create a fake db provider
db_provider = ApiToolProvider(
tenant_id="",
user_id="",
name="",
icon="",
schema=schema,
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)
if "auth_type" not in credentials:
raise ValueError("auth_type is required")
# get auth type, none or api key
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# decrypt credentials
if db_provider.id:
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
try:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials,
tenant_id=tenant_id,
)
)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[ToolProviderApiEntity] = []
for provider in db_providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller, db_provider=provider, decrypt_credentials=True
)
user_provider.labels = labels
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_provider)
tools = provider_controller.get_tools(tenant_id=tenant_id)
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
)
)
result.append(user_provider)
return result

View File

@@ -0,0 +1,330 @@
import json
import logging
from pathlib import Path
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
list builtin tool provider tools
:param user_id: the id of the user
:param tenant_id: the id of the tenant
:param provider: the name of the provider
:return: the list of tools
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=builtin_provider,
decrypt_credentials=True,
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
"""
list builtin provider credentials schema
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
@staticmethod
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
raise ValueError(str(e))
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
db.session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
# delete cache
tool_configuration.delete_tool_credentials_cache()
db.session.commit()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
"""
get builtin tool provider credentials
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
if provider_obj is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
@staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider_obj)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration.delete_tool_credentials_cache()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
"""
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
icon_bytes = Path(icon_path).read_bytes()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = GenericProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.first()
)
else:
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.first()
)
if provider_obj is None:
return None
provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
return provider_obj
except Exception:
# it's an old provider without organization
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
)
.first()
)

View File

@@ -0,0 +1,8 @@
from core.tools.entities.tool_entities import ToolLabel
from core.tools.entities.values import default_tool_labels
class ToolLabelsService:
@classmethod
def list_tool_labels(cls) -> list[ToolLabel]:
return default_tool_labels

View File

@@ -0,0 +1,26 @@
import logging
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""
list tool providers
:return: the list of tool providers
"""
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
for provider in providers:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
result = [provider.to_dict() for provider in providers]
return result

View File

@@ -0,0 +1,305 @@
import json
import logging
from typing import Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
)
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
@classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
"""
get tool provider icon url
"""
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
)
if provider_type == ToolProviderType.BUILT_IN.value:
return str(url_prefix / "builtin" / provider_name / "icon")
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
if isinstance(icon, str):
return cast(dict, json.loads(icon))
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
"""
repack provider
:param provider: the provider dict
"""
if isinstance(provider, dict) and "icon" in provider:
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id:
if isinstance(provider.icon, str):
provider.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
else:
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
result = ToolProviderApiEntity(
id=provider_controller.entity.identity.name,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
label=provider_controller.entity.identity.label,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
is_team_authorization=False,
plugin_id=None,
tools=[],
labels=provider_controller.tool_labels,
)
if isinstance(provider_controller, PluginToolProviderController):
result.plugin_id = provider_controller.plugin_id
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
for name, value in schema.items():
if result.masked_credentials:
result.masked_credentials[name] = ""
# check if the provider need credentials
if not provider_controller.need_credentials:
result.is_team_authorization = True
result.allow_delete = False
elif db_provider:
result.is_team_authorization = True
if decrypt_credentials:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
return result
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
) -> ApiToolProviderController:
"""
convert provider controller to user provider
"""
# package tool provider controller
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY
if db_provider.credentials["auth_type"] == "api_key"
else ApiProviderAuthType.NONE,
)
return controller
@staticmethod
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
"""
convert provider controller to provider
"""
return WorkflowToolProviderController.from_db(db_provider)
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
):
"""
convert provider controller to user provider
"""
return ToolProviderApiEntity(
id=provider_controller.provider_id,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
label=provider_controller.entity.identity.label,
type=ToolProviderType.WORKFLOW,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
plugin_unique_identifier=None,
tools=[],
labels=labels or [],
)
@classmethod
def api_provider_to_user_provider(
cls,
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] | None = None,
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
username = "Anonymous"
if db_provider.user is None:
raise ValueError(f"user is None for api provider {db_provider.id}")
try:
user = db_provider.user
if not user:
raise ValueError("user not found")
username = user.name
except Exception:
logger.exception(f"failed to get user name for api provider {db_provider.id}")
# add provider into providers
credentials = db_provider.credentials
result = ToolProviderApiEntity(
id=db_provider.id,
author=username,
name=db_provider.name,
description=I18nObject(
en_US=db_provider.description,
zh_Hans=db_provider.description,
),
icon=db_provider.icon,
label=I18nObject(
en_US=db_provider.name,
zh_Hans=db_provider.name,
),
type=ToolProviderType.API,
plugin_id=None,
plugin_unique_identifier=None,
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or [],
)
if decrypt_credentials:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
return result
@staticmethod
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
tenant_id=tenant_id,
)
)
# get tool parameters
parameters = tool.entity.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return ToolApiEntity(
author=tool.entity.identity.author,
name=tool.entity.identity.name,
label=tool.entity.identity.label,
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
output_schema=tool.entity.output_schema,
parameters=current_parameters,
labels=labels or [],
)
if isinstance(tool, ApiToolBundle):
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels or [],
)

View File

@@ -0,0 +1,339 @@
import json
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.tools.tools_transform_service import ToolTransformService
class WorkflowToolManageService:
"""
Service class for managing workflow tools.
"""
@staticmethod
def create_workflow_tool(
*,
user_id: str,
tenant_id: str,
workflow_app_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
def update_workflow_tool(
cls,
user_id: str,
tenant_id: str,
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
description: str,
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
"""
Update a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_tool_id: workflow tool id
:param name: name
:param label: label
:param icon: icon
:param description: description
:param parameters: parameters
:param privacy_policy: privacy policy
:param labels: labels
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow | None = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
return {"result": "success"}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
List workflow tools.
:param user_id: the user id
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
except Exception:
# skip deleted tools
pass
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
result = []
for tool in tools:
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(tenant_id)[0],
labels=labels.get(tool.provider_id, []),
tenant_id=tenant_id,
)
]
result.append(user_tool_provider)
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
"""
Delete a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
"""
db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.commit()
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
"""
Get a workflow tool.
:db_tool: the database tool
:return: the tool
"""
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
workflow = workflow_app.workflow
if not workflow:
raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {db_tool.id} not found")
return {
"name": db_tool.name,
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
),
"synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
"""
List workflow tool provider tools.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
)
]

View File

@@ -0,0 +1,217 @@
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task
# format new index
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)
# update keyword index
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
# save keyword index
if keywords and len(keywords) > 0:
keyword.add_texts([document], keywords_list=[keywords])
else:
keyword.add_texts([document])
@classmethod
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: DatasetDocument,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
# generate child chunks
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
)
# save child chunks
if documents and documents[0].children:
index_processor.load(dataset, documents)
for position, child_chunk in enumerate(documents[0].children, start=1):
child_segment = ChildChunk(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=dataset_document.id,
segment_id=segment.id,
position=position,
index_node_id=child_chunk.metadata["doc_id"],
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
db.session.commit()
@classmethod
def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
child_document = Document(
page_content=child_segment.content,
metadata={
"doc_id": child_segment.index_node_id,
"doc_hash": child_segment.index_node_hash,
"document_id": child_segment.document_id,
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@classmethod
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:
new_child_document = Document(
page_content=new_child_chunk.content,
metadata={
"doc_id": new_child_chunk.index_node_id,
"doc_hash": new_child_chunk.index_node_hash,
"document_id": new_child_chunk.document_id,
"dataset_id": new_child_chunk.dataset_id,
},
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
child_document = Document(
page_content=update_child_chunk.content,
metadata={
"doc_id": update_child_chunk.index_node_id,
"doc_hash": update_child_chunk.index_node_hash,
"document_id": update_child_chunk.document_id,
"dataset_id": update_child_chunk.dataset_id,
},
)
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
vector.delete_by_ids(delete_node_ids)
if documents:
vector.add_texts(documents, duplicate_check=True)
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])

View File

@@ -0,0 +1,113 @@
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.model import App, EndUser
from models.web import PinnedConversation
from services.conversation_service import ConversationService
class WebConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
include_ids = None
exclude_ids = None
if pinned is not None and user:
stmt = (
select(PinnedConversation.conversation_id)
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.order_by(PinnedConversation.created_at.desc())
)
pinned_conversation_ids = session.scalars(stmt).all()
if pinned:
include_ids = pinned_conversation_ids
else:
exclude_ids = pinned_conversation_ids
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=user,
last_id=last_id,
limit=limit,
invoke_from=invoke_from,
include_ids=include_ids,
exclude_ids=exclude_ids,
sort_by=sort_by,
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
)
if pinned_conversation:
return
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
pinned_conversation = PinnedConversation(
app_id=app_model.id,
conversation_id=conversation.id,
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by=user.id,
)
db.session.add(pinned_conversation)
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
if not user:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
)
if not pinned_conversation:
return
db.session.delete(pinned_conversation)
db.session.commit()

View File

@@ -0,0 +1,227 @@
import datetime
import json
from typing import Any
import requests
from flask_login import current_user # type: ignore
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if "url" not in args or not args["url"]:
raise ValueError("url is required")
if "options" not in args or not args["options"]:
raise ValueError("options is required")
if "limit" not in args["options"] or not args["options"]["limit"]:
raise ValueError("limit is required")
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get("provider", "")
url = args.get("url")
options = args.get("options", "")
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
crawl_sub_pages = options.get("crawl_sub_pages", False)
only_main_content = options.get("only_main_content", False)
if not crawl_sub_pages:
params = {
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": only_main_content},
}
else:
includes = options.get("includes").split(",") if options.get("includes") else []
excludes = options.get("excludes").split(",") if options.get("excludes") else []
params = {
"includePaths": includes,
"excludePaths": excludes,
"limit": options.get("limit", 1),
"scrapeOptions": {"onlyMainContent": only_main_content},
}
if options.get("max_depth"):
params["maxDepth"] = options.get("max_depth")
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f"website_crawl_{job_id}"
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {"status": "active", "job_id": job_id}
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
crawl_sub_pages = options.get("crawl_sub_pages", False)
if not crawl_sub_pages:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": url,
"maxPages": options.get("limit", 1),
"useSitemap": options.get("use_sitemap", True),
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
else:
raise ValueError("Invalid provider")
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
"status": result.get("status", "active"),
"job_id": job_id,
"total": result.get("total", 0),
"current": result.get("current", 0),
"data": result.get("data", []),
}
if crawl_status_data["status"] == "completed":
website_crawl_time_cache_key = f"website_crawl_{job_id}"
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
crawl_status_data = {
"status": data.get("status", "active"),
"job_id": job_id,
"total": len(data.get("urls", [])),
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
"data": [],
"time_consuming": data.get("duration", 0) / 1000,
}
if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
formatted_data = [
{
"title": item.get("data", {}).get("title"),
"source_url": item.get("data", {}).get("url"),
"description": item.get("data", {}).get("description"),
"markdown": item.get("data", {}).get("content"),
}
for item in data.get("processed", {}).values()
]
crawl_status_data["data"] = formatted_data
else:
raise ValueError("Invalid provider")
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
data: Any
if provider == "firecrawl":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
d = storage.load_once(file_key)
if d:
data = json.loads(d.decode("utf-8"))
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed":
raise ValueError("Crawl job is not completed")
data = result.get("data")
if data:
for item in data:
if item.get("source_url") == url:
return dict(item)
return None
elif provider == "jinareader":
if not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return dict(response.json().get("data", {}))
else:
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
if data.get("status") != "completed":
raise ValueError("Crawl job is not completed")
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
for item in data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
return dict(item.get("data", {}))
return None
else:
raise ValueError("Invalid provider")
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
params = {"onlyMainContent": only_main_content}
result = firecrawl_app.scrape_url(url, params)
return result
else:
raise ValueError("Invalid provider")

View File

@@ -0,0 +1,630 @@
import json
from typing import Any, Optional
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
EasyUIBasedAppConfig,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
)
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.file.models import FileUploadConfig
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
from models.account import Account
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import App, AppMode, AppModelConfig
from models.workflow import Workflow, WorkflowType
class WorkflowConverter:
"""
App Convert to Workflow Mode
"""
def convert_to_workflow(
self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str
):
"""
Convert app to workflow
- basic mode of chatbot app
- expert mode of chatbot app
- completion app
:param app_model: App instance
:param account: Account
:param name: new app name
:param icon: new app icon
:param icon_type: new app icon type
:param icon_background: new app icon background
:return: new App instance
"""
# convert app model config
if not app_model.app_model_config:
raise ValueError("App model config is required")
workflow = self.convert_app_model_config_to_workflow(
app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id
)
# create new app
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
new_app.enable_site = app_model.enable_site
new_app.enable_api = app_model.enable_api
new_app.api_rpm = app_model.api_rpm
new_app.api_rph = app_model.api_rph
new_app.is_demo = False
new_app.is_public = app_model.is_public
new_app.created_by = account.id
new_app.updated_by = account.id
db.session.add(new_app)
db.session.flush()
db.session.commit()
workflow.app_id = new_app.id
db.session.commit()
app_was_created.send(new_app, account=account)
return new_app
def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str):
"""
Convert app model config to workflow mode
:param app_model: App instance
:param app_model_config: AppModelConfig instance
:param account_id: Account ID
"""
# get new app mode
new_app_mode = self._get_new_app_mode(app_model)
# convert app model config
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
graph: dict[str, Any] = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
# - model_config -> llm
# - prompt_template -> llm
# - file_upload -> llm
# - external_data_variables -> http-request
# - dataset -> knowledge-retrieval
# - show_retrieve_source -> knowledge-retrieval
# convert to start node
start_node = self._convert_to_start_node(variables=app_config.variables)
graph["nodes"].append(start_node)
# convert to http request node
external_data_variable_node_mapping: dict[str, str] = {}
if app_config.external_data_variables:
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
app_model=app_model,
variables=app_config.variables,
external_data_variables=app_config.external_data_variables,
)
for http_request_node in http_request_nodes:
graph = self._append_node(graph, http_request_node)
# convert to knowledge retrieval node
if app_config.dataset:
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model
)
if knowledge_retrieval_node:
graph = self._append_node(graph, knowledge_retrieval_node)
# convert to llm node
llm_node = self._convert_to_llm_node(
original_app_mode=AppMode.value_of(app_model.mode),
new_app_mode=new_app_mode,
graph=graph,
model_config=app_config.model,
prompt_template=app_config.prompt_template,
file_upload=app_config.additional_features.file_upload,
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
graph = self._append_node(graph, llm_node)
if new_app_mode == AppMode.WORKFLOW:
# convert to end node by app mode
end_node = self._convert_to_end_node()
graph = self._append_node(graph, end_node)
else:
answer_node = self._convert_to_answer_node()
graph = self._append_node(graph, answer_node)
app_model_config_dict = app_config.app_model_config_dict
# features
if new_app_mode == AppMode.ADVANCED_CHAT:
features = {
"opening_statement": app_model_config_dict.get("opening_statement"),
"suggested_questions": app_model_config_dict.get("suggested_questions"),
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
"speech_to_text": app_model_config_dict.get("speech_to_text"),
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
"retriever_resource": app_model_config_dict.get("retriever_resource"),
}
else:
features = {
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
}
# create workflow record
workflow = Workflow(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=WorkflowType.from_app_mode(new_app_mode).value,
version="draft",
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account_id,
environment_variables=[],
conversation_variables=[],
)
db.session.add(workflow)
db.session.commit()
return workflow
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
elif app_mode_enum == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
elif app_mode_enum == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
else:
raise ValueError("Invalid app mode")
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
"""
Convert to Start Node
:param variables: list of variables
:return:
"""
return {
"id": "start",
"position": None,
"data": {
"title": "START",
"type": NodeType.START.value,
"variables": [jsonable_encoder(v) for v in variables],
},
}
def _convert_to_http_request_node(
self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity]
) -> tuple[list[dict], dict[str, str]]:
"""
Convert API Based Extension to HTTP Request Node
:param app_model: App instance
:param variables: list of variables
:param external_data_variables: list of external data variables
:return:
"""
index = 1
nodes = []
external_data_variable_node_mapping = {}
tenant_id = app_model.tenant_id
for external_data_variable in external_data_variables:
tool_type = external_data_variable.type
if tool_type != "api":
continue
tool_variable = external_data_variable.variable
tool_config = external_data_variable.config
# get params from config
api_based_extension_id = tool_config.get("api_based_extension_id")
if not api_based_extension_id:
continue
# get api_based_extension
api_based_extension = self._get_api_based_extension(
tenant_id=tenant_id, api_based_extension_id=api_based_extension_id
)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key)
inputs = {}
for v in variables:
inputs[v.variable] = "{{#start." + v.variable + "#}}"
request_body = {
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
"params": {
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
},
}
request_body_json = json.dumps(request_body)
request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}")
http_request_node = {
"id": f"http_request_{index}",
"position": None,
"data": {
"title": f"HTTP REQUEST {api_based_extension.name}",
"type": NodeType.HTTP_REQUEST.value,
"method": "post",
"url": api_based_extension.api_endpoint,
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
"headers": "",
"params": "",
"body": {"type": "json", "data": request_body_json},
},
}
nodes.append(http_request_node)
# append code node for response body parsing
code_node: dict[str, Any] = {
"id": f"code_{index}",
"position": None,
"data": {
"title": f"Parse {api_based_extension.name} Response",
"type": NodeType.CODE.value,
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
"code_language": "python3",
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
'response_json)\n return {\n "result": response_body["result"]\n }',
"outputs": {"result": {"type": "string"}},
},
}
nodes.append(code_node)
external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"]
index += 1
return nodes, external_data_variable_node_mapping
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
:param dataset_config: dataset
:param model_config: model config
:return:
"""
retrieve_config = dataset_config.retrieve_config
if new_app_mode == AppMode.ADVANCED_CHAT:
query_variable_selector = ["sys", "query"]
elif retrieve_config.query_variable:
# fetch query variable
query_variable_selector = ["start", retrieve_config.query_variable]
else:
return None
return {
"id": "knowledge_retrieval",
"position": None,
"data": {
"title": "KNOWLEDGE RETRIEVAL",
"type": NodeType.KNOWLEDGE_RETRIEVAL.value,
"query_variable_selector": query_variable_selector,
"dataset_ids": dataset_config.dataset_ids,
"retrieval_mode": retrieve_config.retrieve_strategy.value,
"single_retrieval_config": {
"model": {
"provider": model_config.provider,
"name": model_config.model,
"mode": model_config.mode,
"completion_params": {
**model_config.parameters,
"stop": model_config.stop,
},
}
}
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
else None,
"multiple_retrieval_config": {
"top_k": retrieve_config.top_k,
"score_threshold": retrieve_config.score_threshold,
"reranking_model": retrieve_config.reranking_model,
}
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
else None,
},
}
def _convert_to_llm_node(
self,
original_app_mode: AppMode,
new_app_mode: AppMode,
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
"""
Convert to LLM Node
:param original_app_mode: original app mode
:param new_app_mode: new app mode
:param graph: graph
:param model_config: model config
:param prompt_template: prompt template
:param file_upload: file upload config (optional)
:param external_data_variable_node_mapping: external data variable node mapping
"""
# fetch start and knowledge retrieval node
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
knowledge_retrieval_node = next(
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
)
role_prefix = None
prompts: Any = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
if not prompt_template.simple_prompt_template:
raise ValueError("Simple prompt template is required")
# get prompt template
prompt_transform = SimplePromptTransform()
prompt_template_config = prompt_transform.get_prompt_template(
app_mode=original_app_mode,
provider=model_config.provider,
model=model_config.model,
pre_prompt=prompt_template.simple_prompt_template,
has_context=knowledge_retrieval_node is not None,
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
if not template:
prompts = []
else:
template = self._replace_template_variables(
template, start_node["data"]["variables"], external_data_variable_node_mapping
)
prompts = [{"role": "user", "text": template}]
else:
advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template
prompts = []
if advanced_chat_prompt_template:
for m in advanced_chat_prompt_template.messages:
text = m.text
text = self._replace_template_variables(
text, start_node["data"]["variables"], external_data_variable_node_mapping
)
prompts.append({"role": m.role.value, "text": text})
# Completion Model
else:
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
if not prompt_template.simple_prompt_template:
raise ValueError("Simple prompt template is required")
# get prompt template
prompt_transform = SimplePromptTransform()
prompt_template_config = prompt_transform.get_prompt_template(
app_mode=original_app_mode,
provider=model_config.provider,
model=model_config.model,
pre_prompt=prompt_template.simple_prompt_template,
has_context=knowledge_retrieval_node is not None,
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
template = self._replace_template_variables(
template=template,
variables=start_node["data"]["variables"],
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
prompts = {"text": template}
prompt_rules = prompt_template_config["prompt_rules"]
role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
}
else:
advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template
if advanced_completion_prompt_template:
text = advanced_completion_prompt_template.prompt
text = self._replace_template_variables(
template=text,
variables=start_node["data"]["variables"],
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
else:
text = ""
text = text.replace("{{#query#}}", "{{#sys.query#}}")
prompts = {
"text": text,
}
if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix:
role_prefix = {
"user": advanced_completion_prompt_template.role_prefix.user,
"assistant": advanced_completion_prompt_template.role_prefix.assistant,
}
memory = None
if new_app_mode == AppMode.ADVANCED_CHAT:
memory = {"role_prefix": role_prefix, "window": {"enabled": False}}
completion_params = model_config.parameters
completion_params.update({"stop": model_config.stop})
return {
"id": "llm",
"position": None,
"data": {
"title": "LLM",
"type": NodeType.LLM.value,
"model": {
"provider": model_config.provider,
"name": model_config.model,
"mode": model_config.mode,
"completion_params": completion_params,
},
"prompt_template": prompts,
"memory": memory,
"context": {
"enabled": knowledge_retrieval_node is not None,
"variable_selector": ["knowledge_retrieval", "result"]
if knowledge_retrieval_node is not None
else None,
},
"vision": {
"enabled": file_upload is not None,
"variable_selector": ["sys", "files"] if file_upload is not None else None,
"configs": {"detail": file_upload.image_config.detail}
if file_upload is not None and file_upload.image_config is not None
else None,
},
},
}
def _replace_template_variables(
self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None
) -> str:
"""
Replace Template Variables
:param template: template
:param variables: list of variables
:param external_data_variable_node_mapping: external data variable node mapping
:return:
"""
for v in variables:
template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}")
if external_data_variable_node_mapping:
for variable, code_node_id in external_data_variable_node_mapping.items():
template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}")
return template
def _convert_to_end_node(self) -> dict:
"""
Convert to End Node
:return:
"""
# for original completion app
return {
"id": "end",
"position": None,
"data": {
"title": "END",
"type": NodeType.END.value,
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
},
}
def _convert_to_answer_node(self) -> dict:
"""
Convert to Answer Node
:return:
"""
# for original chat app
return {
"id": "answer",
"position": None,
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
"""
Create Edge
:param source: source node id
:param target: target node id
:return:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
"""
Append Node to Graph
:param graph: Graph, include: nodes, edges
:param node: Node to append
:return:
"""
previous_node = graph["nodes"][-1]
graph["nodes"].append(node)
graph["edges"].append(self._create_edge(previous_node["id"], node["id"]))
return graph
def _get_new_app_mode(self, app_model: App) -> AppMode:
"""
Get new app mode
:param app_model: App instance
:return: AppMode
"""
if app_model.mode == AppMode.COMPLETION.value:
return AppMode.WORKFLOW
else:
return AppMode.ADVANCED_CHAT
def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str):
"""
Get API Based Extension
:param tenant_id: tenant id
:param api_based_extension_id: api based extension id
:return:
"""
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
if not api_based_extension:
raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}")
return api_based_extension

View File

@@ -0,0 +1,103 @@
import uuid
from datetime import datetime
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from models import App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
class WorkflowAppService:
def get_paginate_workflow_app_logs(
self,
*,
session: Session,
app_model: App,
keyword: str | None = None,
status: WorkflowRunStatus | None = None,
created_at_before: datetime | None = None,
created_at_after: datetime | None = None,
page: int = 1,
limit: int = 20,
) -> dict:
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
:param session: SQLAlchemy session
:param app_model: app model
:param keyword: search keyword
:param status: filter by status
:param created_at_before: filter logs created before this timestamp
:param created_at_after: filter logs created after this timestamp
:param page: page number
:param limit: items per page
:return: Pagination object
"""
# Build base statement using SQLAlchemy 2.0 style
stmt = select(WorkflowAppLog).where(
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
)
if keyword or status:
stmt = stmt.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword:
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),
# filter keyword by end user session id if created by end user role
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
]
# filter keyword by workflow run id
keyword_uuid = self._safe_parse_uuid(keyword)
if keyword_uuid:
keyword_conditions.append(WorkflowRun.id == keyword_uuid)
stmt = stmt.outerjoin(
EndUser,
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
).where(or_(*keyword_conditions))
if status:
stmt = stmt.where(WorkflowRun.status == status)
# Add time-based filtering
if created_at_before:
stmt = stmt.where(WorkflowAppLog.created_at <= created_at_before)
if created_at_after:
stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after)
stmt = stmt.order_by(WorkflowAppLog.created_at.desc())
# Get total count using the same filters
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
# Apply pagination limits
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
# Execute query and get items
items = list(session.scalars(offset_stmt).all())
return {
"page": page,
"limit": limit,
"total": total,
"has_more": total > page * limit,
"data": items,
}
@staticmethod
def _safe_parse_uuid(value: str):
# fast check
if len(value) < 32:
return None
try:
return uuid.UUID(value)
except ValueError:
return None

View File

@@ -0,0 +1,143 @@
import threading
from typing import Optional
import contexts
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowRunService:
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
Only return triggered_from == advanced_chat
:param app_model: app model
:param args: request args
"""
class WorkflowWithMessage:
message_id: str
conversation_id: str
def __init__(self, workflow_run: WorkflowRun):
self._workflow_run = workflow_run
def __getattr__(self, item):
return getattr(self._workflow_run, item)
pagination = self.get_paginate_workflow_runs(app_model, args)
with_message_workflow_runs = []
for workflow_run in pagination.data:
message = workflow_run.message
with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
if message:
with_message_workflow_run.message_id = message.id
with_message_workflow_run.conversation_id = message.conversation_id
with_message_workflow_runs.append(with_message_workflow_run)
pagination.data = with_message_workflow_runs
return pagination
def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
:param app_model: app model
:param args: request args
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.id == run_id,
)
.first()
)
return workflow_run
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
"""
Get workflow run node execution list
"""
workflow_run = self.get_workflow_run(app_model, run_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if not workflow_run:
return []
node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
WorkflowNodeExecution.app_id == app_model.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == run_id,
)
.order_by(WorkflowNodeExecution.index.desc())
.all()
)
return node_executions

View File

@@ -0,0 +1,519 @@
import json
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
from typing import Any, Optional
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, AppMode
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
class WorkflowService:
"""
Workflow Service
"""
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
)
.first()
)
# return draft workflow
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get published workflow
"""
if not app_model.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
def get_all_published_workflow(
self,
*,
session: Session,
app_model: App,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
) -> tuple[Sequence[Workflow], bool]:
"""
Get published workflow with pagination
"""
if not app_model.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(Workflow.app_id == app_model.id)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
if user_id:
stmt = stmt.where(Workflow.created_by == user_id)
if named_only:
stmt = stmt.where(Workflow.marked_name != "")
workflows = session.scalars(stmt).all()
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
def sync_draft_workflow(
self,
*,
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
) -> Workflow:
"""
Sync draft workflow
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=WorkflowType.from_app_mode(app_model.mode).value,
version="draft",
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
db.session.add(workflow)
# update draft workflow if found
else:
workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
# commit db session changes
db.session.commit()
# trigger app workflow events
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow)
# return draft workflow
return workflow
def publish_workflow(
self,
*,
session: Session,
app_model: App,
account: Account,
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
)
# commit db session changes
session.add(workflow)
# trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
# return new workflow
return workflow
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configs
"""
# return default block config
default_block_configs = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
"""
Get default config of node.
:param node_type: node type
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
# return default block config
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_draft_workflow_node(
self, app_model: App, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=app_model.tenant_id,
node_id=node_id,
)
workflow_node_execution.app_id = app_model.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
user_id=user_id,
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id,
)
return workflow_node_execution
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
) -> WorkflowNodeExecution:
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
try:
node_instance, generator = getter()
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
run_succeeded = False
node_run_result = None
error = e.error
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = tenant_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1
workflow_node_execution.node_id = node_id
workflow_node_execution.node_type = node_instance.node_type
workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
if node_run_result.process_data
else None
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
workflow_node_execution.inputs = json.dumps(inputs)
workflow_node_execution.process_data = json.dumps(process_data)
workflow_node_execution.outputs = json.dumps(outputs)
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
return workflow_node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
Completion App to Workflow App
:param app_model: App instance
:param account: Account instance
:param args: dict
:return:
"""
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
new_app: App = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get("name", "Default Name"),
icon_type=args.get("icon_type", "emoji"),
icon=args.get("icon", "🤖"),
icon_background=args.get("icon_background", "#FFEAD5"),
)
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
if app_model.mode == AppMode.ADVANCED_CHAT.value:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
"""
Update workflow attributes
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
return None
allowed_fields = ["marked_name", "marked_comment"]
for field, value in data.items():
if field in allowed_fields:
setattr(workflow, field, value)
workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
return workflow
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
"""
Delete a workflow
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:return: True if successful
:raises: ValueError if workflow not found
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
# Check if workflow is a draft version
if workflow.version == "draft":
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
session.delete(workflow)
return True

View File

@@ -0,0 +1,53 @@
from flask_login import current_user # type: ignore
from configs import dify_config
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
from services.account_service import TenantService
from services.feature_service import FeatureService
class WorkspaceService:
@classmethod
def get_tenant_info(cls, tenant: Tenant):
if not tenant:
return None
tenant_info = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,
"status": tenant.status,
"created_at": tenant.created_at,
"in_trail": True,
"trial_end_reason": None,
"role": "normal",
}
# Get role of user
tenant_account_join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
if can_replace_logo and TenantService.has_roles(
tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
):
base_url = dify_config.FILES_URL
replace_webapp_logo = (
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
if tenant.custom_config_dict.get("replace_webapp_logo")
else None
)
remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False)
tenant_info["custom_config"] = {
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
return tenant_info