Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
python -m pip install "setuptools<80.0.0"
find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing]"' \;
find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing,diskcache]"' \;

- name: Install dash-renderer dependencies
working-directory: dash/dash-renderer
Expand Down Expand Up @@ -231,7 +231,7 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
python -m pip install "setuptools<80.0.0"
find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev]"' \;
find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,diskcache]"' \;

- name: Build/Setup test components
run: npm run setup-tests.py # TODO build the packages and save them to packages/ in build job
Expand Down
47 changes: 35 additions & 12 deletions dash/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,13 @@ def _setup_background_callback(
return to_json(data)


def _progress_background_callback(response, callback_manager, background):
def _progress_background_callback(
response, callback_manager, background, cache_key=None
):
progress_outputs = background.get("progress")
adapter = get_app().backend.request_adapter()
cache_key = adapter.args.get("cacheKey")
if cache_key is None:
adapter = get_app().backend.request_adapter()
cache_key = adapter.args.get("cacheKey")

if progress_outputs:
# Get the progress before the result as it would be erased after the results.
Expand All @@ -467,21 +470,38 @@ def _progress_background_callback(response, callback_manager, background):


def _update_background_callback(
error_handler, callback_ctx, response, kwargs, background, multi
error_handler,
callback_ctx,
response,
kwargs,
background,
multi,
cache_key=None,
job_id=None,
):
"""Set up the background callback and manage jobs."""
callback_manager = _get_callback_manager(kwargs, background)

adapter = get_app().backend.request_adapter()
cache_key = adapter.args.get("cacheKey") if adapter else None
job_id = adapter.args.get("job") if adapter else None
if cache_key is None or job_id is None:
adapter = get_app().backend.request_adapter()
cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None)
job_id = job_id or (adapter.args.get("job") if adapter else None)

_progress_background_callback(response, callback_manager, background)
_progress_background_callback(
response, callback_manager, background, cache_key=cache_key
)

output_value = callback_manager.get_result(cache_key, job_id)

return _handle_rest_background_callback(
output_value, callback_manager, response, error_handler, callback_ctx, multi
output_value,
callback_manager,
response,
error_handler,
callback_ctx,
multi,
cache_key=cache_key,
job_id=job_id,
)


Expand All @@ -493,10 +513,13 @@ def _handle_rest_background_callback(
callback_ctx,
multi,
has_update=False,
cache_key=None,
job_id=None,
):
adapter = get_app().backend.request_adapter()
cache_key = adapter.args.get("cacheKey") if adapter else None
job_id = adapter.args.get("job") if adapter else None
if cache_key is None or job_id is None:
adapter = get_app().backend.request_adapter()
cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None)
job_id = job_id or (adapter.args.get("job") if adapter else None)
# Must get job_running after get_result since get_results terminates it.
job_running = callback_manager.job_running(job_id)
if not job_running and output_value is callback_manager.UNDEFINED:
Expand Down
1 change: 1 addition & 0 deletions dash/_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def load_dash_env_vars():
"DASH_MCP_ENABLED",
"DASH_MCP_PATH",
"DASH_MCP_EXPOSE_DOCSTRINGS",
"DASH_MCP_AUTHORIZATION_SERVER",
"HOST",
"PORT",
)
Expand Down
2 changes: 1 addition & 1 deletion dash/background_callback/managers/diskcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, cache=None, cache_by=None, expire=None):
is determined by the default behavior of the ``cache`` instance.
"""
try:
import diskcache # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel
import diskcache # type: ignore[import-not-found,import-untyped] # pylint: disable=import-outside-toplevel
import psutil # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable,import-error
import multiprocess # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable
except ImportError as missing_imports:
Expand Down
10 changes: 9 additions & 1 deletion dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches
enable_mcp: Optional[bool] = None,
mcp_path: Optional[str] = None,
mcp_expose_docstrings: Optional[bool] = None,
mcp_authorization_server: Optional[str] = None,
**obsolete,
):

Expand Down Expand Up @@ -609,6 +610,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches
self._mcp_path = (
_mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path
)
self._mcp_authorization_server = get_combined_config(
"mcp_authorization_server", mcp_authorization_server
)

# list of dependencies - this one is used by the back end for dispatching
self.callback_map: dict = {}
Expand Down Expand Up @@ -829,7 +833,11 @@ def _setup_routes(self):
)

try:
enable_mcp_server(self, self._mcp_path)
enable_mcp_server(
self,
self._mcp_path,
mcp_authorization_server=self._mcp_authorization_server,
)
except Exception as e: # pylint: disable=broad-exception-caught
self._enable_mcp = False
self.logger.warning(
Expand Down
69 changes: 67 additions & 2 deletions dash/mcp/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import json
import logging
import uuid
from functools import reduce
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from mcp.types import (
LATEST_PROTOCOL_VERSION,
Expand All @@ -32,6 +34,7 @@
list_tools,
read_resource,
)
from dash.mcp.tasks import get_task, get_task_result, cancel_task
from dash.mcp.primitives.tools.callback_adapter_collection import (
CallbackAdapterCollection,
)
Expand All @@ -44,7 +47,61 @@
logger = logging.getLogger(__name__)


def enable_mcp_server(app: Dash, mcp_path: str) -> None:
def _url_from_path(app: Dash, *parts: str) -> str:
"""Build an absolute URL by joining path parts onto the current request origin.

Behind a reverse proxy, TLS terminates at the proxy so
the scheme may report HTTP even when the client connected
over HTTPS. Use HTTPS unless running on localhost.
"""
from urllib.parse import urlparse # pylint: disable=import-outside-toplevel

adapter = app.backend.request_adapter()
parsed = urlparse(adapter.url)
host = parsed.netloc
is_localhost = host.startswith("localhost") or host.startswith("127.0.0.1")
scheme = "http" if is_localhost else "https"
path = reduce(urljoin, parts, "/")
return f"{scheme}://{host}{path}"


def _setup_mcp_oauth(app: Dash, mcp_path: str, mcp_authorization_server: str) -> None:
"""Register RFC 9728 Protected Resource Metadata endpoint for MCP.

Serves discovery metadata so MCP clients can find the authorization
server. Auth enforcement is the responsibility of the hosting platform
(e.g. Plotly Cloud gateway, Dash Embedded, or a reverse proxy).
"""
well_known_path = urljoin("/.well-known/oauth-protected-resource/", mcp_path)

def _serve_resource_metadata():
return app.backend.make_response(
json.dumps(
{
"resource": _url_from_path(
app, app.config.requests_pathname_prefix, mcp_path
),
"authorization_servers": [mcp_authorization_server],
"bearer_methods_supported": ["header"],
}
),
content_type="application/json",
)

# pylint: disable-next=protected-access
app._add_url(well_known_path.lstrip("/"), _serve_resource_metadata)

logger.info(
"MCP OAuth discovery enabled, authorization server: %s",
mcp_authorization_server,
)


def enable_mcp_server(
app: Dash,
mcp_path: str,
mcp_authorization_server: str | None = None,
) -> None:
"""Add MCP routes to a Dash app."""

app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS)
Expand Down Expand Up @@ -184,6 +241,9 @@ def _handle_not_allowed():
)
app.routes.append(mcp_url)

if mcp_authorization_server:
_setup_mcp_oauth(app, mcp_path, mcp_authorization_server)

logger.info(
"MCP routes registered at %s%s",
app.config.routes_pathname_prefix,
Expand Down Expand Up @@ -229,11 +289,16 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None:
"initialize": _handle_initialize,
"tools/list": list_tools,
"tools/call": lambda: call_tool(
params.get("name", ""), params.get("arguments", {})
tool_name=params.get("name", ""),
arguments=params.get("arguments", {}),
task=params.get("task"),
),
"resources/list": list_resources,
"resources/templates/list": list_resource_templates,
"resources/read": lambda: read_resource(params.get("uri", "")),
"tasks/get": lambda: get_task(task_id=params.get("taskId", "")),
"tasks/result": lambda: get_task_result(task_id=params.get("taskId", "")),
"tasks/cancel": lambda: cancel_task(task_id=params.get("taskId", "")),
}

try:
Expand Down
16 changes: 12 additions & 4 deletions dash/mcp/primitives/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

from typing import Any

from mcp.types import CallToolResult, ListToolsResult
from mcp.types import CallToolResult, CreateTaskResult, ListToolsResult

from dash.mcp.types import ToolNotFoundError

from .base import MCPToolProvider
from .tool_background_tasks import BackgroundTaskTools
from .tool_decorated_mcp_functions import DecoratedFunctionTools
from .tool_get_dash_component import GetDashComponentTool
from .tools_callbacks import CallbackTools

_TOOL_PROVIDERS: list[type[MCPToolProvider]] = [
CallbackTools,
BackgroundTaskTools,
GetDashComponentTool,
DecoratedFunctionTools,
]
Expand All @@ -28,11 +30,17 @@ def list_tools() -> ListToolsResult:
return ListToolsResult(tools=tools)


def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
"""Route a tools/call request by tool name."""
def call_tool(
tool_name: str, arguments: dict[str, Any], task: dict | None = None
) -> CallToolResult | CreateTaskResult:
"""Route a tools/call request by tool name.

The optional ``task`` parameter (per MCP Tasks protocol) is passed
through to providers that support background callbacks.
"""
for provider in _TOOL_PROVIDERS:
if tool_name in provider.get_tool_names():
return provider.call_tool(tool_name, arguments)
return provider.call_tool(tool_name, arguments, task=task)
raise ToolNotFoundError(
f"Tool not found: {tool_name}."
" The app's callbacks may have changed."
Expand Down
6 changes: 4 additions & 2 deletions dash/mcp/primitives/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any

from mcp.types import CallToolResult, Tool
from mcp.types import CallToolResult, CreateTaskResult, Tool


class MCPToolProvider:
Expand All @@ -24,5 +24,7 @@ def list_tools(cls) -> list[Tool]:
raise NotImplementedError

@classmethod
def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
def call_tool(
cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None
) -> CallToolResult | CreateTaskResult:
raise NotImplementedError
2 changes: 2 additions & 0 deletions dash/mcp/primitives/tools/descriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import TYPE_CHECKING

from .base import ToolDescriptionSource
from .description_background_callbacks import BackgroundCallbackDescription
from .description_docstring import DocstringDescription
from .description_outputs import OutputSummaryDescription

Expand All @@ -22,6 +23,7 @@
_SOURCES: list[type[ToolDescriptionSource]] = [
OutputSummaryDescription,
DocstringDescription,
BackgroundCallbackDescription,
]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Description for background (long-running) callbacks.

Informs the LLM that the tool returns a taskId immediately
and must be polled via the background task result tool.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from ..tool_background_tasks import GET_RESULT_TOOL_NAME
from .base import ToolDescriptionSource

if TYPE_CHECKING:
from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter


class BackgroundCallbackDescription(ToolDescriptionSource):
"""Add async polling instructions for background callbacks."""

@classmethod
def describe(cls, callback: CallbackAdapter) -> list[str]:
# pylint: disable-next=protected-access
if not callback._cb_info.get("background"):
return []

return [
"",
"This is a long-running background operation. "
"It returns a taskId immediately. "
f"Call tool `{GET_RESULT_TOOL_NAME}` with the taskId to poll for the result.",
]
Loading
Loading