Initial commit

This commit is contained in:
2025-10-14 14:17:21 +08:00
commit ac715a8b88
35011 changed files with 3834178 additions and 0 deletions

View File

View File

@@ -0,0 +1,17 @@
from typing import Optional
from werkzeug.exceptions import HTTPException
class BaseHTTPException(HTTPException):
error_code: str = "unknown"
data: Optional[dict] = None
def __init__(self, description=None, response=None):
super().__init__(description, response)
self.data = {
"code": self.error_code,
"message": self.description,
"status": self.code,
}

View File

@@ -0,0 +1,119 @@
import re
import sys
from typing import Any
from flask import current_app, got_request_exception
from flask_restful import Api, http_status_message # type: ignore
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
from core.errors.error import AppInvokeQuotaExceededError
class ExternalApi(Api):
def handle_error(self, e):
"""Error handler for the API transforms a raised exception into a Flask
response, with the appropriate HTTP status code and body.
:param e: the raised Exception object
:type e: Exception
"""
got_request_exception.send(current_app, exception=e)
headers = Headers()
if isinstance(e, HTTPException):
if e.response is not None:
resp = e.get_response()
return resp
status_code = e.code
default_data = {
"code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(),
"message": getattr(e, "description", http_status_message(status_code)),
"status": status_code,
}
if (
default_data["message"]
and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
):
default_data["message"] = "Invalid JSON payload received or JSON payload is empty."
headers = e.get_response().headers
elif isinstance(e, ValueError):
status_code = 400
default_data = {
"code": "invalid_param",
"message": str(e),
"status": status_code,
}
elif isinstance(e, AppInvokeQuotaExceededError):
status_code = 429
default_data = {
"code": "too_many_requests",
"message": str(e),
"status": status_code,
}
else:
status_code = 500
default_data = {
"message": http_status_message(status_code),
}
# Werkzeug exceptions generate a content-length header which is added
# to the response in addition to the actual content-length header
# https://github.com/flask-restful/flask-restful/issues/534
remove_headers = ("Content-Length",)
for header in remove_headers:
headers.pop(header, None)
data = getattr(e, "data", default_data)
error_cls_name = type(e).__name__
if error_cls_name in self.errors:
custom_data = self.errors.get(error_cls_name, {})
custom_data = custom_data.copy()
status_code = custom_data.get("status", 500)
if "message" in custom_data:
custom_data["message"] = custom_data["message"].format(
message=str(e.description if hasattr(e, "description") else e)
)
data.update(custom_data)
# record the exception in the logs when we have a server error of status code: 500
if status_code and status_code >= 500:
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
if status_code == 406 and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys()) # only supported application/json
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
data = {"code": "not_acceptable", "message": data.get("message")}
resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
elif status_code == 400:
if isinstance(data.get("message"), dict):
param_key, param_value = list(data.get("message", {}).items())[0]
data = {"code": "invalid_param", "message": param_value, "params": param_key}
else:
if "code" not in data:
data["code"] = "unknown"
resp = self.make_response(data, status_code, headers)
else:
if "code" not in data:
data["code"] = "unknown"
resp = self.make_response(data, status_code, headers)
if status_code == 401:
resp = self.unauthorized(resp)
return resp

View File

@@ -0,0 +1,241 @@
#
# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP
#
# ===================================================================
# The contents of this file are dedicated to the public domain. To
# the extent that dedication to the public domain is not available,
# everyone is granted a worldwide, perpetual, royalty-free,
# non-exclusive license to exercise all rights associated with the
# contents of this file for any purpose whatsoever.
# No rights are reserved.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ===================================================================
from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2 # type: ignore
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
from Crypto.Util.py3compat import _copy_bytes, bord
from Crypto.Util.strxor import strxor
class PKCS1OAepCipher:
"""Cipher object for PKCS#1 v1.5 OAEP.
Do not create directly: use :func:`new` instead."""
def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
"""Initialize this PKCS#1 OAEP cipher object.
:Parameters:
key : an RSA key object
If a private half is given, both encryption and decryption are possible.
If a public half is given, only encryption is possible.
hashAlgo : hash object
The hash function to use. This can be a module under `Crypto.Hash`
or an existing hash object created from any of such modules. If not specified,
`Crypto.Hash.SHA1` is used.
mgfunc : callable
A mask generation function that accepts two parameters: a string to
use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
label : bytes/bytearray/memoryview
A label to apply to this particular encryption. If not specified,
an empty string is used. Specifying a label does not improve
security.
randfunc : callable
A function that returns random bytes.
:attention: Modify the mask generation function only if you know what you are doing.
Sender and receiver must use the same one.
"""
self._key = key
if hashAlgo:
self._hashObj = hashAlgo
else:
self._hashObj = Crypto.Hash.SHA1
if mgfunc:
self._mgf = mgfunc
else:
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
self._label = _copy_bytes(None, None, label)
self._randfunc = randfunc
def can_encrypt(self):
"""Legacy function to check if you can call :meth:`encrypt`.
.. deprecated:: 3.0"""
return self._key.can_encrypt()
def can_decrypt(self):
"""Legacy function to check if you can call :meth:`decrypt`.
.. deprecated:: 3.0"""
return self._key.can_decrypt()
def encrypt(self, message):
"""Encrypt a message with PKCS#1 OAEP.
:param message:
The message to encrypt, also known as plaintext. It can be of
variable length, but not longer than the RSA modulus (in bytes)
minus 2, minus twice the hash output size.
For instance, if you use RSA 2048 and SHA-256, the longest message
you can encrypt is 190 byte long.
:type message: bytes/bytearray/memoryview
:returns: The ciphertext, as large as the RSA modulus.
:rtype: bytes
:raises ValueError:
if the message is too long.
"""
# See 7.1.1 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
mLen = len(message)
# Step 1b
ps_len = k - mLen - 2 * hLen - 2
if ps_len < 0:
raise ValueError("Plaintext is too long.")
# Step 2a
lHash = sha1(self._label).digest()
# Step 2b
ps = b"\x00" * ps_len
# Step 2c
db = lHash + ps + b"\x01" + _copy_bytes(None, None, message)
# Step 2d
ros = self._randfunc(hLen)
# Step 2e
dbMask = self._mgf(ros, k - hLen - 1)
# Step 2f
maskedDB = strxor(db, dbMask)
# Step 2g
seedMask = self._mgf(maskedDB, hLen)
# Step 2h
maskedSeed = strxor(ros, seedMask)
# Step 2i
em = b"\x00" + maskedSeed + maskedDB
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
def decrypt(self, ciphertext):
"""Decrypt a message with PKCS#1 OAEP.
:param ciphertext: The encrypted message.
:type ciphertext: bytes/bytearray/memoryview
:returns: The original message (plaintext).
:rtype: bytes
:raises ValueError:
if the ciphertext has the wrong length, or if decryption
fails the integrity check (in which case, the decryption
key is probably wrong).
:raises TypeError:
if the RSA key has no private half (i.e. you are trying
to decrypt using a public key).
"""
# See 7.1.2 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
# Step 1b and 1c
if len(ciphertext) != k or k < hLen + 2:
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
lHash = sha1(self._label).digest()
# Step 3b
y = em[0]
# y must be 0, but we MUST NOT check it here in order not to
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
maskedSeed = em[1 : hLen + 1]
maskedDB = em[hLen + 1 :]
# Step 3c
seedMask = self._mgf(maskedDB, hLen)
# Step 3d
seed = strxor(maskedSeed, seedMask)
# Step 3e
dbMask = self._mgf(seed, k - hLen - 1)
# Step 3f
db = strxor(maskedDB, dbMask)
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4
return db[one_pos + 1 :]
def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
"""Return a cipher object :class:`PKCS1OAEP_Cipher`
that can be used to perform PKCS#1 OAEP encryption or decryption.
:param key:
The key object to use to encrypt or decrypt the message.
Decryption is only possible with a private RSA key.
:type key: RSA key object
:param hashAlgo:
The hash function to use. This can be a module under `Crypto.Hash`
or an existing hash object created from any of such modules.
If not specified, `Crypto.Hash.SHA1` is used.
:type hashAlgo: hash object
:param mgfunc:
A mask generation function that accepts two parameters: a string to
use as seed, and the length of the mask to generate, in bytes.
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
:type mgfunc: callable
:param label:
A label to apply to this particular encryption. If not specified,
an empty string is used. Specifying a label does not improve
security.
:type label: bytes/bytearray/memoryview
:param randfunc:
A function that returns random bytes.
The default is `Random.get_random_bytes`.
:type randfunc: callable
"""
if randfunc is None:
randfunc = Random.get_random_bytes
return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)

View File

@@ -0,0 +1,311 @@
import json
import logging
import random
import re
import string
import subprocess
import time
import uuid
from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields # type: ignore
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from models.account import Account
def run(script):
return subprocess.getstatusoutput("source /root/.bashrc && " + script)
class AppIconUrlField(fields.Raw):
def output(self, key, obj):
if obj is None:
return None
from models.model import App, IconType, Site
if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
return file_helpers.get_signed_file_url(obj.icon)
return None
class AvatarUrlField(fields.Raw):
def output(self, key, obj):
if obj is None:
return None
from models.account import Account
if isinstance(obj, Account) and obj.avatar is not None:
return file_helpers.get_signed_file_url(obj.avatar)
return None
class TimestampField(fields.Raw):
def format(self, value) -> int:
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
# Check if the email matches the pattern
if re.match(pattern, email) is not None:
return email
error = "{email} is not a valid email.".format(email=email)
raise ValueError(error)
def uuid_value(value):
if value == "":
return str(value)
try:
uuid_obj = uuid.UUID(value)
return str(uuid_obj)
except ValueError:
error = "{value} is not a valid uuid.".format(value=value)
raise ValueError(error)
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):
return value
raise ValueError(f"{value} is not a valid alphanumeric value")
def timestamp_value(timestamp):
try:
int_timestamp = int(timestamp)
if int_timestamp < 0:
raise ValueError
return int_timestamp
except ValueError:
error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp)
raise ValueError(error)
class StrLen:
"""Restrict input to an integer in a range (inclusive)"""
def __init__(self, max_length, argument="argument"):
self.max_length = max_length
self.argument = argument
def __call__(self, value):
length = len(value)
if length > self.max_length:
error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format(
arg=self.argument, val=value, length=self.max_length
)
raise ValueError(error)
return value
class FloatRange:
"""Restrict input to an float in a range (inclusive)"""
def __init__(self, low, high, argument="argument"):
self.low = low
self.high = high
self.argument = argument
def __call__(self, value):
value = _get_float(value)
if value < self.low or value > self.high:
error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format(
arg=self.argument, val=value, lo=self.low, hi=self.high
)
raise ValueError(error)
return value
class DatetimeString:
def __init__(self, format, argument="argument"):
self.format = format
self.argument = argument
def __call__(self, value):
try:
datetime.strptime(value, self.format)
except ValueError:
error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format(
arg=self.argument, val=value, format=self.format
)
raise ValueError(error)
return value
def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError("{} is not a valid float".format(value))
def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string
error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string)
raise ValueError(error)
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
for i in range(n):
result += random.choice(letters_digits)
return result
def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"):
return cast(str, request.headers.get("Cf-Connecting-Ip"))
elif request.headers.getlist("X-Forwarded-For"):
return cast(str, request.headers.getlist("X-Forwarded-For")[0])
else:
return cast(str, request.remote_addr)
def generate_text_hash(text: str) -> str:
hash_text = str(text) + "None"
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json")
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
class TokenManager:
@classmethod
def generate_token(
cls,
token_type: str,
account: Optional["Account"] = None,
email: Optional[str] = None,
additional_data: Optional[dict] = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
if additional_data:
token_data.update(additional_data)
expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES")
if expiry_minutes is None:
raise ValueError(f"Expiry minutes for {token_type} token is not set")
token_key = cls._get_token_key(token, token_type)
expiry_time = int(expiry_minutes * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
if account_id:
cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
return token
@classmethod
def _get_token_key(cls, token: str, token_type: str) -> str:
return f"{token_type}:token:{token}"
@classmethod
def revoke_token(cls, token: str, token_type: str):
token_key = cls._get_token_key(token, token_type)
redis_client.delete(token_key)
@classmethod
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
return token_data
@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
current_token: Optional[str] = redis_client.get(key)
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type)
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(key, expiry_time, token)
@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
return f"{token_type}:account:{account_id}"
class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
def is_rate_limited(self, email: str) -> bool:
key = self._get_key(email)
current_time = int(time.time())
window_start_time = current_time - self.time_window
redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = redis_client.zcard(key)
if attempts and int(attempts) >= self.max_attempts:
return True
return False
def increment_rate_limit(self, email: str):
key = self._get_key(email)
current_time = int(time.time())
redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)

View File

@@ -0,0 +1,5 @@
class InfiniteScrollPagination:
def __init__(self, data, limit, has_more):
self.data = data
self.limit = limit
self.has_more = has_more

View File

@@ -0,0 +1,46 @@
import json
from core.llm_generator.output_parser.errors import OutputParserError
def parse_json_markdown(json_string: str) -> dict:
# Get json from the backticks/braces
json_string = json_string.strip()
starts = ["```json", "```", "``", "`", "{"]
ends = ["```", "``", "`", "}"]
end_index = -1
start_index = 0
parsed: dict = {}
for s in starts:
start_index = json_string.find(s)
if start_index != -1:
if json_string[start_index] != "{":
start_index += len(s)
break
if start_index != -1:
for e in ends:
end_index = json_string.rfind(e, start_index)
if end_index != -1:
if json_string[end_index] == "}":
end_index += 1
break
if start_index != -1 and end_index != -1 and start_index < end_index:
extracted_content = json_string[start_index:end_index].strip()
parsed = json.loads(extracted_content)
else:
raise ValueError("could not find json block in the output.")
return parsed
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserError(f"got invalid json object. error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserError(
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
)
return json_obj

View File

@@ -0,0 +1,107 @@
from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login import user_logged_in # type: ignore
from flask_login.config import EXEMPT_METHODS # type: ignore
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
from configs import dify_config
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user: Any = LocalProxy(lambda: _get_user())
def login_required(func):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
not, it calls the :attr:`LoginManager.unauthorized` callback.) For
example::
@app.route('/post')
@login_required
def post():
pass
If there are only certain times you need to require that your user is
logged in, you can do so with::
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()
...which is essentially the code that this function adds to your views.
It can be convenient to globally turn off authentication when unit testing.
To enable this, if the application configuration variable `LOGIN_DISABLED`
is set to `True`, this decorator will be ignored.
.. Note ::
Per `W3 guidelines for CORS preflight requests
<http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
HTTP ``OPTIONS`` requests are exempt from login checks.
:param func: The view function to decorate.
:type func: function
"""
@wraps(func)
def decorated_view(*args, **kwargs):
auth_header = request.headers.get("Authorization")
if dify_config.ADMIN_API_KEY_ENABLE:
if auth_header:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
admin_api_key = dify_config.ADMIN_API_KEY
if admin_api_key:
if admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == workspace_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role == "owner")
.one_or_none()
)
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore
return None

View File

@@ -0,0 +1,133 @@
import urllib.parse
from dataclasses import dataclass
from typing import Optional
import requests
@dataclass
class OAuthUserInfo:
id: str
name: str
email: str
class OAuth:
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
raise NotImplementedError()
def get_access_token(self, code: str):
raise NotImplementedError()
def get_raw_user_info(self, token: str):
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
raise NotImplementedError()
class GitHubOAuth(OAuth):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
return {**user_info, "email": primary_email.get("email", "")}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
class GoogleOAuth(OAuth):
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: Optional[str] = None):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in Google OAuth: {response_json}")
return access_token
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])

View File

@@ -0,0 +1,303 @@
import datetime
import urllib.parse
from typing import Any
import requests
from flask_login import current_user # type: ignore
from extensions.ext_database import db
from models.source import DataSourceOauthBinding
class OAuthDataSource:
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
raise NotImplementedError()
def get_access_token(self, code: str):
raise NotImplementedError()
class NotionOAuth(OAuthDataSource):
_AUTH_URL = "https://api.notion.com/v1/oauth/authorize"
_TOKEN_URL = "https://api.notion.com/v1/oauth/token"
_NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
def get_authorization_url(self):
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"owner": "user",
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in Notion OAuth: {response_json}")
workspace_name = response_json.get("workspace_name")
workspace_icon = response_json.get("workspace_icon")
workspace_id = response_json.get("workspace_id")
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
).first()
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str):
workspace_name = self.notion_workspace_name(access_token)
workspace_icon = None
workspace_id = current_user.current_tenant_id
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
).first()
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
).first()
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = data_source_binding.source_info
new_source_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
"total": len(pages),
}
data_source_binding.source_info = new_source_info
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str):
pages = []
page_results = self.notion_page_search(access_token)
database_results = self.notion_database_search(access_token)
# get page detail
for page_result in page_results:
page_id = page_result["id"]
page_name = "Untitled"
for key in page_result["properties"]:
if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]:
title_list = page_result["properties"][key]["title"]
if len(title_list) > 0 and "plain_text" in title_list[0]:
page_name = title_list[0]["plain_text"]
page_icon = page_result["icon"]
if page_icon:
icon_type = page_icon["type"]
if icon_type in {"external", "file"}:
url = page_icon[icon_type]["url"]
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
else:
icon = {"type": "emoji", "emoji": page_icon[icon_type]}
else:
icon = None
parent = page_result["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
elif parent_type == "workspace":
parent_id = "root"
else:
parent_id = parent[parent_type]
page = {
"page_id": page_id,
"page_name": page_name,
"page_icon": icon,
"parent_id": parent_id,
"type": "page",
}
pages.append(page)
# get database detail
for database_result in database_results:
page_id = database_result["id"]
if len(database_result["title"]) > 0:
page_name = database_result["title"][0]["plain_text"]
else:
page_name = "Untitled"
page_icon = database_result["icon"]
if page_icon:
icon_type = page_icon["type"]
if icon_type in {"external", "file"}:
url = page_icon[icon_type]["url"]
icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"}
else:
icon = {"type": icon_type, icon_type: page_icon[icon_type]}
else:
icon = None
parent = database_result["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
elif parent_type == "workspace":
parent_id = "root"
else:
parent_id = parent[parent_type]
page = {
"page_id": page_id,
"page_name": page_name,
"page_icon": icon,
"parent_id": parent_id,
"type": "database",
}
pages.append(page)
return pages
def notion_page_search(self, access_token: str):
results = []
next_cursor = None
has_more = True
while has_more:
data: dict[str, Any] = {
"filter": {"value": "page", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
has_more = response_json.get("has_more", False)
next_cursor = response_json.get("next_cursor", None)
return results
def notion_block_parent_page_id(self, access_token: str, block_id: str):
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"]
parent_type = parent["type"]
if parent_type == "block_id":
return self.notion_block_parent_page_id(access_token, parent[parent_type])
return parent[parent_type]
def notion_workspace_name(self, access_token: str):
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
user_info = response_json[user_type]
if "workspace_name" in user_info:
return user_info["workspace_name"]
return "workspace"
def notion_database_search(self, access_token: str):
results = []
next_cursor = None
has_more = True
while has_more:
data: dict[str, Any] = {
"filter": {"value": "database", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
has_more = response_json.get("has_more", False)
next_cursor = response_json.get("next_cursor", None)
return results

View File

@@ -0,0 +1,22 @@
import jwt
from werkzeug.exceptions import Unauthorized
from configs import dify_config
class PassportService:
def __init__(self):
self.sk = dify_config.SECRET_KEY
def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm="HS256")
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized("Invalid token signature.")
except jwt.exceptions.DecodeError:
raise Unauthorized("Invalid token.")
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized("Token has expired.")

View File

@@ -0,0 +1,26 @@
import base64
import binascii
import hashlib
import re
password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$"
def valid_password(password):
# Define a regex pattern for password rules
pattern = password_pattern
# Check if the password matches the pattern
if re.match(pattern, password) is not None:
return password
raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
def hash_password(password_str, salt_byte):
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
return binascii.hexlify(dk)
def compare_password(password_str, password_hashed_base64, salt_base64):
# compare password for login
return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64)

View File

@@ -0,0 +1,93 @@
import hashlib
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher
def generate_key_pair(tenant_id):
private_key = RSA.generate(2048)
public_key = private_key.publickey()
pem_private = private_key.export_key()
pem_public = public_key.export_key()
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
storage.save(filepath, pem_private)
return pem_public.decode()
prefix_hybrid = b"HYBRID:"
def encrypt(text, public_key):
if isinstance(public_key, str):
public_key = public_key.encode()
aes_key = get_random_bytes(16)
cipher_aes = AES.new(aes_key, AES.MODE_EAX)
ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode())
rsa_key = RSA.import_key(public_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
enc_aes_key = cipher_rsa.encrypt(aes_key)
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
return prefix_hybrid + encrypted_data
def get_decrypt_decoding(tenant_id):
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key)
if not private_key:
try:
private_key = storage.load(filepath)
except FileNotFoundError:
raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))
redis_client.setex(cache_key, 120, private_key)
rsa_key = RSA.import_key(private_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
return rsa_key, cipher_rsa
def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
if encrypted_text.startswith(prefix_hybrid):
encrypted_text = encrypted_text[len(prefix_hybrid) :]
enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()]
nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16]
tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32]
ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :]
aes_key = cipher_rsa.decrypt(enc_aes_key)
cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce)
decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag)
else:
decrypted_text = cipher_rsa.decrypt(encrypted_text)
return decrypted_text.decode()
def decrypt(encrypted_text, tenant_id):
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
class PrivkeyNotFoundError(Exception):
pass

View File

@@ -0,0 +1,52 @@
import logging
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
class SMTPClient:
def __init__(
self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False
):
self.server = server
self.port = port
self._from = _from
self.username = username
self.password = password
self.use_tls = use_tls
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
smtp = None
try:
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
smtp.starttls()
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
if self.username and self.password:
smtp.login(self.username, self.password)
msg = MIMEMultipart()
msg["Subject"] = mail["subject"]
msg["From"] = self._from
msg["To"] = mail["to"]
msg.attach(MIMEText(mail["html"], "html"))
smtp.sendmail(self._from, mail["to"], msg.as_string())
except smtplib.SMTPException as e:
logging.exception("SMTP error occurred")
raise
except TimeoutError as e:
logging.exception("Timeout occurred while sending email")
raise
except Exception as e:
logging.exception(f"Unexpected error occurred while sending email to {mail['to']}")
raise
finally:
if smtp:
smtp.quit()