Skip to content

Commit

Permalink
Fix/plugin race condition (langgenius#14253)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Feb 25, 2025
1 parent 42b13bd commit 490b6d0
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 41 deletions.
7 changes: 7 additions & 0 deletions api/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp


Expand All @@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())

# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()

return dify_app


Expand Down
19 changes: 15 additions & 4 deletions api/contexts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from threading import Lock
from typing import TYPE_CHECKING

from contexts.wrapper import RecyclableContextVar

if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
Expand All @@ -12,8 +14,17 @@

workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")

plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))

plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
65 changes: 65 additions & 0 deletions api/contexts/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from contextvars import ContextVar
from typing import Generic, TypeVar

T = TypeVar("T")


class HiddenValue:
pass


_default = HiddenValue()


class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
NOTE: you need to call `increment_thread_recycles` before requests
"""

_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")

@classmethod
def increment_thread_recycles(cls):
try:
recycles = cls._thread_recycles.get()
cls._thread_recycles.set(recycles + 1)
except LookupError:
cls._thread_recycles.set(0)

def __init__(self, context_var: ContextVar[T]):
self._context_var = context_var
self._updates = ContextVar[int](context_var.name + "_updates", default=0)

def get(self, default: T | HiddenValue = _default) -> T:
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)

# check if thread is recycled and should be updated
if thread_recycles < self_updates:
return self._context_var.get()
else:
# thread_recycles >= self_updates, means current context is invalid
if isinstance(default, HiddenValue) or default is _default:
raise LookupError
else:
return default

def set(self, value: T):
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
# increase it manually
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)

if self._updates.get() == self._thread_recycles.get(0):
# after increment,
self._updates.set(self._updates.get() + 1)

# set the context
self._context_var.set(value)
4 changes: 2 additions & 2 deletions api/core/agent/entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import StrEnum
from typing import Any, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field

from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType

Expand All @@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
provider_type: ToolProviderType
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any

from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager


Expand Down Expand Up @@ -61,9 +61,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) ->
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

if "/" not in config["model"]["provider"]:
config["model"]["provider"] = (
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))

if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
Expand Down
8 changes: 4 additions & 4 deletions api/core/app/app_config/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
provider: str
model: str
mode: Optional[str] = None
parameters: dict[str, Any] = {}
stop: list[str] = []
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)


class AdvancedChatMessageEntity(BaseModel):
Expand Down Expand Up @@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):

variable: str
type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)


class DatasetRetrieveConfigEntity(BaseModel):
Expand Down Expand Up @@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""

type: str
config: dict[str, Any] = {}
config: dict[str, Any] = Field(default_factory=dict)


class TextToSpeechEntity(BaseModel):
Expand Down
8 changes: 4 additions & 4 deletions api/core/app/entities/app_invoke_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
credentials: dict[str, Any] = Field(default_factory=dict)
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)

# pydantic configs
model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
call_depth: int = 0

# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
extras: dict[str, Any] = Field(default_factory=dict)

# tracing instance
trace_manager: Optional[TraceQueueManager] = None
Expand Down
9 changes: 4 additions & 5 deletions api/core/entities/provider_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from json import JSONDecodeError
from typing import Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_

from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
Expand Down Expand Up @@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
"""

tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)

def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
Expand Down Expand Up @@ -1060,7 +1059,7 @@ def to_list(self) -> list[ProviderConfiguration]:

def __getitem__(self, key):
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))

return self.configurations[key]

Expand All @@ -1075,7 +1074,7 @@ def values(self) -> Iterator[ProviderConfiguration]:

def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
key = str(ModelProviderID(key))

return self.configurations.get(key, default) # type: ignore

Expand Down
6 changes: 5 additions & 1 deletion api/core/hosting_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):


class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None

def __init__(self) -> None:
self.provider_map = {}
self.moderation_config = None

def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD":
return
Expand Down
15 changes: 5 additions & 10 deletions api/core/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pydantic import BaseModel

import contexts
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
Expand All @@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):


class ModelProviderFactory:
provider_position_map: dict[str, int] = {}
provider_position_map: dict[str, int]

def __init__(self, tenant_id: str) -> None:
self.provider_position_map = {}

self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelManager()

Expand Down Expand Up @@ -360,11 +361,5 @@ def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[
:param provider: provider name
:return: plugin id and provider name
"""
plugin_id = DEFAULT_PLUGIN_ID
provider_name = provider
if "/" in provider:
# get the plugin_id before provider
plugin_id = "/".join(provider.split("/")[:-1])
provider_name = provider.split("/")[-1]

return str(plugin_id), provider_name
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name
10 changes: 3 additions & 7 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from werkzeug.exceptions import NotFound

from configs import dify_config
from core.entities import DEFAULT_PLUGIN_ID
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted
Expand Down Expand Up @@ -328,14 +328,10 @@ def update_dataset(dataset_id, data, user):
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
plugin_model_provider = dataset.embedding_model_provider
if "/" not in plugin_model_provider:
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
plugin_model_provider = str(ModelProviderID(plugin_model_provider))

new_plugin_model_provider = data["embedding_model_provider"]
if "/" not in new_plugin_model_provider:
new_plugin_model_provider = (
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
)
new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))

if (
new_plugin_model_provider != plugin_model_provider
Expand Down

0 comments on commit 490b6d0

Please sign in to comment.