Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

View File

@@ -0,0 +1,107 @@
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def put(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 204
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
api.add_resource(AnnotationListApi, "/apps/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")

View File

@@ -0,0 +1,55 @@
from flask_restful import Resource, marshal_with
from controllers.common import fields
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from services.app_service import AppService
class AppParameterApi(Resource):
"""Resource for app variables."""
@validate_app_token
@marshal_with(fields.parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
return AppService().get_app_meta(app_model)
class AppInfoApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app information"""
tags = [tag.name for tag in app_model.tags]
return {"name": app_model.name, "description": app_model.description, "tags": tags, "mode": app_model.mode}
api.add_resource(AppParameterApi, "/parameters")
api.add_resource(AppMetaApi, "/meta")
api.add_resource(AppInfoApi, "/info")

View File

@@ -0,0 +1,125 @@
import logging
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
from controllers.service_api import api
from controllers.service_api.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(AudioApi, "/audio-to-text")
api.add_resource(TextApi, "/text-to-audio")

View File

@@ -0,0 +1,161 @@
import logging
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
NotChatAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != "completion":
raise AppUnavailableError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
response = AppGenerateService.generate(
app_model=app_model,
user=end_user,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming,
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != "completion":
raise AppUnavailableError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {"result": "success"}, 200
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {"result": "success"}, 200
api.add_resource(CompletionApi, "/completion-messages")
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
api.add_resource(ChatApi, "/chat-messages")
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")

View File

@@ -0,0 +1,126 @@
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_delete_fields,
conversation_infinite_scroll_pagination_fields,
simple_conversation_fields,
)
from fields.conversation_variable_fields import (
conversation_variable_infinite_scroll_pagination_fields,
)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
args = parser.parse_args()
try:
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(conversation_delete_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, location="json")
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationVariablesApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_variable_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser, c_id):
# conversational variable only for chat app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
api.add_resource(ConversationApi, "/conversations")
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")

View File

@@ -0,0 +1,109 @@
from libs.exception import BaseHTTPException
class AppUnavailableError(BaseHTTPException):
error_code = "app_unavailable"
description = "App unavailable, please check your app configurations."
code = 400
class NotCompletionAppError(BaseHTTPException):
error_code = "not_completion_app"
description = "Please check if your Completion app mode matches the right API route."
code = 400
class NotChatAppError(BaseHTTPException):
error_code = "not_chat_app"
description = "Please check if your app mode matches the right API route."
code = 400
class NotWorkflowAppError(BaseHTTPException):
error_code = "not_workflow_app"
description = "Please check if your app mode matches the right API route."
code = 400
class ConversationCompletedError(BaseHTTPException):
error_code = "conversation_completed"
description = "The conversation has ended. Please start a new conversation."
code = 400
class ProviderNotInitializeError(BaseHTTPException):
error_code = "provider_not_initialize"
description = (
"No valid model provider credentials found. "
"Please go to Settings -> Model Provider to complete your provider credentials."
)
code = 400
class ProviderQuotaExceededError(BaseHTTPException):
error_code = "provider_quota_exceeded"
description = (
"Your quota for Dify Hosted OpenAI has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = "model_currently_not_support"
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
code = 400
class CompletionRequestError(BaseHTTPException):
error_code = "completion_request_error"
description = "Completion request failed."
code = 400
class NoAudioUploadedError(BaseHTTPException):
error_code = "no_audio_uploaded"
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = "audio_too_large"
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = "unsupported_audio_type"
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = "provider_not_support_speech_to_text"
description = "Provider not support speech to text."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415

View File

@@ -0,0 +1,53 @@
from flask import request
from flask_restful import Resource, marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService
class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model: App, end_user: EndUser):
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if not file.mimetype:
raise UnsupportedFileTypeError()
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
try:
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=end_user,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(FileApi, "/files/upload")

View File

@@ -0,0 +1,134 @@
import json
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser
from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
class MessageListApi(Resource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"status": fields.String,
"error": fields.String,
}
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
class AppGetFeedbacksApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get All Feedbacks of an app"""
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args")
args = parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
return {"data": feedbacks}
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
except SuggestedQuestionsAfterAnswerDisabledError:
raise BadRequest("Suggested Questions Is Disabled.")
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {"result": "success", "data": questions}
api.add_resource(MessageListApi, "/messages")
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
api.add_resource(AppGetFeedbacksApi, "/app/feedbacks")

View File

@@ -0,0 +1,30 @@
from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
from controllers.common import fields
from controllers.service_api import api
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
class AppSiteApi(Resource):
"""Resource for app sites."""
@validate_app_token
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
api.add_resource(AppSiteApi, "/site")

View File

@@ -0,0 +1,168 @@
import logging
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
from controllers.service_api.app.error import (
CompletionRequestError,
NotWorkflowAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun, WorkflowRunStatus
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": fields.String,
"inputs": fields.Raw,
"outputs": fields.Raw,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"elapsed_time": fields.Float,
}
class WorkflowRunDetailApi(Resource):
@validate_app_token
@marshal_with(workflow_run_fields)
def get(self, app_model: App, workflow_run_id: str):
"""
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
return workflow_run
class WorkflowRunApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
"""
Run workflow
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()
streaming = args.get("response_mode") == "streaming"
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""
Stop workflow task
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
keyword=args.keyword,
status=args.status,
created_at_before=args.created_at__before,
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@@ -0,0 +1,324 @@
from flask import request
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""
def get(self, tenant_id):
"""Resource for getting datasets."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(
page, limit, tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True
else:
item["embedding_available"] = False
else:
item["embedding_available"] = True
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
def post(self, tenant_id):
"""Resource for creating datasets."""
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
parser.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
Args:
_: ignore
dataset_id (UUID): The ID of the dataset to be deleted.
Returns:
dict: A dictionary with a key 'result' and a value 'success'
if the dataset was successfully deleted. Omitted in HTTP response.
int: HTTP status code 204 indicating that the operation was successful.
Raises:
NotFound: If the dataset with the given ID does not exist.
"""
dataset_id_str = str(dataset_id)
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError()
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")

View File

@@ -0,0 +1,422 @@
import json
from flask import request
from flask_restful import marshal, reqparse
from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
ProviderNotInitializeError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,
)
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("text", type=str, required=True, nullable=False, location="json")
parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("text", type=str, required=False, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args["text"]:
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.")
# save file info
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
try:
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
"data": marshal(documents, document_fields),
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
return response
class DocumentIndexingStatusApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# get documents
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound("Documents not found.")
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
data = {"data": documents_status}
return data
api.add_resource(
DocumentAddByTextApi,
"/datasets/<uuid:dataset_id>/document/create_by_text",
"/datasets/<uuid:dataset_id>/document/create-by-text",
)
api.add_resource(
DocumentAddByFileApi,
"/datasets/<uuid:dataset_id>/document/create_by_file",
"/datasets/<uuid:dataset_id>/document/create-by-file",
)
api.add_resource(
DocumentUpdateByTextApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
)
api.add_resource(
DocumentUpdateByFileApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")

View File

@@ -0,0 +1,79 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = "high_quality_dataset_only"
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = "archived_document_immutable"
description = "The archived document is not editable."
code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = "dataset_name_duplicate"
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = "invalid_action"
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = "document_already_finished"
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException):
error_code = "document_indexing"
description = "The document is being processed and cannot be edited."
code = 400
class InvalidMetadataError(BaseHTTPException):
error_code = "invalid_metadata"
description = "The metadata content is incorrect. Please check and verify."
code = 400
class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use"
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409

View File

@@ -0,0 +1,17 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")

View File

@@ -0,0 +1,114 @@
from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201
def get(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
class DatasetMetadataServiceApi(DatasetApiResource):
def patch(self, tenant_id, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return marshal(metadata, dataset_metadata_fields), 200
def delete(self, tenant_id, dataset_id, metadata_id):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 204
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
def get(self, tenant_id):
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id, action):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return 200
class DocumentMetadataEditServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
MetadataService.update_documents_metadata(dataset, metadata_args)
return 200
api.add_resource(DatasetMetadataCreateServiceApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataServiceApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in")
api.add_resource(
DatasetMetadataBuiltInFieldActionServiceApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>"
)
api.add_resource(DocumentMetadataEditServiceApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@@ -0,0 +1,402 @@
from flask import request
from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
)
from services.errors.chunk import (
ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError,
)
from services.errors.chunk import (
ChildChunkIndexingError as ChildChunkIndexingServiceError,
)
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument("segments", type=list, required=False, nullable=True, location="json")
args = parser.parse_args()
if args["segments"] is not None:
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Get segments."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
parser = reqparse.RequestParser()
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args()
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_user.current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
limit=limit,
)
response = {
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
"limit": limit,
"page": page,
}
return response, 200
class DatasetSegmentApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate args
parser = reqparse.RequestParser()
parser.add_argument("segment", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
"""Create child chunk."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
def get(self, tenant_id, dataset_id, document_id, segment_id):
"""Get child chunks."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks."""
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Delete child chunk."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Update child chunk."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# get document
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(
DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>"
)
api.add_resource(
ChildChunkApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks"
)
api.add_resource(
DatasetChildChunkApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

View File

@@ -0,0 +1,54 @@
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.wraps import (
DatasetApiResource,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@@ -0,0 +1,16 @@
from flask_restful import Resource
from configs import dify_config
from controllers.service_api import api
class IndexApi(Resource):
def get(self):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
"server_version": dify_config.CURRENT_VERSION,
}
api.add_resource(IndexApi, "/")

View File

@@ -0,0 +1,21 @@
from flask_login import current_user
from flask_restful import Resource
from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService
class ModelProviderAvailableModelApi(Resource):
@validate_dataset_token
def get(self, _, model_type):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")

View File

@@ -0,0 +1,314 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model:
raise Forbidden("The app no longer exists.")
if app_model.status != "normal":
raise Forbidden("The app's status is abnormal.")
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view):
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.")
elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Forbidden("The number of documents has reached the limit of your subscription.")
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=api_token.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return decorated
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
.returning(ApiToken)
)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = "DEFAULT-USER"
end_user = (
db.session.query(EndUser)
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]