Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
.env
*.env.*
storage/generate_files/*
storage/privkeys/*
storage/tools/*
storage/upload_files/*
# Logs
logs
*.log*
# jetbrains
.idea
.mypy_cache
.ruff_cache
# venv
.venv

490
dify_1.4.0/api/.env.example Normal file
View File

@@ -0,0 +1,490 @@
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
# Alternatively you can set it with `SECRET_KEY` environment variable.
SECRET_KEY=
# Console API base URL
CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000
# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001
# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# redis configuration
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false
REDIS_SENTINELS=
REDIS_SENTINEL_SERVICE_NAME=
REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage
# S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com
S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=your-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
# Tencent COS Storage configuration
TENCENT_COS_BUCKET_NAME=your-bucket-name
TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name
BAIDU_OBS_SECRET_KEY=your-secret-key
BAIDU_OBS_ACCESS_KEY=your-access-key
BAIDU_OBS_ENDPOINT=your-server-url
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# Supabase Storage Configuration
SUPABASE_BUCKET_NAME=your-bucket-name
SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore
VECTOR_STORE=weaviate
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
QDRANT_CLIENT_TIMEOUT=20
QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
#Couchbase configuration
COUCHBASE_CONNECTION_STRING=127.0.0.1
COUCHBASE_USER=Administrator
COUCHBASE_PASSWORD=password
COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_ANALYZER_PARAMS=
# MyScale configuration
MYSCALE_HOST=127.0.0.1
MYSCALE_PORT=8123
MYSCALE_USER=default
MYSCALE_PASSWORD=
MYSCALE_DATABASE=default
MYSCALE_FTS_PARAMS=
# Relyt configuration
RELYT_HOST=127.0.0.1
RELYT_PORT=5432
RELYT_USER=postgres
RELYT_PASSWORD=postgres
RELYT_DATABASE=postgres
# Tencent configuration
TENCENT_VECTOR_DB_URL=http://127.0.0.1
TENCENT_VECTOR_DB_API_KEY=dify
TENCENT_VECTOR_DB_TIMEOUT=30
TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431
PGVECTO_RS_USER=postgres
PGVECTO_RS_PASSWORD=difyai123456
PGVECTO_RS_DATABASE=postgres
# PGVector configuration
PGVECTOR_HOST=127.0.0.1
PGVECTOR_PORT=5433
PGVECTOR_USER=postgres
PGVECTOR_PASSWORD=postgres
PGVECTOR_DATABASE=postgres
PGVECTOR_MIN_CONNECTION=1
PGVECTOR_MAX_CONNECTION=5
# TableStore Vector configuration
TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
TABLESTORE_INSTANCE_NAME=instance-name
TABLESTORE_ACCESS_KEY_ID=xxx
TABLESTORE_ACCESS_KEY_SECRET=xxx
# Tidb Vector configuration
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
TIDB_VECTOR_PORT=4000
TIDB_VECTOR_USER=xxx.root
TIDB_VECTOR_PASSWORD=xxxxxx
TIDB_VECTOR_DATABASE=dify
# Tidb on qdrant configuration
TIDB_ON_QDRANT_URL=http://127.0.0.1
TIDB_ON_QDRANT_API_KEY=dify
TIDB_ON_QDRANT_CLIENT_TIMEOUT=20
TIDB_ON_QDRANT_GRPC_ENABLED=false
TIDB_ON_QDRANT_GRPC_PORT=6334
TIDB_PUBLIC_KEY=dify
TIDB_PRIVATE_KEY=dify
TIDB_API_URL=http://127.0.0.1
TIDB_IAM_API_URL=http://127.0.0.1
TIDB_REGION=regions/aws-us-east-1
TIDB_PROJECT_ID=dify
TIDB_SPEND_LIMIT=100
# Chroma configuration
CHROMA_HOST=127.0.0.1
CHROMA_PORT=8000
CHROMA_TENANT=default_tenant
CHROMA_DATABASE=default_database
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
CHROMA_AUTH_CREDENTIALS=difyai123456
# AnalyticDB configuration
ANALYTICDB_KEY_ID=your-ak
ANALYTICDB_KEY_SECRET=your-sk
ANALYTICDB_REGION_ID=cn-hangzhou
ANALYTICDB_INSTANCE_ID=gp-ab123456
ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1
OPENSEARCH_PORT=9200
OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url
UPSTASH_VECTOR_TOKEN=your-access-token
# ViKingDB configuration
VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
VIKINGDB_REGION=cn-shanghai
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600
OPENGAUSS_USER=postgres
OPENGAUSS_PASSWORD=Dify@123
OPENGAUSS_DATABASE=dify
OPENGAUSS_MIN_CONNECTION=1
OPENGAUSS_MAX_CONNECTION=5
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.gmail.com
SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Sentry configuration
SENTRY_DSN=
# DEBUG
DEBUG=false
SQLALCHEMY_ECHO=false
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
ETL_TYPE=dify
UNSTRUCTURED_API_URL=
UNSTRUCTURED_API_KEY=
SCARF_NO_ANALYTICS=true
#ssrf
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
SSRF_DEFAULT_MAX_RETRIES=3
SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5
BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database
# Workflow file upload limit
WORKFLOW_FILE_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
# API Tool configuration
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# HTTP Node configuration
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
# Log file path
LOG_FILE=
# Log file max size, the unit is MB
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Log dateformat
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Workflow runtime configuration
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
# Plugin configuration
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
OTEL_EXPORTER_TYPE=otlp
OTEL_SAMPLING_RATE=0.1
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
OTEL_MAX_QUEUE_SIZE=2048
OTEL_MAX_EXPORT_BATCH_SIZE=512
OTEL_METRIC_EXPORT_INTERVAL=60000
OTEL_BATCH_EXPORT_TIMEOUT=10000
OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false

101
dify_1.4.0/api/.ruff.toml Normal file
View File

@@ -0,0 +1,101 @@
exclude = [
"migrations/*",
]
line-length = 120
[format]
quote-style = "double"
[lint]
preview = false
select = [
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"PT", # flake8-pytest-style rules
"PLC0208", # iteration-over-set
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
"PLE0605", # invalid-all-format
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
"RUF013", # implicit-optional
"RUF019", # unnecessary-key-check
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
]
ignore = [
"E402", # module-import-not-at-top-of-file
"E711", # none-comparison
"E712", # true-false-comparison
"E721", # type-comparison
"E722", # bare-except
"F821", # undefined-name
"F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
]
[lint.pyflakes]
allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests",
"tests.unit_tests",
]

79
dify_1.4.0/api/Dockerfile Normal file
View File

@@ -0,0 +1,79 @@
# base image
FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install uv
ENV UV_VERSION=0.6.14
RUN pip install --no-cache-dir uv==${UV_VERSION}
FROM base AS packages
# if you located in China, you can use aliyun mirror to speed up
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
RUN uv sync --locked
# production stage
FROM base AS production
ENV FLASK_APP=app.py
ENV EDITION=SELF_HOSTED
ENV DEPLOY_ENV=PRODUCTION
ENV CONSOLE_API_URL=http://127.0.0.1:5001
ENV CONSOLE_WEB_URL=http://127.0.0.1:3000
ENV SERVICE_API_URL=http://127.0.0.1:5001
ENV APP_WEB_URL=http://127.0.0.1:3000
EXPOSE 5001
# set timezone
ENV TZ=UTC
WORKDIR /app/api
RUN \
apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE
libmagic1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code
COPY . /app/api/
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

93
dify_1.4.0/api/README.md Normal file
View File

@@ -0,0 +1,93 @@
# Dify Backend API
## Usage
> [!IMPORTANT]
>
> In the v1.3.0 release, `poetry` has been replaced with
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
1. Start the docker-compose stack
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash
cd ../docker
cp middleware.env.example middleware.env
# change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
cd ../api
```
2. Copy `.env.example` to `.env`
```cli
cp .env.example .env
```
3. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
```bash for Linux
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
bash for Mac
```bash for Mac
secret_key=$(openssl rand -base64 42)
sed -i '' "/^SECRET_KEY=/c\\
SECRET_KEY=${secret_key}" .env
```
4. Create environment.
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
First, you need to add the uv package manager, if you don't have it already.
```bash
pip install uv
# Or on macOS
brew install uv
```
5. Install dependencies
```bash
uv sync --dev
```
6. Run migrate
Before the first launch, migrate the database to the latest version.
```bash
uv run flask db upgrade
```
7. Start backend
```bash
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`.
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
```
## Testing
1. Install dependencies for both the backend and the test environment
```bash
uv sync --dev
```
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash
uv run -P api bash dev/pytest/pytest_all_tests.sh
```

41
dify_1.4.0/api/app.py Normal file
View File

@@ -0,0 +1,41 @@
import os
import sys
def is_db_command():
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
# create app
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@@ -0,0 +1,110 @@
import logging
import time
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> DifyApp:
"""
create a raw flask app
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
return dify_app
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
return app
def initialize_extensions(app: DifyApp):
from extensions import (
ext_app_metrics,
ext_blueprints,
ext_celery,
ext_code_based_extension,
ext_commands,
ext_compress,
ext_database,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_mail,
ext_migrate,
ext_otel,
ext_proxy_fix,
ext_redis,
ext_sentry,
ext_set_secretkey,
ext_storage,
ext_timezone,
ext_warnings,
)
extensions = [
ext_timezone,
ext_logging,
ext_warnings,
ext_import_modules,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
ext_database,
ext_app_metrics,
ext_migrate,
ext_redis,
ext_storage,
ext_celery,
ext_login,
ext_mail,
ext_hosting_provider,
ext_sentry,
ext_proxy_fix,
ext_blueprints,
ext_commands,
ext_otel,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logging.info(f"Skipped {short_name}")
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
def create_migrations_app():
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
return app

1152
dify_1.4.0/api/commands.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()

View File

@@ -0,0 +1,102 @@
import logging
from typing import Any
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
from .observability import ObservabilityConfig
from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource
from .remote_settings_sources.nacos import NacosSettingsSource
logger = logging.getLogger(__name__)
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
if not remote_source_name:
return {}
remote_source: RemoteSettingsSource | None = None
match remote_source_name:
case RemoteSettingsSourceName.APOLLO:
remote_source = ApolloSettingsSource(current_state)
case RemoteSettingsSourceName.NACOS:
remote_source = NacosSettingsSource(current_state)
case _:
logger.warning(f"Unsupported remote source: {remote_source_name}")
return {}
d: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
if field_value is not None:
d[field_key] = field_value
return d
class DifyConfig(
# Packaging info
PackagingInfo,
# Deployment configs
DeploymentConfig,
# Feature configs
FeatureConfig,
# Middleware configs
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Observability configs
ObservabilityConfig,
# Remote source configs
RemoteSettingsSourceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
# ignore extra attributes
extra="ignore",
)
# Before adding any config,
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
)

View File

@@ -0,0 +1,28 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class DeploymentConfig(BaseSettings):
"""
Configuration settings for application deployment
"""
APPLICATION_NAME: str = Field(
description="Name of the application, used for identification and logging purposes",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="Enable debug mode for additional logging and development features",
default=False,
)
EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
)
DEPLOY_ENV: str = Field(
description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION",
default="PRODUCTION",
)

View File

@@ -0,0 +1,20 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class EnterpriseFeatureConfig(BaseSettings):
"""
Configuration for enterprise-level features.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description="Enable or disable enterprise-level features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description="Allow customization of the enterprise logo.",
default=False,
)

View File

@@ -0,0 +1,10 @@
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
NotionConfig,
SentryConfig,
):
pass

View File

@@ -0,0 +1,36 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class NotionConfig(BaseSettings):
"""
Configuration settings for Notion integration
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.",
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.",
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description="Type of Notion integration."
" Set to 'internal' for internal integrations, or None for public integrations.",
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.",
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description="Integration token for Notion API access. Used for direct API calls without OAuth flow.",
default=None,
)

View File

@@ -0,0 +1,28 @@
from typing import Optional
from pydantic import Field, NonNegativeFloat
from pydantic_settings import BaseSettings
class SentryConfig(BaseSettings):
"""
Configuration settings for Sentry error tracking and performance monitoring
"""
SENTRY_DSN: Optional[str] = Field(
description="Sentry Data Source Name (DSN)."
" This is the unique identifier of your Sentry project, used to send events to the correct project.",
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry performance monitoring traces."
" Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.",
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry profiling."
" Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.",
default=1.0,
)

View File

@@ -0,0 +1,906 @@
from typing import Annotated, Literal, Optional
from pydantic import (
AliasChoices,
Field,
HttpUrl,
NegativeInt,
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
)
from pydantic_settings import BaseSettings
from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings):
"""
Security-related configurations for the application
"""
SECRET_KEY: str = Field(
description="Secret key for secure session cookie signing."
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
default="",
)
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a password reset token remains valid",
default=5,
)
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
)
ADMIN_API_KEY_ENABLE: bool = Field(
description="Whether to enable admin api key for authentication",
default=False,
)
ADMIN_API_KEY: Optional[str] = Field(
description="admin api key for authentication",
default=None,
)
class AppExecutionConfig(BaseSettings):
"""
Configuration parameters for application execution
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
description="Maximum number of requests per app per day",
default=5000,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
Configuration for the code execution sandbox environment
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="URL endpoint for the code execution service",
default=HttpUrl("http://sandbox:8194"),
)
CODE_EXECUTION_API_KEY: str = Field(
description="API key for accessing the code execution service",
default="dify-sandbox",
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="Connection timeout in seconds for code execution requests",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="Read timeout in seconds for code execution requests",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="Write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="Maximum allowed numeric value in code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="Minimum allowed numeric value in code execution",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested structures in code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="Maximum number of decimal places for floating-point numbers in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for string arrays in code execution",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for object arrays in code execution",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for numeric arrays in code execution",
default=1000,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
"""
PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default=HttpUrl("http://localhost:5002"),
)
PLUGIN_DAEMON_KEY: str = Field(
description="Plugin API key",
default="plugin-api-key",
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
description="Plugin Remote Install Host",
default="localhost",
)
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
description="Plugin Remote Install Port",
default=5003,
)
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin packages in bytes",
default=15728640,
)
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
description="Maximum allowed size for plugin bundles in bytes",
default=15728640 * 12,
)
class MarketplaceConfig(BaseSettings):
"""
Configuration for marketplace
"""
MARKETPLACE_ENABLED: bool = Field(
description="Enable or disable marketplace",
default=True,
)
MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default=HttpUrl("https://marketplace.dify.ai"),
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
"""
CONSOLE_API_URL: str = Field(
description="Base URL for the console API,"
"used for login authentication callback or notion integration callbacks",
default="",
)
CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface,used for frontend references and CORS configuration",
default="",
)
SERVICE_API_URL: str = Field(
description="Base URL for the service API, displayed to users for API access",
default="",
)
APP_WEB_URL: str = Field(
description="Base URL for the web application, used for frontend references",
default="",
)
ENDPOINT_URL_TEMPLATE: str = Field(
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
class FileAccessConfig(BaseSettings):
"""
Configuration for file access and handling
"""
FILES_URL: str = Field(
description="Base URL for file preview or download,"
" used for frontend display and multi-model inputs"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
alias_priority=1,
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
default=300,
)
class FileUploadConfig(BaseSettings):
"""
Configuration for file upload limitations
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed file size for uploads in megabytes",
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a single upload batch",
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for uploads in megabytes",
default=10,
)
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="video file size limit in Megabytes for uploading files",
default=100,
)
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="audio file size limit in Megabytes for uploading files",
default=50,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a batch upload operation",
default=20,
)
WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a workflow upload operation",
default=10,
)
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
"""
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description="Comma-separated list of allowed origins for CORS in the console",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
)
@computed_field
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description="",
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
default="*",
)
@computed_field
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for binary data in HTTP requests",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for text data in HTTP requests",
default=1 * 1024 * 1024,
)
HTTP_REQUEST_NODE_SSL_VERIFY: bool = Field(
description="Enable or disable SSL verification for HTTP requests",
default=True,
)
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
description="Maximum number of retries for network requests (SSRF)",
default=3,
)
SSRF_PROXY_ALL_URL: Optional[str] = Field(
description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
default=None,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
default=None,
)
SSRF_DEFAULT_TIME_OUT: PositiveFloat = Field(
description="The default timeout period used for network requests (SSRF)",
default=5,
)
SSRF_DEFAULT_CONNECT_TIME_OUT: PositiveFloat = Field(
description="The default connect timeout period used for network requests (SSRF)",
default=5,
)
SSRF_DEFAULT_READ_TIME_OUT: PositiveFloat = Field(
description="The default read timeout period used for network requests (SSRF)",
default=5,
)
SSRF_DEFAULT_WRITE_TIME_OUT: PositiveFloat = Field(
description="The default write timeout period used for network requests (SSRF)",
default=5,
)
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
" when the app is behind a single trusted reverse proxy.",
default=False,
)
class InnerAPIConfig(BaseSettings):
"""
Configuration for internal API functionality
"""
INNER_API: bool = Field(
description="Enable or disable the internal API",
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description="API key for accessing the internal API",
default=None,
)
class LoggingConfig(BaseSettings):
"""
Configuration for application logging
"""
LOG_LEVEL: str = Field(
description="Logging level, default to INFO. Set to ERROR for production environments.",
default="INFO",
)
LOG_FILE: Optional[str] = Field(
description="File path for log output.",
default=None,
)
LOG_FILE_MAX_SIZE: PositiveInt = Field(
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
default=20,
)
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
description="Maximum file backup count file rotation retention",
default=5,
)
LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: Optional[str] = Field(
description="Date format string for log timestamps",
default=None,
)
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
default="UTC",
)
class ModelLoadBalanceConfig(BaseSettings):
"""
Configuration for model load balancing and token counting
"""
MODEL_LB_ENABLED: bool = Field(
description="Enable or disable load balancing for models",
default=False,
)
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: bool = Field(
description="Enable or disable plugin based token counting. If disabled, token counting will return 0.",
default=False,
)
class BillingConfig(BaseSettings):
"""
Configuration for platform billing features
"""
BILLING_ENABLED: bool = Field(
description="Enable or disable billing functionality",
default=False,
)
class UpdateConfig(BaseSettings):
"""
Configuration for application update checks
"""
CHECK_UPDATE_URL: str = Field(
description="URL to check for application updates",
default="https://updates.dify.ai",
)
class WorkflowConfig(BaseSettings):
"""
Configuration for workflow execution
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description="Maximum number of steps allowed in a single workflow execution",
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum execution time in seconds for a single workflow",
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested workflow calls",
default=5,
)
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
description="Maximum allowed depth for nested parallel executions",
default=3,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""
Configuration for workflow node execution
"""
MAX_SUBMIT_COUNT: PositiveInt = Field(
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
default=100,
)
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
"""
OAUTH_REDIRECT_PATH: str = Field(
description="Redirect path for OAuth authentication callbacks",
default="/console/api/oauth/authorize",
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client ID",
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description="GitHub OAuth client secret",
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description="Google OAuth client ID",
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description="Google OAuth client secret",
default=None,
)
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
description="Expiration time for access tokens in minutes",
default=60,
)
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
)
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""
Configuration for content moderation
"""
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="Size of the buffer for content moderation processing",
default=300,
)
class ToolConfig(BaseSettings):
"""
Configuration for tool management
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description="Maximum age in seconds for caching tool icons",
default=3600,
)
class MailConfig(BaseSettings):
"""
Configuration for email services
"""
MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.",
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description="Default email address to use as the sender",
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description="API key for Resend email service",
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description="API URL for Resend email service",
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description="SMTP server hostname",
default=None,
)
SMTP_PORT: Optional[int] = Field(
description="SMTP server port number",
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description="Username for SMTP authentication",
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description="Password for SMTP authentication",
default=None,
)
SMTP_USE_TLS: bool = Field(
description="Enable TLS encryption for SMTP connections",
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description="Enable opportunistic TLS for SMTP connections",
default=False,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
)
class RagEtlConfig(BaseSettings):
"""
Configuration for RAG ETL processes
"""
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
ETL_TYPE: str = Field(
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
default="dify",
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description="Data source type for keyword extraction"
" ('database' or other supported types), default to 'database'",
default="database",
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description="API URL for Unstructured.io service",
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured.io service",
default="",
)
SCARF_NO_ANALYTICS: Optional[str] = Field(
description="This is about whether to disable Scarf analytics in Unstructured library.",
default="false",
)
class DataSetConfig(BaseSettings):
"""
Configuration for dataset management
"""
PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations - plan: sandbox",
default=30,
)
PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations - plan: pro and team",
default=7,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description="Enable or disable dataset operator functionality",
default=False,
)
TIDB_SERVERLESS_NUMBER: PositiveInt = Field(
description="number of tidb serverless cluster",
default=500,
)
CREATE_TIDB_SERVICE_JOB_ENABLED: bool = Field(
description="Enable or disable create tidb service job",
default=False,
)
PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING: PositiveInt = Field(
description="Interval in days for message cleanup operations - plan: sandbox",
default=30,
)
class WorkspaceConfig(BaseSettings):
"""
Configuration for workspace management
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description="Expiration time in hours for workspace invitation links",
default=72,
)
class IndexingConfig(BaseSettings):
"""
Configuration for indexing operations
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description="Maximum token length for text segmentation during indexing",
default=4000,
)
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description="Comma-separated list of included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description="Comma-separated list of excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description="Comma-separated list of pinned tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description="Comma-separated list of included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description="Comma-separated list of excluded tools",
default="",
)
@property
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
@property
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
@property
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
@property
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
@property
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
@property
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
default=False,
)
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
description="whether to enable email password login",
default=True,
)
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
description="whether to enable github/google oauth login",
default=False,
)
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="expiry time in minutes for email code login token",
default=5,
)
ALLOW_REGISTER: bool = Field(
description="whether to enable register",
default=False,
)
ALLOW_CREATE_WORKSPACE: bool = Field(
description="whether to enable create workspace",
default=False,
)
class AccountConfig(BaseSettings):
ACCOUNT_DELETION_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
description="Duration in minutes for which a account deletion token remains valid",
default=5,
)
EDUCATION_ENABLED: bool = Field(
description="whether to enable education identity",
default=False,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
LoginConfig,
AccountConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
):
pass

View File

@@ -0,0 +1,239 @@
from typing import Optional
from pydantic import Field, NonNegativeInt
from pydantic_settings import BaseSettings
class HostedCreditConfig(BaseSettings):
HOSTED_MODEL_CREDIT_CONFIG: str = Field(
description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'",
default="",
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
Returns 1 if model is not found in configuration (default credit).
:param model_name: The name of the model to search for
:return: The credit value for the model
"""
if not self.HOSTED_MODEL_CREDIT_CONFIG:
return 1
try:
credit_map = dict(
item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item
)
# Search for matching model pattern
for pattern, credit in credit_map.items():
if pattern.strip() == model_name:
return int(credit)
return 1 # Default quota if no match found
except (ValueError, AttributeError):
return 1 # Return default quota if parsing fails
class HostedOpenAiConfig(BaseSettings):
"""
Configuration for hosted OpenAI service
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted OpenAI service",
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted OpenAI API",
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description="Organization ID for hosted OpenAI service",
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted OpenAI service",
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003",
)
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description="Enable hosted Azure OpenAI service",
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted Azure OpenAI service",
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted Azure OpenAI API",
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Azure OpenAI service usage",
default=200,
)
class HostedAnthropicConfig(BaseSettings):
"""
Configuration for hosted Anthropic service
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description="Base URL for hosted Anthropic API",
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description="API key for hosted Anthropic service",
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
class HostedMinmaxConfig(BaseSettings):
"""
Configuration for hosted Minmax service
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description="Enable hosted Minmax service",
default=False,
)
class HostedSparkConfig(BaseSettings):
"""
Configuration for hosted Spark service
"""
HOSTED_SPARK_ENABLED: bool = Field(
description="Enable hosted Spark service",
default=False,
)
class HostedZhipuAIConfig(BaseSettings):
"""
Configuration for hosted ZhipuAI service
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description="Enable hosted ZhipuAI service",
default=False,
)
class HostedModerationConfig(BaseSettings):
"""
Configuration for hosted Moderation service
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description="Enable hosted Moderation service",
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description="Comma-separated list of moderation providers",
default="",
)
class HostedFetchAppTemplateConfig(BaseSettings):
"""
Configuration for fetching app templates
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
default="remote",
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote app templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig(
# place the configs in alphabet order
HostedAnthropicConfig,
HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig,
HostedMinmaxConfig,
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
):
pass

View File

@@ -0,0 +1,307 @@
import os
from typing import Any, Literal, Optional
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
from .vdb.opengauss_config import OpenGaussConfig
from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig
from .vdb.tablestore_config import TableStoreConfig
from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
from .vdb.vastbase_vector_config import VastbaseVectorConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
] = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
default="storage",
deprecated=True,
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description="Type of vector store to use for efficient similarity search."
" Set to None if not using a vector store.",
default=None,
)
VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field(
description="Enable whitelist for vector store.",
default=False,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description="Method for keyword extraction and storage."
" Default is 'jieba', a Chinese text segmentation library.",
default="jieba",
)
class DatabaseConfig(BaseSettings):
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description="Port number for database connection.",
default=5432,
)
DB_USERNAME: str = Field(
description="Username for database authentication.",
default="postgres",
)
DB_PASSWORD: str = Field(
description="Password for database authentication.",
default="",
)
DB_DATABASE: str = Field(
description="Name of the database to connect to.",
default="dify",
)
DB_CHARSET: str = Field(
description="Character set for database connection.",
default="",
)
DB_EXTRAS: str = Field(
description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
default="postgresql",
)
@computed_field
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}"
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="Maximum number of database connections in the pool.",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="Maximum number of connections that can be created beyond the pool_size.",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="Number of seconds after which a connection is automatically recycled.",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description="If True, SQLAlchemy will log all SQL statements.",
default=False,
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description="URL of the message broker for Celery tasks.",
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel for high availability.",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
)
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class InternalTestConfig(BaseSettings):
"""
Configuration settings for Internal Test
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="Internal test AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="Internal test AWS access key ID",
default=None,
)
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,
AzureBlobStorageConfig,
BaiduOBSStorageConfig,
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
OpenDALStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,
VolcengineTOSStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
HuaweiCloudConfig,
MilvusConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,
PGVectorConfig,
VastbaseVectorConfig,
PGVectoRSConfig,
QdrantConfig,
RelytConfig,
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
CouchbaseConfig,
InternalTestConfig,
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
LindormConfig,
OceanBaseVectorConfig,
BaiduVectorDBConfig,
OpenGaussConfig,
TableStoreConfig,
):
pass

View File

View File

@@ -0,0 +1,95 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic_settings import BaseSettings
class RedisConfig(BaseSettings):
"""
Configuration settings for Redis connection
"""
REDIS_HOST: str = Field(
description="Hostname or IP address of the Redis server",
default="localhost",
)
REDIS_PORT: PositiveInt = Field(
description="Port number on which the Redis server is listening",
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description="Username for Redis authentication (if required)",
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description="Password for Redis authentication (if required)",
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description="Redis database number to use (0-15)",
default=0,
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Enable Redis Sentinel mode for high availability",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Comma-separated list of Redis Sentinel nodes (host:port)",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel service to monitor",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Username for Redis Sentinel authentication (if required)",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Password for Redis Sentinel authentication (if required)",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1,
)
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)
REDIS_SERIALIZATION_PROTOCOL: int = Field(
description="Redis serialization protocol (RESP) version",
default=3,
)
REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field(
description="Enable client side cache in redis",
default=False,
)

View File

@@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class AliyunOSSStorageConfig(BaseSettings):
"""
Configuration settings for Aliyun Object Storage Service (OSS)
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Aliyun OSS bucket to store and retrieve objects",
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with Aliyun OSS",
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with Aliyun OSS",
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description="URL of the Aliyun OSS endpoint for your chosen region",
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')",
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None,
)

View File

@@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class S3StorageConfig(BaseSettings):
"""
Configuration settings for S3-compatible object storage
"""
S3_ENDPOINT: Optional[str] = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')",
default=None,
)
S3_REGION: Optional[str] = Field(
description="Region where the S3 bucket is located (e.g., 'us-east-1')",
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description="Name of the S3 bucket to store and retrieve objects",
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with the S3 service",
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with the S3 service",
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description="S3 addressing style: 'auto', 'path', or 'virtual'",
default="auto",
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description="Use AWS managed IAM roles for authentication instead of access/secret keys",
default=False,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class AzureBlobStorageConfig(BaseSettings):
"""
Configuration settings for Azure Blob Storage
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description="Name of the Azure Storage account (e.g., 'mystorageaccount')",
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description="Access key for authenticating with the Azure Storage account",
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description="Name of the Azure Blob container to store and retrieve objects",
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')",
default=None,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class BaiduOBSStorageConfig(BaseSettings):
"""
Configuration settings for Baidu Object Storage Service (OBS)
"""
BAIDU_OBS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')",
default=None,
)
BAIDU_OBS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Baidu OBS",
default=None,
)
BAIDU_OBS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Baidu OBS",
default=None,
)
BAIDU_OBS_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')",
default=None,
)

View File

@@ -0,0 +1,20 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class GoogleCloudStorageConfig(BaseSettings):
"""
Configuration settings for Google Cloud Storage
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')",
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description="Base64-encoded JSON key file for Google Cloud service account authentication",
default=None,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class HuaweiCloudOBSStorageConfig(BaseSettings):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Huawei Cloud OBS",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Huawei Cloud OBS",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class OCIStorageConfig(BaseSettings):
"""
Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage
"""
OCI_ENDPOINT: Optional[str] = Field(
description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')",
default=None,
)
OCI_REGION: Optional[str] = Field(
description="OCI region where the bucket is located (e.g., 'us-phoenix-1')",
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')",
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description="Access key (also known as API key) for authenticating with OCI Object Storage",
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description="Secret key associated with the access key for authenticating with OCI Object Storage",
default=None,
)

View File

@@ -0,0 +1,9 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OpenDALStorageConfig(BaseSettings):
OPENDAL_SCHEME: str = Field(
default="fs",
description="OpenDAL scheme.",
)

View File

@@ -0,0 +1,25 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class SupabaseStorageConfig(BaseSettings):
"""
Configuration settings for Supabase Object Storage Service
"""
SUPABASE_BUCKET_NAME: Optional[str] = Field(
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
default=None,
)
SUPABASE_API_KEY: Optional[str] = Field(
description="API KEY for authenticating with Supabase",
default=None,
)
SUPABASE_URL: Optional[str] = Field(
description="URL of the Supabase",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class TencentCloudCOSStorageConfig(BaseSettings):
"""
Configuration settings for Tencent Cloud Object Storage (COS)
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Tencent Cloud COS bucket to store and retrieve objects",
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')",
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)",
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)",
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Volcengine TOS",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Volcengine TOS",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')",
default=None,
)

View File

@@ -0,0 +1,51 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AnalyticdbConfig(BaseSettings):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
)
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
)
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
)
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to.",
)
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance"
" (usually the initial account created with the instance).",
)
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for database authentication."
)
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
)
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
)
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

View File

@@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class BaiduVectorDBConfig(BaseSettings):
"""
Configuration settings for Baidu Vector Database
"""
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
default=None,
)
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
default=30000,
)
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
description="Account for authenticating with the Baidu Vector Database",
default=None,
)
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Baidu Vector Database service",
default=None,
)
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Baidu Vector Database to connect to",
default=None,
)
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Baidu Vector Database (default is 1)",
default=1,
)
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)

View File

@@ -0,0 +1,40 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ChromaConfig(BaseSettings):
"""
Configuration settings for Chroma vector database
"""
CHROMA_HOST: Optional[str] = Field(
description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')",
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description="Port number on which the Chroma server is listening (default is 8000)",
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description="Tenant identifier for multi-tenancy support in Chroma",
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description="Name of the Chroma database to connect to",
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)",
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description="Authentication credentials for Chroma (format depends on the auth provider)",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class CouchbaseConfig(BaseSettings):
"""
Couchbase configs
"""
COUCHBASE_CONNECTION_STRING: Optional[str] = Field(
description="COUCHBASE connection string",
default=None,
)
COUCHBASE_USER: Optional[str] = Field(
description="COUCHBASE user",
default=None,
)
COUCHBASE_PASSWORD: Optional[str] = Field(
description="COUCHBASE password",
default=None,
)
COUCHBASE_BUCKET_NAME: Optional[str] = Field(
description="COUCHBASE bucket name",
default=None,
)
COUCHBASE_SCOPE_NAME: Optional[str] = Field(
description="COUCHBASE scope name",
default=None,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Configuration settings for Elasticsearch
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Port number on which the Elasticsearch server is listening (default is 9200)",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Username for authenticating with Elasticsearch (default is 'elastic')",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Elasticsearch (default is 'elastic')",
default="elastic",
)

View File

@@ -0,0 +1,25 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class HuaweiCloudConfig(BaseSettings):
"""
Configuration settings for Huawei cloud search service
"""
HUAWEI_CLOUD_HOSTS: Optional[str] = Field(
description="Hostname or IP address of the Huawei cloud search service instance",
default=None,
)
HUAWEI_CLOUD_USER: Optional[str] = Field(
description="Username for authenticating with Huawei cloud search service",
default=None,
)
HUAWEI_CLOUD_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Huawei cloud search service",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class LindormConfig(BaseSettings):
"""
Lindorm configs
"""
LINDORM_URL: Optional[str] = Field(
description="Lindorm url",
default=None,
)
LINDORM_USERNAME: Optional[str] = Field(
description="Lindorm user",
default=None,
)
LINDORM_PASSWORD: Optional[str] = Field(
description="Lindorm password",
default=None,
)
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)
LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0)

View File

@@ -0,0 +1,46 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class MilvusConfig(BaseSettings):
"""
Configuration settings for Milvus vector database
"""
MILVUS_URI: Optional[str] = Field(
description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')",
default="http://127.0.0.1:19530",
)
MILVUS_TOKEN: Optional[str] = Field(
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
MILVUS_USER: Optional[str] = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Milvus, if username/password authentication is enabled",
default=None,
)
MILVUS_DATABASE: str = Field(
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

View File

@@ -0,0 +1,38 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MyScaleConfig(BaseSettings):
"""
Configuration settings for MyScale vector database
"""
MYSCALE_HOST: str = Field(
description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')",
default="localhost",
)
MYSCALE_PORT: PositiveInt = Field(
description="Port number on which the MyScale server is listening (default is 8123)",
default=8123,
)
MYSCALE_USER: str = Field(
description="Username for authenticating with MyScale (default is 'default')",
default="default",
)
MYSCALE_PASSWORD: str = Field(
description="Password for authenticating with MyScale (default is an empty string)",
default="",
)
MYSCALE_DATABASE: str = Field(
description="Name of the MyScale database to connect to (default is 'default')",
default="default",
)
MYSCALE_FTS_PARAMS: str = Field(
description="Additional parameters for MyScale Full Text Search index)",
default="",
)

View File

@@ -0,0 +1,41 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OceanBaseVectorConfig(BaseSettings):
"""
Configuration settings for OceanBase Vector database
"""
OCEANBASE_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')",
default=None,
)
OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the OceanBase Vector server is listening (default is 2881)",
default=2881,
)
OCEANBASE_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the OceanBase Vector database",
default=None,
)
OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OceanBase Vector database",
default=None,
)
OCEANBASE_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the OceanBase Vector database to connect to",
default=None,
)
OCEANBASE_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires OceanBase >= 4.3.5.1). Set to false for compatibility "
"with older versions",
default=False,
)

View File

@@ -0,0 +1,50 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OpenGaussConfig(BaseSettings):
"""
Configuration settings for OpenGauss
"""
OPENGAUSS_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenGauss server(e.g., 'localhost')",
default=None,
)
OPENGAUSS_PORT: PositiveInt = Field(
description="Port number on which the OpenGauss server is listening (default is 6600)",
default=6600,
)
OPENGAUSS_USER: Optional[str] = Field(
description="Username for authenticating with the OpenGauss database",
default=None,
)
OPENGAUSS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OpenGauss database",
default=None,
)
OPENGAUSS_DATABASE: Optional[str] = Field(
description="Name of the OpenGauss database to connect to",
default=None,
)
OPENGAUSS_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the OpenGauss database",
default=1,
)
OPENGAUSS_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the OpenGauss database",
default=5,
)
OPENGAUSS_ENABLE_PQ: bool = Field(
description="Enable openGauss PQ acceleration feature",
default=False,
)

View File

@@ -0,0 +1,58 @@
import enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description="Port number on which the OpenSearch server is listening (default is 9200)",
default=9200,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with OpenSearch",
default=None,
)
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
)

View File

@@ -0,0 +1,46 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class OracleConfig(BaseSettings):
"""
Configuration settings for Oracle database
"""
ORACLE_USER: Optional[str] = Field(
description="Username for authenticating with the Oracle database",
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Oracle database",
default=None,
)
ORACLE_DSN: Optional[str] = Field(
description="Oracle database connection string. For traditional database, use format 'host:port/service_name'. "
"For autonomous database, use the service name from tnsnames.ora in the wallet",
default=None,
)
ORACLE_CONFIG_DIR: Optional[str] = Field(
description="Directory containing the tnsnames.ora configuration file. Only used in thin mode connection",
default=None,
)
ORACLE_WALLET_LOCATION: Optional[str] = Field(
description="Oracle wallet directory path containing the wallet files for secure connection",
default=None,
)
ORACLE_WALLET_PASSWORD: Optional[str] = Field(
description="Password to decrypt the Oracle wallet, if it is encrypted",
default=None,
)
ORACLE_IS_AUTONOMOUS: bool = Field(
description="Flag indicating whether connecting to Oracle Autonomous Database",
default=False,
)

View File

@@ -0,0 +1,50 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectorConfig(BaseSettings):
"""
Configuration settings for PGVector (PostgreSQL with vector extension)
"""
PGVECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')",
default=None,
)
PGVECTOR_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server is listening (default is 5433)",
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database",
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database",
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database to connect to",
default=None,
)
PGVECTOR_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the PostgreSQL database",
default=1,
)
PGVECTOR_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the PostgreSQL database",
default=5,
)
PGVECTOR_PG_BIGM: bool = Field(
description="Whether to use pg_bigm module for full text search",
default=False,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectoRSConfig(BaseSettings):
"""
Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL)
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')",
default=None,
)
PGVECTO_RS_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database using PGVecto.RS",
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database using PGVecto.RS",
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database with PGVecto.RS extension to connect to",
default=None,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class QdrantConfig(BaseSettings):
"""
Configuration settings for Qdrant vector database
"""
QDRANT_URL: Optional[str] = Field(
description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')",
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Qdrant server",
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Timeout in seconds for Qdrant client operations (default is 20 seconds)",
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)",
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description="Port number for gRPC connection to Qdrant server (default is 6334)",
default=6334,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class RelytConfig(BaseSettings):
"""
Configuration settings for Relyt database
"""
RELYT_HOST: Optional[str] = Field(
description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')",
default=None,
)
RELYT_PORT: PositiveInt = Field(
description="Port number on which the Relyt server is listening (default is 9200)",
default=9200,
)
RELYT_USER: Optional[str] = Field(
description="Username for authenticating with the Relyt database",
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Relyt database",
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description="Name of the Relyt database to connect to (default is 'default')",
default="default",
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class TableStoreConfig(BaseSettings):
"""
Configuration settings for TableStore.
"""
TABLESTORE_ENDPOINT: Optional[str] = Field(
description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')",
default=None,
)
TABLESTORE_INSTANCE_NAME: Optional[str] = Field(
description="Instance name to access TableStore server (eg. 'instance-name')",
default=None,
)
TABLESTORE_ACCESS_KEY_ID: Optional[str] = Field(
description="AccessKey id for the instance name",
default=None,
)
TABLESTORE_ACCESS_KEY_SECRET: Optional[str] = Field(
description="AccessKey secret for the instance name",
default=None,
)

View File

@@ -0,0 +1,55 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TencentVectorDBConfig(BaseSettings):
"""
Configuration settings for Tencent Vector Database
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')",
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Tencent Vector Database service",
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)",
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description="Username for authenticating with the Tencent Vector Database (if required)",
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Tencent Vector Database (if required)",
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Tencent Vector Database (default is 1)",
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Tencent Vector Database (default is 2)",
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Tencent Vector Database to connect to",
default=None,
)
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features",
default=False,
)

View File

@@ -0,0 +1,70 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TidbOnQdrantConfig(BaseSettings):
"""
Tidb on Qdrant configs
"""
TIDB_ON_QDRANT_URL: Optional[str] = Field(
description="Tidb on Qdrant url",
default=None,
)
TIDB_ON_QDRANT_API_KEY: Optional[str] = Field(
description="Tidb on Qdrant api key",
default=None,
)
TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Tidb on Qdrant client timeout in seconds",
default=20,
)
TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Tidb on Qdrant connection",
default=False,
)
TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field(
description="Tidb on Qdrant grpc port",
default=6334,
)
TIDB_PUBLIC_KEY: Optional[str] = Field(
description="Tidb account public key",
default=None,
)
TIDB_PRIVATE_KEY: Optional[str] = Field(
description="Tidb account private key",
default=None,
)
TIDB_API_URL: Optional[str] = Field(
description="Tidb API url",
default=None,
)
TIDB_IAM_API_URL: Optional[str] = Field(
description="Tidb IAM API url",
default=None,
)
TIDB_REGION: Optional[str] = Field(
description="Tidb serverless region",
default="regions/aws-us-east-1",
)
TIDB_PROJECT_ID: Optional[str] = Field(
description="Tidb project id",
default=None,
)
TIDB_SPEND_LIMIT: Optional[int] = Field(
description="Tidb spend limit",
default=100,
)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class TiDBVectorConfig(BaseSettings):
"""
Configuration settings for TiDB Vector database
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')",
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the TiDB Vector server is listening (default is 4000)",
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the TiDB Vector database",
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the TiDB Vector database",
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the TiDB Vector database to connect to",
default=None,
)

View File

@@ -0,0 +1,20 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class UpstashConfig(BaseSettings):
"""
Configuration settings for Upstash vector database
"""
UPSTASH_VECTOR_URL: Optional[str] = Field(
description="URL of the upstash server (e.g., 'https://vector.upstash.io')",
default=None,
)
UPSTASH_VECTOR_TOKEN: Optional[str] = Field(
description="Token for authenticating with the upstash server",
default=None,
)

View File

@@ -0,0 +1,45 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class VastbaseVectorConfig(BaseSettings):
"""
Configuration settings for Vector (Vastbase with vector extension)
"""
VASTBASE_HOST: Optional[str] = Field(
description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
default=None,
)
VASTBASE_PORT: PositiveInt = Field(
description="Port number on which the Vastbase server is listening (default is 5432)",
default=5432,
)
VASTBASE_USER: Optional[str] = Field(
description="Username for authenticating with the Vastbase database",
default=None,
)
VASTBASE_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Vastbase database",
default=None,
)
VASTBASE_DATABASE: Optional[str] = Field(
description="Name of the Vastbase database to connect to",
default=None,
)
VASTBASE_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the Vastbase database",
default=1,
)
VASTBASE_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the Vastbase database",
default=5,
)

View File

@@ -0,0 +1,50 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class VikingDBConfig(BaseSettings):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:
https://www.volcengine.com/docs/6291/65568
"""
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
description="The Access Key provided by Volcengine VikingDB for API authentication."
"Refer to the following documentation for details on obtaining credentials:"
"https://www.volcengine.com/docs/6291/65568",
default=None,
)
VIKINGDB_SECRET_KEY: Optional[str] = Field(
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
default=None,
)
VIKINGDB_REGION: str = Field(
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
default="cn-shanghai",
)
VIKINGDB_HOST: str = Field(
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
'api-vikingdb.mlp.cn-shanghai.volces.com')",
default="api-vikingdb.mlp.cn-shanghai.volces.com",
)
VIKINGDB_SCHEME: str = Field(
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
default="http",
)
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
description="The connection timeout of the Volcengine VikingDB service.",
default=30,
)
VIKINGDB_SOCKET_TIMEOUT: int = Field(
description="The socket timeout of the Volcengine VikingDB service.",
default=30,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class WeaviateConfig(BaseSettings):
"""
Configuration settings for Weaviate vector database
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')",
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Weaviate server",
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,
)

View File

@@ -0,0 +1,9 @@
from configs.observability.otel.otel_config import OTelConfig
class ObservabilityConfig(OTelConfig):
"""
Observability configuration settings
"""
pass

View File

@@ -0,0 +1,49 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OTelConfig(BaseSettings):
"""
OpenTelemetry configuration settings
"""
ENABLE_OTEL: bool = Field(
description="Whether to enable OpenTelemetry",
default=False,
)
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
)
OTLP_API_KEY: str = Field(
description="OTLP API key",
default="",
)
OTEL_EXPORTER_TYPE: str = Field(
description="OTEL exporter type",
default="otlp",
)
OTEL_EXPORTER_OTLP_PROTOCOL: str = Field(
description="OTLP exporter protocol ('grpc' or 'http')",
default="http",
)
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
default=5000, description="Batch export schedule delay in milliseconds"
)
OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor")
OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size")
OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds")
OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds")
OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds")

View File

@@ -0,0 +1,18 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class PackagingInfo(BaseSettings):
"""
Packaging build information
"""
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.4.0",
)
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default="",
)

View File

@@ -0,0 +1,15 @@
from pydantic import Field
from .apollo import ApolloSettingsSourceInfo
from .base import RemoteSettingsSource
from .enums import RemoteSettingsSourceName
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
description="name of remote config source",
default="",
)
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]

View File

@@ -0,0 +1,55 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import Field
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings
from configs.remote_settings_sources.base import RemoteSettingsSource
from .client import ApolloClient
class ApolloSettingsSourceInfo(BaseSettings):
"""
Packaging build information
"""
APOLLO_APP_ID: Optional[str] = Field(
description="apollo app_id",
default=None,
)
APOLLO_CLUSTER: Optional[str] = Field(
description="apollo cluster",
default=None,
)
APOLLO_CONFIG_URL: Optional[str] = Field(
description="apollo config url",
default=None,
)
APOLLO_NAMESPACE: Optional[str] = Field(
description="apollo namespace",
default=None,
)
class ApolloSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.client = ApolloClient(
app_id=configs["APOLLO_APP_ID"],
cluster=configs["APOLLO_CLUSTER"],
config_url=configs["APOLLO_CONFIG_URL"],
start_hot_update=False,
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
)
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
return field_value, field_name, False

View File

@@ -0,0 +1,304 @@
import hashlib
import json
import logging
import os
import threading
import time
from collections.abc import Mapping
from pathlib import Path
from .python_3x import http_request, makedirs_wrapper
from .utils import (
CONFIGURATIONS,
NAMESPACE_NAME,
NOTIFICATION_ID,
get_value_from_dict,
init_ip,
no_key_cache_key,
signature,
url_encode_wrapper,
)
logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
):
# Core routing parameters
self.config_url = config_url
self.cluster = cluster
self.app_id = app_id
# Non-core parameters
self.ip = init_ip()
self.secret = secret
# Check the parameter variables
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key = None
# Private startup method
self._path_checker()
if start_hot_update:
self._start_hot_update()
# start the heartbeat thread
heartbeat = threading.Thread(target=self._heart_beat)
heartbeat.daemon = True
heartbeat.start()
def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"get_json_from_net load configs failed, body is {body}")
return None
data = json.loads(body)
data = data["configurations"]
return_data = {CONFIGURATIONS: data}
return return_data
else:
return None
except Exception:
logger.exception("an error occurred in get_json_from_net")
return None
def get_value(self, key, default_val=None, namespace="application"):
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
return val
no_key = no_key_cache_key(namespace, key)
if no_key in self._no_key:
return default_val
# read the network configuration
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
self._update_cache_and_file(namespace_data, namespace)
return val
# read the file configuration
namespace_cache = self._get_local_cache(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
self._update_cache_and_file(namespace_cache, namespace)
return val
# If all of them are not obtained, the default value is returned
# and the local cache is set to None
self._set_local_cache_none(namespace, key)
return default_val
except Exception:
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
return default_val
# Set the key of a namespace to none, and do not set default val
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key
def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()
def stop(self):
self._stopping = True
logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None:
return
if old_kv is None:
old_kv = {}
if new_kv is None:
new_kv = {}
try:
for key in old_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if new_value is None:
# If newValue is empty, it means key, and the value is deleted.
self._change_listener("delete", namespace, key, old_value)
continue
if new_value != old_value:
self._change_listener("update", namespace, key, new_value)
continue
for key in new_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if old_value is None:
self._change_listener("add", namespace, key, new_value)
except BaseException as e:
logger.warning(str(e))
def _path_checker(self):
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
new_string = json.dumps(namespace_data)
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
if self._hash.get(namespace) == new_hash:
pass
else:
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
file_path.write_text(new_string)
self._hash[namespace] = new_hash
# get the configuration from the local file
def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
result = json.loads(f.readline())
return result
return {}
def _long_poll(self):
notifications = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
if NOTIFICATION_ID in namespace_data:
notification_id = self._cache[key][NOTIFICATION_ID]
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
try:
# if the length is 0 it is returned directly
if len(notifications) == 0:
return
url = "{}/notifications/v2".format(self.config_url)
params = {
"appId": self.app_id,
"cluster": self.cluster,
"notifications": json.dumps(notifications, ensure_ascii=False),
}
param_str = url_encode_wrapper(params)
url = url + "?" + param_str
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
http_code = code
if http_code == 304:
logger.debug("No change, loop...")
return
if http_code == 200:
if not body:
logger.error(f"_long_poll load configs failed,body is {body}")
return
data = json.loads(body)
for entry in data:
namespace = entry[NAMESPACE_NAME]
n_id = entry[NOTIFICATION_ID]
logger.info("%s has changes: notificationId=%d", namespace, n_id)
self._get_net_and_set_local(namespace, n_id, call_change=True)
return
else:
logger.warning("Sleep...")
except Exception as e:
logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
namespace_data[NOTIFICATION_ID] = n_id
old_namespace = self._cache.get(namespace)
self._update_cache_and_file(namespace_data, namespace)
if self._change_listener is not None and call_change and old_namespace:
old_kv = old_namespace.get(CONFIGURATIONS)
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)
def _listener(self):
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
time.sleep(self._cycle_time)
logger.info("stopped, long_poll")
# add the need for endorsement to the header
def _sign_headers(self, url: str) -> Mapping[str, str]:
headers: dict[str, str] = {}
if self.secret == "":
return headers
uri = url[len(self.config_url) : len(url)]
time_unix_now = str(int(round(time.time() * 1000)))
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
headers["Timestamp"] = time_unix_now
return headers
def _heart_beat(self):
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"_do_heart_beat load configs failed,body is {body}")
return None
data = json.loads(body)
if self.last_release_key == data["releaseKey"]:
return None
self.last_release_key = data["releaseKey"]
data = data["configurations"]
self._update_cache_and_file(data, namespace)
else:
return None
except Exception:
logger.exception("an error occurred in _do_heart_beat")
return None
def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)
if not net_namespace_data:
return namespace_data
namespace_data = net_namespace_data.get(CONFIGURATIONS)
if namespace_data:
self._update_cache_and_file(namespace_data, namespace)
return namespace_data

View File

@@ -0,0 +1,41 @@
import logging
import os
import ssl
import urllib.request
from urllib import parse
from urllib.error import HTTPError
# Create an SSL context that allows for a lower level of security
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Create an opener object and pass in a custom SSL context
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}):
try:
request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
except HTTPError as e:
if e.code == 304:
logger.warning("http_request error,code is 304, maybe you should check secret")
return 304, None
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
raise e
def url_encode(params):
return parse.urlencode(params)
def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True)

View File

@@ -0,0 +1,51 @@
import hashlib
import socket
from .python_3x import url_encode
# define constants
CONFIGURATIONS = "configurations"
NOTIFICATION_ID = "notificationId"
NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys
def signature(timestamp, uri, secret):
import base64
import hmac
string_to_sign = "" + timestamp + "\n" + uri
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params):
return url_encode(params)
def no_key_cache_key(namespace, key):
return "{}{}{}".format(namespace, len(namespace), key)
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
return None
if key in kv_data:
return kv_data[key]
return None
def init_ip():
ip = ""
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 53))
ip = s.getsockname()[0]
finally:
if s:
s.close()
return ip

View File

@@ -0,0 +1,15 @@
from collections.abc import Mapping
from typing import Any
from pydantic.fields import FieldInfo
class RemoteSettingsSource:
def __init__(self, configs: Mapping[str, Any]):
pass
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class RemoteSettingsSourceName(StrEnum):
APOLLO = "apollo"
NACOS = "nacos"

View File

@@ -0,0 +1,52 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
from pydantic.fields import FieldInfo
from .http_request import NacosHttpClient
logger = logging.getLogger(__name__)
from configs.remote_settings_sources.base import RemoteSettingsSource
from .utils import _parse_config
class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.configs = configs
self.remote_configs: dict[str, Any] = {}
self.async_init()
def async_init(self):
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
params = {"dataId": data_id, "group": group, "tenant": tenant}
try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content)
except Exception as e:
logger.exception("[get-access-token] exception occurred")
raise
def _parse_config(self, content: str) -> dict:
if not content:
return {}
try:
return _parse_config(self, content)
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
if field_value is None:
return None, field_name, False
return field_value, field_name, False

View File

@@ -0,0 +1,83 @@
import base64
import hashlib
import hmac
import logging
import os
import time
import requests
logger = logging.getLogger(__name__)
class NacosHttpClient:
def __init__(self):
self.username = os.getenv("DIFY_ENV_NACOS_USERNAME")
self.password = os.getenv("DIFY_ENV_NACOS_PASSWORD")
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token = None
self.token_ttl = 18000
self.token_expire_time: float = 0
def http_request(self, url, method="GET", headers=None, params=None):
try:
self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"):
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
if module == "login":
return
ts = str(int(time.time() * 1000))
if self.ak and self.sk:
sign_str = self.get_sign_str(params["group"], params["tenant"], ts)
headers["Spas-AccessKey"] = self.ak
headers["Spas-Signature"] = self.__do_sign(sign_str, self.sk)
headers["timeStamp"] = ts
if self.username and self.password:
self.get_access_token(force_refresh=False)
params["accessToken"] = self.token
def __do_sign(self, sign_str, sk):
return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode()
.strip()
)
def get_sign_str(self, group, tenant, ts):
sign_str = ""
if tenant:
sign_str = tenant + "+"
if group:
sign_str = sign_str + group + "+"
if sign_str:
sign_str += ts
return sign_str
def get_access_token(self, force_refresh=False):
current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token
params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login"
try:
resp = requests.request("POST", url, headers=None, params=params)
resp.raise_for_status()
response_data = resp.json()
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
except Exception as e:
logger.exception("[get-access-token] exception occur")
raise

View File

@@ -0,0 +1,31 @@
def _parse_config(self, content: str) -> dict[str, str]:
config: dict[str, str] = {}
if not content:
return config
for line in content.splitlines():
cleaned_line = line.strip()
if not cleaned_line or cleaned_line.startswith(("#", "!")):
continue
separator_index = -1
for i, c in enumerate(cleaned_line):
if c in ("=", ":") and (i == 0 or cleaned_line[i - 1] != "\\"):
separator_index = i
break
if separator_index == -1:
continue
key = cleaned_line[:separator_index].strip()
raw_value = cleaned_line[separator_index + 1 :].strip()
try:
decoded_value = bytes(raw_value, "utf-8").decode("unicode_escape")
decoded_value = decoded_value.replace(r"\=", "=").replace(r"\:", ":")
except UnicodeDecodeError:
decoded_value = raw_value
config[key] = decoded_value
return config

View File

@@ -0,0 +1,40 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = [
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"docx",
"csv",
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@@ -0,0 +1,32 @@
language_timezone_mapping = {
"en-US": "America/New_York",
"zh-Hans": "Asia/Shanghai",
"zh-Hant": "Asia/Taipei",
"pt-BR": "America/Sao_Paulo",
"es-ES": "Europe/Madrid",
"fr-FR": "Europe/Paris",
"de-DE": "Europe/Berlin",
"ja-JP": "Asia/Tokyo",
"ko-KR": "Asia/Seoul",
"ru-RU": "Europe/Moscow",
"it-IT": "Europe/Rome",
"uk-UA": "Europe/Kyiv",
"vi-VN": "Asia/Ho_Chi_Minh",
"ro-RO": "Europe/Bucharest",
"pl-PL": "Europe/Warsaw",
"hi-IN": "Asia/Kolkata",
"tr-TR": "Europe/Istanbul",
"fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
}
languages = list(language_timezone_mapping.keys())
def supported_language(lang):
if lang in languages:
return lang
error = "{lang} is not a valid language.".format(lang=lang)
raise ValueError(error)

View File

@@ -0,0 +1,7 @@
# The two constants below should keep in sync.
# Default content type for files which have no explicit content type.
DEFAULT_MIME_TYPE = "application/octet-stream"
# Default file extension for files which have no explicit content type, should
# correspond to the `DEFAULT_MIME_TYPE` above.
DEFAULT_EXTENSION = ".bin"

View File

@@ -0,0 +1,84 @@
import json
from collections.abc import Mapping
from models.model import AppMode
default_app_templates: Mapping[AppMode, Mapping] = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
"mode": AppMode.WORKFLOW.value,
"enable_site": True,
"enable_api": True,
}
},
# completion default mode
AppMode.COMPLETION: {
"app": {
"mode": AppMode.COMPLETION.value,
"enable_site": True,
"enable_api": True,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
"user_input_form": json.dumps(
[
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": "",
},
},
]
),
"pre_prompt": "{{query}}",
},
},
# chat default mode
AppMode.CHAT: {
"app": {
"mode": AppMode.CHAT.value,
"enable_site": True,
"enable_api": True,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
},
},
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"enable_site": True,
"enable_api": True,
},
},
# agent-chat default mode
AppMode.AGENT_CHAT: {
"app": {
"mode": AppMode.AGENT_CHAT.value,
"enable_site": True,
"enable_api": True,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {},
},
},
},
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,4 @@
TTS_AUTO_PLAY_TIMEOUT = 5
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02

View File

@@ -0,0 +1,39 @@
from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)

View File

@@ -0,0 +1,65 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
pass
_default = HiddenValue()
class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
NOTE: you need to call `increment_thread_recycles` before requests
"""
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
@classmethod
def increment_thread_recycles(cls):
try:
recycles = cls._thread_recycles.get()
cls._thread_recycles.set(recycles + 1)
except LookupError:
cls._thread_recycles.set(0)
def __init__(self, context_var: ContextVar[T]):
self._context_var = context_var
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
def get(self, default: T | HiddenValue = _default) -> T:
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
# check if thread is recycled and should be updated
if thread_recycles < self_updates:
return self._context_var.get()
else:
# thread_recycles >= self_updates, means current context is invalid
if isinstance(default, HiddenValue) or default is _default:
raise LookupError
else:
return default
def set(self, value: T):
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
# increase it manually
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
if self._updates.get() == self._thread_recycles.get(0):
# after increment,
self._updates.set(self._updates.get() + 1)
# set the context
self._context_var.set(value)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,11 @@
from werkzeug.exceptions import HTTPException
class FilenameNotExistsError(HTTPException):
code = 400
description = "The specified filename does not exist."
class RemoteFileUploadError(HTTPException):
code = 400
description = "Error uploading remote file."

View File

@@ -0,0 +1,43 @@
from flask_restful import fields
from libs.helper import AppIconUrlField
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}

View File

@@ -0,0 +1,85 @@
import mimetypes
import os
import platform
import re
import urllib.parse
import warnings
from uuid import uuid4
import httpx
try:
import magic
except ImportError:
if platform.system() == "Windows":
warnings.warn(
"To use python-magic guess MIMETYPE, you need to run `pip install python-magic-bin`", stacklevel=2
)
elif platform.system() == "Darwin":
warnings.warn("To use python-magic guess MIMETYPE, you need to run `brew install libmagic`", stacklevel=2)
elif platform.system() == "Linux":
warnings.warn(
"To use python-magic guess MIMETYPE, you need to run `sudo apt-get install libmagic1`", stacklevel=2
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
from pydantic import BaseModel
class FileInfo(BaseModel):
filename: str
extension: str
mimetype: str
size: int
def guess_file_info_from_response(response: httpx.Response):
url = str(response.url)
# Try to extract filename from URL
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# If filename couldn't be extracted, use Content-Disposition header
if not filename:
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename_match = re.search(r'filename="?(.+)"?', content_disposition)
if filename_match:
filename = filename_match.group(1)
# If still no filename, generate a unique one
if not filename:
unique_name = str(uuid4())
filename = f"{unique_name}"
# Guess MIME type from filename first, then URL
mimetype, _ = mimetypes.guess_type(filename)
if mimetype is None:
mimetype, _ = mimetypes.guess_type(url)
if mimetype is None:
# If guessing fails, use Content-Type from response headers
mimetype = response.headers.get("Content-Type", "application/octet-stream")
# Use python-magic to guess MIME type if still unknown or generic
if mimetype == "application/octet-stream" and magic is not None:
try:
mimetype = magic.from_buffer(response.content[:1024], mime=True)
except magic.MagicException:
pass
extension = os.path.splitext(filename)[1]
# Ensure filename has an extension
if not extension:
extension = mimetypes.guess_extension(mimetype) or ".bin"
filename = f"{filename}{extension}"
return FileInfo(
filename=filename,
extension=extension,
mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)),
)

View File

@@ -0,0 +1,182 @@
from flask import Blueprint
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from .explore.message import (
MessageFeedbackApi,
MessageListApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from .explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(bp)
# File
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
# Import app controllers
from .app import (
advanced_prompt_template,
agent,
annotation,
app,
audio,
completion,
conversation,
conversation_variables,
generator,
message,
model_config,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_run,
workflow_statistic,
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
# Import billing controllers
from .billing import billing, compliance
# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
hit_testing,
metadata,
website,
)
# Import explore controllers
from .explore import (
installed_app,
parameter,
recommended_app,
saved_message,
)
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# Explore Completion
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
# Explore Conversation
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
# Explore Message
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
# Explore Workflow
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
# Import tag controllers
from .tag import tags
# Import workspace controllers
from .workspace import (
account,
agent_providers,
endpoint,
load_balancing_config,
members,
model_providers,
models,
plugin,
tool_providers,
workspace,
)

View File

@@ -0,0 +1,151 @@
from functools import wraps
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)
return decorated
class InsertExploreAppListApi(Resource):
@only_edition_cloud
@admin_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("desc", type=str, location="json")
parser.add_argument("copyright", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
recommended_app = RecommendedApp(
app_id=app.id,
description=desc,
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args["language"],
category=args["category"],
position=args["position"],
)
db.session.add(recommended_app)
app.is_public = True
db.session.commit()
return {"result": "success"}, 201
else:
recommended_app.description = desc
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
app.is_public = True
db.session.commit()
return {"result": "success"}, 200
class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
for installed_app in installed_apps:
db.session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")

View File

@@ -0,0 +1,186 @@
from typing import Any
import flask_restful
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api
from .wraps import account_initialization_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
return resource
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_model: Any = None
resource_id_field: str | None = None
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
return {"items": keys}
@marshal_with(api_key_fields)
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 201
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_model: Any = None
resource_id_field: str | None = None
def delete(self, resource_id, api_key_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
.filter(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
if key is None:
flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
class AppApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
class AppApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
class DatasetApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
class DatasetApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")

View File

@@ -0,0 +1,24 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")

View File

@@ -0,0 +1,28 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
class AgentLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")

View File

@@ -0,0 +1,275 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.login import login_required
from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200
class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {"result": "success"}, 204
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
api.add_resource(
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
)
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")

View File

@@ -0,0 +1,350 @@
import uuid
from typing import cast
from flask_login import current_user
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
"""Get app list"""
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return marshal(app_pagination, app_pagination_fields)
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201
class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
def put(self, app_model):
"""Update app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app(app_model, args)
return app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
app_service.delete_app(app_model)
return {"result": "success"}, 204
class AppCopyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
return app, 201
class AppExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
class AppNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get("name"))
return app_model
class AppIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
return app_model
class AppSiteStatus(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
return app_model
class AppApiStatus(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
return app_model
class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
"""Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
return app_trace_config
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
# add app trace
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
)
return {"result": "success"}
api.add_resource(AppListApi, "/apps")
api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")

View File

@@ -0,0 +1,111 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200

View File

@@ -0,0 +1,181 @@
import logging
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from models import App, AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
class ChatMessageAudioApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model):
file = request.files["file"]
try:
response = AudioService.transcript_asr(
app_model=app_model,
file=file,
end_user=None,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("Failed to handle post request to ChatMessageAudioApi")
raise InternalServerError()
class ChatMessageTextApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("Failed to handle post request to ChatMessageTextApi")
raise InternalServerError()
class TextModesApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
try:
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args["language"],
)
return response
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
raise AppUnavailableError("Text to audio voices language parameter loss.")
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("Failed to handle get request to TextModesApi")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")

View File

@@ -0,0 +1,166 @@
import logging
import flask_login
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
# define completion message api for user
class CompletionMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class CompletionMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id):
account = flask_login.current_user
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200
class ChatMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class ChatMessageStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id):
account = flask_login.current_user
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")

View File

@@ -0,0 +1,322 @@
from datetime import UTC, datetime
import pytz # pip install pytz
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_detail_fields,
conversation_message_detail_fields,
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
class CompletionConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
or_(
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
)
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
class CompletionConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {"result": "success"}, 204
class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
args = parser.parse_args()
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
if args["keyword"]:
keyword_filter = "%{}%".format(args["keyword"])
query = (
query.join(
Message,
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
.group_by(Conversation.id)
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
match args["sort_by"]:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
query = query.order_by(Conversation.created_at.desc())
case "updated_at":
query = query.order_by(Conversation.updated_at.asc())
case "-updated_at":
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {"result": "success"}, 204
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_account_id = current_user.id
db.session.commit()
return conversation

View File

@@ -0,0 +1,60 @@
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
class ConversationVariablesApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")

View File

@@ -0,0 +1,129 @@
from libs.exception import BaseHTTPException
class AppNotFoundError(BaseHTTPException):
error_code = "app_not_found"
description = "App not found."
code = 404
class ProviderNotInitializeError(BaseHTTPException):
error_code = "provider_not_initialize"
description = (
"No valid model provider credentials found. "
"Please go to Settings -> Model Provider to complete your provider credentials."
)
code = 400
class ProviderQuotaExceededError(BaseHTTPException):
error_code = "provider_quota_exceeded"
description = (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = "model_currently_not_support"
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
code = 400
class ConversationCompletedError(BaseHTTPException):
error_code = "conversation_completed"
description = "The conversation has ended. Please start a new conversation."
code = 400
class AppUnavailableError(BaseHTTPException):
error_code = "app_unavailable"
description = "App unavailable, please check your app configurations."
code = 400
class CompletionRequestError(BaseHTTPException):
error_code = "completion_request_error"
description = "Completion request failed."
code = 400
class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = "app_more_like_this_disabled"
description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = "no_audio_uploaded"
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = "audio_too_large"
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = "unsupported_audio_type"
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = "provider_not_support_speech_to_text"
description = "Provider not support speech to text."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
code = 400
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400
class TracingConfigNotExist(BaseHTTPException):
error_code = "trace_config_not_exist"
description = "Trace config not exist."
code = 400
class TracingConfigIsExist(BaseHTTPException):
error_code = "trace_config_is_exist"
description = "Trace config is exist."
code = 400
class TracingConfigCheckError(BaseHTTPException):
error_code = "trace_config_check_error"
description = "Invalid Credentials."
code = 400
class InvokeRateLimitError(BaseHTTPException):
"""Raised when the Invoke returns rate limit error."""
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429

View File

@@ -0,0 +1,119 @@
import os
from flask_login import current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
class RuleGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
account = current_user
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return rules
class RuleCodeGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
account = current_user
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
max_tokens=CODE_GENERATION_MAX_TOKENS,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return code_result
class RuleStructuredOutputGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return structured_output
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")

View File

@@ -0,0 +1,246 @@
import logging
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_fields)),
}
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args["first_id"]:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
has_more = False
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
class MessageFeedbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
db.session.commit()
return {"result": "success"}
class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
class MessageAnnotationCountApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
return {"count": count}
class MessageSuggestedQuestionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(message_detail_fields)
def get(self, app_model, message_id):
message_id = str(message_id)
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
return message
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")

View File

@@ -0,0 +1,147 @@
import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restful import Resource
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
class ModelConfigResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
)
except Exception:
continue
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
if key in tool_map:
tool_runtime = tool_map[key]
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception:
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
for masked_key, masked_value in masked_parameter_map[key].items():
if (
masked_key in agent_tool_entity.tool_parameters
and agent_tool_entity.tool_parameters[masked_key] == masked_value
):
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return {"result": "success"}
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")

View File

@@ -0,0 +1,92 @@
from flask_restful import Resource, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.ops_service import OpsService
class TraceAppConfigApi(Resource):
"""
Manage trace app configurations
"""
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not trace_config:
return {"has_not_configured": True}
return trace_config
except Exception as e:
raise BadRequest(str(e))
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigIsExist()
if result.get("error"):
raise TracingConfigCheckError()
return result
except Exception as e:
raise BadRequest(str(e))
@setup_required
@login_required
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}
except Exception as e:
raise BadRequest(str(e))
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204
except Exception as e:
raise BadRequest(str(e))
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

View File

@@ -0,0 +1,111 @@
from datetime import UTC, datetime
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.login import login_required
from models import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
parser.add_argument(
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
class AppSite(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise NotFound
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return site
class AppSiteAccessTokenReset(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise NotFound
site.code = Site.generate_code(16)
site.updated_by = current_user.id
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return site
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")

Some files were not shown because too many files have changed in this diff Show More