diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index ca6816ccf9..84f553978c 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -2,7 +2,9 @@ ## Python Backend Framework -- **`dash/dash.py`** - Main `Dash` application class (~2000 lines). Orchestrates Flask server, layout management, callback registration, routing, and asset serving. Key methods: `layout` property, `callback()`, `clientside_callback()`, `run()`. +- **`dash/dash.py`** - Main `Dash` application class (~2000 lines). Orchestrates the server backend, layout management, callback registration, routing, and asset serving. Key methods: `layout` property, `callback()`, `clientside_callback()`, `run()`. + +- **`dash/backends/`** - Server backend implementations. See [Server Backends](#server-backends) section for details. - **`dash/_callback.py`** - Callback registration and execution. Contains `callback()` decorator (usable as `@dash.callback` without app instance), `clientside_callback()`, and `register_callback()` which inserts callbacks into the callback map. @@ -101,6 +103,127 @@ Use dict IDs with wildcards (`MATCH`, `ALL`, `ALLSMALLER`) to target dynamically - `/_dash-component-suites//` - Serves component JS/CSS assets - `/assets/` - Serves static assets from app's assets folder +## Server Backends + +Dash supports multiple web server backends. The backend abstraction is in `dash/backends/`. + +### Available Backends + +| Backend | Type | Install | Use Case | +|---------|------|---------|----------| +| **Flask** (default) | WSGI (sync) | `pip install dash` | Standard deployments, simplicity | +| **Quart** | ASGI (async) | `pip install dash[quart]` | Async callbacks, WebSocket support | +| **FastAPI** | ASGI (async) | `pip install dash[fastapi]` | OpenAPI docs, async, modern Python | + +### Usage + +**Default (Flask):** +```python +from dash import Dash +app = Dash(__name__) +``` + +**With existing server instance:** +```python +from flask import Flask +from dash import Dash + +server = Flask(__name__) +app = Dash(__name__, server=server) +``` + +**Quart backend:** +```python +from quart import Quart +from dash import Dash + +server = Quart(__name__) +app = Dash(__name__, server=server) +``` + +**FastAPI backend:** +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server) + +# Run with: uvicorn module:app.server --reload +``` + +### Architecture + +The backend system uses an abstract interface: + +- **`BaseDashServer`** (`dash/backends/base_server.py`) - Abstract base class defining the server interface. All backends implement this. + +- **`RequestAdapter`** - Normalizes HTTP request objects across frameworks. Provides unified access to `args`, `cookies`, `headers`, `get_json()`, etc. + +- **`ResponseAdapter`** - Normalizes response creation. Handles `set_cookie()`, `set_header()`, `set_response()`. + +- **`get_backend(name)`** - Factory function to get backend class by name (`"flask"`, `"quart"`, `"fastapi"`). + +- **`get_server_type(server)`** - Auto-detects backend from a server instance. + +### Backend Implementations + +**Flask** (`dash/backends/_flask.py`): +- `FlaskDashServer` - Wraps Flask app +- `FlaskRequestAdapter` - Uses `flask.request` proxy +- `FlaskResponseAdapter` - Uses `flask.Response` +- Compression via `flask-compress` + +**Quart** (`dash/backends/_quart.py`): +- `QuartDashServer` - Wraps Quart app (async Flask API) +- `QuartRequestAdapter` - Uses `quart.request` proxy +- `QuartResponseAdapter` - Uses `quart.Response` +- All route handlers are `async def` +- Compression via `quart-compress` + +**FastAPI** (`dash/backends/_fastapi.py`): +- `FastAPIDashServer` - Wraps FastAPI app +- `FastAPIRequestAdapter` - Uses context variable for current request +- `FastAPIResponseAdapter` - Uses Starlette responses +- `DashMiddleware` - Consolidated ASGI middleware for request handling +- Runs with uvicorn, supports hot reload +- Built-in GZip compression + +### Key Interface Methods + +All backends implement: + +```python +class BaseDashServer(ABC): + def create_app(name, config) -> server # Create new server + def add_url_rule(rule, view_func, ...) # Register routes + def before_request(func) # Request hooks + def after_request(func) # Response hooks + def run(dash_app, host, port, debug) # Start dev server + def make_response(data, mimetype, status) # Create response + def jsonify(obj) # JSON response + def setup_index(dash_app) # Register / route + def serve_callback(dash_app) # Callback endpoint + def setup_component_suites(dash_app) # JS/CSS serving +``` + +### Accessing the Backend + +```python +app = Dash(__name__) + +# Get the underlying server +app.server # Flask/Quart/FastAPI instance + +# Get the backend wrapper +app.backend # BaseDashServer subclass instance +app.backend.server_type # "flask", "quart", or "fastapi" + +# Access request in callbacks +from dash import dash +dash.get_app().backend.request_adapter() # RequestAdapter instance +``` + ## Frontend (dash-renderer) **`dash/dash-renderer/src/`** contains the TypeScript/React frontend. See [RENDERER.md](RENDERER.md) for detailed documentation on: @@ -503,14 +626,14 @@ app.run( ### How It Works 1. `app.run()` detects Jupyter environment via `get_ipython()` -2. Flask server starts in background daemon thread +2. Server starts in background daemon thread 3. Jupyter comm protocol negotiates proxy configuration 4. App displays according to selected mode ``` app.run() in notebook ↓ -Detect Jupyter → Start Flask in background thread +Detect Jupyter → Start server in background thread ↓ Comm request → Extension responds with base_url ↓ @@ -572,8 +695,8 @@ Special handling for Colab: ### Dash() Constructor Parameters **Basic Setup:** -- `name` - Flask app name (default: infers from `__name__`) -- `server` - Flask instance or `True` to create new (default: `True`) +- `name` - Application name (default: infers from `__name__`) +- `server` - Server instance (Flask, Quart, or FastAPI) or `True` to create Flask (default: `True`) - `title` - Browser tab title (default: `"Dash"`) - `update_title` - Title during callbacks (default: `"Updating..."`) @@ -600,6 +723,11 @@ Special handling for Colab: - `background_callback_manager` - DiskcacheManager or CeleryManager - `on_error` - Global callback error handler +**WebSocket Callbacks:** +- `websocket_callbacks` - Enable WebSocket for all callbacks (default: `False`). Requires FastAPI backend. +- `websocket_allowed_origins` - List of allowed origins for WebSocket connections +- `websocket_inactivity_timeout` - Disconnect WebSocket after inactivity period in ms (default: `300000` = 5 minutes). Set to `0` to disable. + ### app.run() Parameters - `host` - Server IP (default: `"127.0.0.1"`, env: `HOST`) @@ -682,6 +810,7 @@ Dash supports `async def` callbacks for non-blocking execution. ### Setup +**With Flask backend:** ```bash pip install dash[async] ``` @@ -692,6 +821,16 @@ Async is auto-enabled when `asgiref` is detected. Or explicitly: app = Dash(__name__, use_async=True) ``` +**With Quart or FastAPI backend:** Async is native - no extra dependencies needed. + +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server) # Async works automatically +``` + ### Usage ```python @@ -708,7 +847,8 @@ async def async_update(value): - Regular async callbacks are **non-blocking** - multiple can run concurrently - Background callbacks also support `async def` - Jupyter uses `nest_asyncio` for event loop compatibility -- Without `dash[async]`, coroutines raise an error +- With Flask backend: requires `dash[async]`, coroutines raise error without it +- With Quart/FastAPI backends: async is native, no extra setup needed ### Async with Background Callbacks @@ -726,6 +866,177 @@ async def async_background(n_clicks): Both DiskcacheManager and CeleryManager support async functions via `asyncio.run()`. +## WebSocket Callbacks + +WebSocket callbacks use a persistent WebSocket connection instead of HTTP POST for callback execution. This reduces latency and connection overhead for applications with frequent callbacks. + +### Requirements + +- **FastAPI backend required**: WebSocket callbacks only work with FastAPI +- **SharedWorker support**: Modern browsers (not IE) + +### Usage + +**Enable globally for all callbacks:** +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server, websocket_callbacks=True) +``` + +**Enable per-callback:** +```python +@app.callback( + Output('output', 'children'), + Input('input', 'value'), + websocket=True # Use WebSocket for this callback only +) +def update(value): + return f"Value: {value}" +``` + +### Configuration + +```python +app = Dash( + __name__, + server=server, + websocket_callbacks=True, + websocket_inactivity_timeout=300000, # 5 minutes (default) + websocket_allowed_origins=['https://example.com'], +) +``` + +- **`websocket_callbacks`** - Enable WebSocket for all callbacks (default: `False`) +- **`websocket_inactivity_timeout`** - Close WebSocket after period of inactivity in milliseconds (default: `300000` = 5 minutes). Heartbeats do not count as activity. Set to `0` to disable timeout. Connection automatically reconnects when needed. +- **`websocket_allowed_origins`** - List of allowed origins for WebSocket connections (security) + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Browser Tab 1 Browser Tab 2 │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ Renderer │ │ Renderer │ │ +│ └──────┬──────┘ └──────┬──────┘ │ +│ │ postMessage │ postMessage │ +│ └────────────┬───────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────┐ │ +│ │ SharedWorker │ (one per origin) │ +│ │ dash-ws-worker │ │ +│ └──────────┬──────────┘ │ +└────────────────────│────────────────────────────────────────────────────┘ + │ WebSocket + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Server (FastAPI) │ +│ WebSocket Endpoint: /_dash-ws-callback │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +**Connection & Reconnection Flow:** +``` +Renderer SharedWorker Server + │ │ │ + │──[CONNECT]──────────────────>│ │ + │ │──[WebSocket Connect]──>│ + │<─[CONNECTED]─────────────────│<─[Connected]───────────│ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[callback request]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[callback response]───│ + │ │ │ + │ (inactivity) │ (heartbeat check) │ + │ │──[close 4001]─────────>│ + │<─[DISCONNECTED]──────────────│ │ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[reconnect + send]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[response]────────────│ +``` + +- **SharedWorker**: Single WebSocket connection shared across browser tabs +- **Heartbeat**: Periodic ping/pong to detect dead connections (30s interval) +- **Inactivity timeout**: Closes connection after no actual callback activity (not heartbeats) +- **Auto-reconnect**: Reconnects automatically when a callback is triggered after timeout + +### Long-Running Callbacks with set_props/get_props + +WebSocket callbacks can stream updates to the client during execution using `set_props()` and read current component values using `ctx.websocket`: + +```python +import asyncio +from dash import callback, Output, Input, set_props, ctx + +@callback( + Output('result', 'children'), + Input('start-btn', 'n_clicks'), + prevent_initial_call=True +) +async def long_running_task(n_clicks): + ws = ctx.websocket + if not ws: + return "WebSocket not available" + + # Stream progress updates to the client + for i in range(100): + await asyncio.sleep(0.1) + set_props('progress-bar', {'value': i + 1}) + set_props('status', {'children': f'Processing step {i + 1}/100...'}) + + # Read current value from another component + current_value = await ws.get_prop('input-field', 'value') + + return f"Completed! Input was: {current_value}" +``` + +**API:** +- `set_props(component_id, props_dict)` - Stream prop updates immediately to client +- `ctx.websocket` - Get WebSocket interface (returns `None` if not in WS context) +- `await ws.get_prop(component_id, prop_name)` - Read current prop value from client +- `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version) +- `await ws.close(code, reason)` - Close the WebSocket connection + +### Connection Hooks + +Use hooks to validate connections and messages: + +```python +from dash import Dash, hooks + +@hooks.websocket_connect() +async def validate_connection(websocket): + """Validate WebSocket connection before accepting.""" + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True # Allow connection + +@hooks.websocket_message() +async def validate_message(websocket, message): + """Validate each WebSocket message.""" + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True # Allow message +``` + +**Hook Return Values:** +- `True` (or truthy) - Allow connection/message +- `False` - Reject with default code (4001) +- `(code, reason)` - Reject with custom close code and reason + +### Key Files + +- `dash/dash.py` - WebSocket config in `_generate_config()` +- `dash/dash-renderer/src/utils/workerClient.ts` - Browser-side SharedWorker client +- `@plotly/dash-websocket-worker/src/WebSocketManager.ts` - WebSocket connection management +- `@plotly/dash-websocket-worker/src/worker.ts` - SharedWorker entry point +- `dash/backends/_fastapi.py` - Server-side WebSocket handler + ## Security ### XSS Protection diff --git a/.ai/COMMANDS.md b/.ai/COMMANDS.md index 4c155d87b7..9671a6c6cb 100644 --- a/.ai/COMMANDS.md +++ b/.ai/COMMANDS.md @@ -14,6 +14,19 @@ pip install -e .[ci,dev,testing,celery,diskcache] npm ci ``` +### Optional Backend Dependencies + +```bash +# For Quart backend (ASGI async) +pip install dash[quart] + +# For FastAPI backend (ASGI async) +pip install dash[fastapi] + +# For async callbacks with Flask +pip install dash[async] +``` + ## Building ```bash diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 315499f2b0..895a0bd829 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,8 +16,10 @@ jobs: outputs: table_paths_changed: ${{ steps.filter.outputs.table_related_paths }} background_cb_changed: ${{ steps.filter.outputs.background_paths }} + backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} dcc_paths_changed: ${{ steps.filter.outputs.dcc_related_paths }} html_paths_changed: ${{ steps.filter.outputs.html_related_paths }} + websocket_changed: ${{ steps.filter.outputs.websocket_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -44,6 +46,21 @@ jobs: - 'tests/background_callback/**' - 'tests/async_tests/**' - 'requirements/**' + backend_paths: + - 'dash/backends/**' + - 'tests/backend_tests/**' + websocket_paths: + - 'dash/backends/_fastapi.py' + - 'dash/backends/_quart.py' + - 'dash/backends/base_server.py' + - 'dash/_callback.py' + - 'dash/_callback_context.py' + - 'dash/_hooks.py' + - 'dash/dash.py' + - '@dash-websocket-worker/**' + - 'dash/dash-renderer/src/**' + - 'tests/websocket/**' + - 'requirements/**' lint-unit: name: Lint & Unit Tests (Python ${{ matrix.python-version }}) @@ -329,6 +346,80 @@ jobs: path: bgtests/test-reports/ retention-days: 7 + backend-tests: + name: Run Backend Callback Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.backend_cb_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + services: + redis: + image: redis:6 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + REDIS_URL: redis://localhost:6379 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '24' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: requirements/*.txt + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<80.0.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; + + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run Backend Callback Tests + run: | + mkdir bgtests + cp -r tests bgtests/tests + cd bgtests + touch __init__.py + pytest --headless --nopercyfinalize tests/backend_tests -v -s + table-unit: name: Table Unit/Lint Tests (Python ${{ matrix.python-version }}) needs: [build, changes_filter] @@ -451,6 +542,67 @@ jobs: path: components/dash-table/test-reports/ retention-days: 7 + websocket-tests: + name: WebSocket Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.websocket_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '24' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: requirements/*.txt + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<80.0.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,fastapi,quart]"' \; + + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run WebSocket tests + run: | + mkdir wstests + cp -r tests wstests/tests + cd wstests + touch __init__.py + pytest --headless --nopercyfinalize tests/websocket -v -s + test-main: name: Main Dash Tests (Python ${{ matrix.python-version }}, Group ${{ matrix.test-group }}) needs: build diff --git a/.gitignore b/.gitignore index 89029448fe..06e855e2dc 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,4 @@ packages/ !components/dash-core-components/tests/integration/upload/upload-assets/upft001.csv !components/dash-table/tests/assets/*.csv !components/dash-table/tests/selenium/assets/*.csv +dash_config.json diff --git a/@plotly/dash-websocket-worker/README.md b/@plotly/dash-websocket-worker/README.md new file mode 100644 index 0000000000..64e37a1987 --- /dev/null +++ b/@plotly/dash-websocket-worker/README.md @@ -0,0 +1,3 @@ +# Dash websocket worker + +Worker for websocket based callbacks. diff --git a/@plotly/dash-websocket-worker/package.json b/@plotly/dash-websocket-worker/package.json new file mode 100644 index 0000000000..619a842380 --- /dev/null +++ b/@plotly/dash-websocket-worker/package.json @@ -0,0 +1,29 @@ +{ + "name": "@plotly/dash-websocket-worker", + "version": "1.0.0", + "description": "SharedWorker for WebSocket-based Dash callbacks", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "build": "webpack --mode production", + "build:dev": "webpack --mode development", + "watch": "webpack --mode development --watch", + "clean": "rm -rf dist" + }, + "files": [ + "dist" + ], + "keywords": [ + "dash", + "websocket", + "sharedworker" + ], + "author": "Plotly", + "license": "MIT", + "devDependencies": { + "typescript": "^5.0.0", + "webpack": "^5.0.0", + "webpack-cli": "^5.0.0", + "ts-loader": "^9.0.0" + } +} diff --git a/@plotly/dash-websocket-worker/src/MessageRouter.ts b/@plotly/dash-websocket-worker/src/MessageRouter.ts new file mode 100644 index 0000000000..68a9f4bfc2 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/MessageRouter.ts @@ -0,0 +1,207 @@ +import { + WorkerMessageType, + WorkerMessage, + CallbackRequestMessage, + GetPropsResponseMessage, + SetPropsMessage, + GetPropsRequestMessage, + CallbackResponseMessage +} from './types'; + +/** + * Routes messages between renderers (via MessagePorts) and the WebSocket server. + */ +export class MessageRouter { + /** Map of renderer IDs to their MessagePorts */ + private renderers: Map = new Map(); + + /** Callback to send messages to the WebSocket server */ + public sendToServer: ((message: unknown) => void) | null = null; + + /** + * Register a renderer with its MessagePort. + * @param rendererId Unique identifier for the renderer + * @param port The MessagePort for communication + */ + public registerRenderer(rendererId: string, port: MessagePort): void { + this.renderers.set(rendererId, port); + } + + /** + * Unregister a renderer. + * @param rendererId The renderer to unregister + */ + public unregisterRenderer(rendererId: string): void { + this.renderers.delete(rendererId); + } + + /** + * Get the number of connected renderers. + */ + public get rendererCount(): number { + return this.renderers.size; + } + + /** + * Handle a message from a renderer. + * @param rendererId The ID of the renderer that sent the message + * @param message The message from the renderer + */ + public handleRendererMessage(rendererId: string, message: WorkerMessage): void { + switch (message.type) { + case WorkerMessageType.CALLBACK_REQUEST: + this.forwardCallbackRequest(rendererId, message as CallbackRequestMessage); + break; + + case WorkerMessageType.GET_PROPS_RESPONSE: + this.forwardGetPropsResponse(rendererId, message as GetPropsResponseMessage); + break; + + default: + console.warn(`Unknown message type from renderer: ${message.type}`); + } + } + + /** + * Handle a message from the WebSocket server. + * @param message The message from the server + */ + public handleServerMessage(message: unknown): void { + const msg = message as WorkerMessage; + const rendererId = msg.rendererId; + + switch (msg.type) { + case WorkerMessageType.CALLBACK_RESPONSE: + this.forwardToRenderer(rendererId, msg as CallbackResponseMessage); + break; + + case WorkerMessageType.SET_PROPS: + this.forwardSetProps(rendererId, msg as SetPropsMessage); + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + this.forwardGetPropsRequest(rendererId, msg as GetPropsRequestMessage); + break; + + case WorkerMessageType.ERROR: + this.forwardToRenderer(rendererId, msg); + break; + + default: + console.warn(`Unknown message type from server: ${msg.type}`); + } + } + + /** + * Send a message to all connected renderers. + * @param message The message to broadcast + */ + public broadcastToRenderers(message: WorkerMessage): void { + for (const [rendererId, port] of this.renderers) { + try { + port.postMessage(message); + } catch (error) { + // Port may be closed if tab was closed + console.warn(`Failed to send to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a connected notification to a specific renderer. + * @param rendererId The renderer to notify + */ + public notifyConnected(rendererId: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.CONNECTED, + rendererId + }); + } catch (error) { + console.warn(`Failed to notify renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a disconnected notification to all renderers. + * @param reason Optional reason for disconnection + */ + public notifyDisconnected(reason?: string): void { + this.broadcastToRenderers({ + type: WorkerMessageType.DISCONNECTED, + rendererId: '', + payload: { reason } + }); + } + + /** + * Send an error notification to a specific renderer. + * @param rendererId The renderer to notify + * @param message Error message + * @param code Optional error code + */ + public notifyError(rendererId: string, message: string, code?: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.ERROR, + rendererId, + payload: { message, code } + }); + } catch (error) { + console.warn(`Failed to send error to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + private forwardCallbackRequest(rendererId: string, message: CallbackRequestMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardGetPropsResponse(rendererId: string, message: GetPropsResponseMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardToRenderer(rendererId: string, message: WorkerMessage): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage(message); + } catch (error) { + console.warn(`Failed to forward to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } else { + console.warn(`Renderer ${rendererId} not found for message`); + } + } + + private forwardSetProps(rendererId: string, message: SetPropsMessage): void { + this.forwardToRenderer(rendererId, message); + } + + private forwardGetPropsRequest(rendererId: string, message: GetPropsRequestMessage): void { + this.forwardToRenderer(rendererId, message); + } +} diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts new file mode 100644 index 0000000000..2d32c7e8ca --- /dev/null +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -0,0 +1,304 @@ +/** + * Configuration options for WebSocket connection. + */ +interface WebSocketConfig { + /** Maximum number of reconnection attempts */ + maxRetries: number; + /** Initial delay between reconnection attempts (ms) */ + initialRetryDelay: number; + /** Maximum delay between reconnection attempts (ms) */ + maxRetryDelay: number; + /** Heartbeat interval (ms) */ + heartbeatInterval: number; + /** Heartbeat timeout (ms) */ + heartbeatTimeout: number; + /** Inactivity timeout (ms) - 0 to disable */ + inactivityTimeout: number; +} + +const DEFAULT_CONFIG: WebSocketConfig = { + maxRetries: 10, + initialRetryDelay: 1000, + maxRetryDelay: 30000, + heartbeatInterval: 30000, + heartbeatTimeout: 10000, + inactivityTimeout: 300000 // 5 minutes default +}; + +/** + * Manages WebSocket connection with automatic reconnection and heartbeat. + */ +export class WebSocketManager { + private ws: WebSocket | null = null; + private serverUrl: string | null = null; + private config: WebSocketConfig; + private retryCount = 0; + private retryTimeout: ReturnType | null = null; + private heartbeatInterval: ReturnType | null = null; + private heartbeatTimeout: ReturnType | null = null; + private lastActivityTime: number = Date.now(); + private messageQueue: string[] = []; + private isConnecting = false; + + /** Callback when connection is established */ + public onOpen: (() => void) | null = null; + /** Callback when connection is closed */ + public onClose: ((reason?: string) => void) | null = null; + /** Callback when a message is received */ + public onMessage: ((data: unknown) => void) | null = null; + /** Callback when an error occurs */ + public onError: ((error: Error) => void) | null = null; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + /** + * Update configuration options. + * Only updates the provided options, keeping others unchanged. + * @param config Partial configuration to merge + */ + public setConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + } + + /** + * Connect to the WebSocket server. + * @param serverUrl The WebSocket server URL + */ + public connect(serverUrl: string): void { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + // Already connected + return; + } + + if (this.isConnecting) { + // Connection in progress + return; + } + + this.serverUrl = serverUrl; + this.isConnecting = true; + // Reset retry count since this is an explicit connect request + // (e.g., from hot reload reconnection) + this.retryCount = 0; + this.createConnection(); + } + + /** + * Disconnect from the WebSocket server. + */ + public disconnect(): void { + this.cleanup(); + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(1000, 'Client disconnect'); + } + this.ws = null; + this.serverUrl = null; + this.retryCount = 0; + } + + /** + * Send a message through the WebSocket connection. + * If not connected, queues the message and triggers reconnection. + * @param message The message to send + */ + public send(message: unknown): void { + const data = JSON.stringify(message); + + // Track activity for non-heartbeat messages + const msgObj = message as { type?: string }; + if (msgObj.type !== 'heartbeat') { + this.lastActivityTime = Date.now(); + } + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(data); + } else { + // Queue message for when connection is established + this.messageQueue.push(data); + + // Trigger reconnect if we have a server URL but aren't connected/connecting + if (this.serverUrl && !this.isConnecting) { + this.isConnecting = true; + // Reset retry count since this is user-initiated activity + this.retryCount = 0; + this.createConnection(); + } + } + } + + /** + * Check if the WebSocket is currently connected. + */ + public get isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + private createConnection(): void { + if (!this.serverUrl) { + return; + } + + try { + this.ws = new WebSocket(this.serverUrl); + this.ws.onopen = this.handleOpen.bind(this); + this.ws.onclose = this.handleClose.bind(this); + this.ws.onmessage = this.handleMessage.bind(this); + this.ws.onerror = this.handleError.bind(this); + } catch (error) { + this.isConnecting = false; + this.scheduleReconnect(); + } + } + + private handleOpen(): void { + this.isConnecting = false; + this.retryCount = 0; + this.lastActivityTime = Date.now(); + + // Flush queued messages + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message && this.ws) { + this.ws.send(message); + } + } + + // Start heartbeat (also handles inactivity check) + this.startHeartbeat(); + + if (this.onOpen) { + this.onOpen(); + } + } + + private handleClose(event: CloseEvent): void { + this.isConnecting = false; + this.cleanup(); + + const reason = event.reason || 'Connection closed'; + + if (this.onClose) { + this.onClose(reason); + } + + // Only reconnect if: + // - We haven't explicitly disconnected (code 1000) + // - It's not an inactivity timeout (code 4001) + if (this.serverUrl && event.code !== 1000 && event.code !== 4001) { + this.scheduleReconnect(); + } + } + + private handleMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + + // Handle heartbeat acknowledgment - does NOT count as activity + if (data.type === 'heartbeat_ack') { + this.clearHeartbeatTimeout(); + return; + } + + // Track activity for actual callback messages + this.lastActivityTime = Date.now(); + + if (this.onMessage) { + this.onMessage(data); + } + } catch (error) { + if (this.onError) { + this.onError(new Error('Failed to parse message')); + } + } + } + + private handleError(): void { + this.isConnecting = false; + // WebSocket error events don't contain useful information + // The close event will follow with more details + } + + private scheduleReconnect(): void { + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + } + + if (this.retryCount >= this.config.maxRetries) { + if (this.onError) { + this.onError(new Error('Max reconnection attempts reached')); + } + return; + } + + // Exponential backoff with jitter + const delay = Math.min( + this.config.initialRetryDelay * Math.pow(2, this.retryCount) + + Math.random() * 1000, + this.config.maxRetryDelay + ); + + this.retryCount++; + + this.retryTimeout = setTimeout(() => { + this.createConnection(); + }, delay); + } + + private startHeartbeat(): void { + this.stopHeartbeat(); + + this.heartbeatInterval = setInterval(() => { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return; + } + + // Check for inactivity timeout + if (this.config.inactivityTimeout > 0) { + const timeSinceActivity = Date.now() - this.lastActivityTime; + if (timeSinceActivity >= this.config.inactivityTimeout) { + this.ws.close(4001, 'Inactivity timeout'); + return; + } + } + + this.ws.send(JSON.stringify({ type: 'heartbeat' })); + this.setHeartbeatTimeout(); + }, this.config.heartbeatInterval); + } + + private stopHeartbeat(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + this.clearHeartbeatTimeout(); + } + + private setHeartbeatTimeout(): void { + this.clearHeartbeatTimeout(); + + this.heartbeatTimeout = setTimeout(() => { + // Heartbeat timeout - connection may be dead + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(4000, 'Heartbeat timeout'); + } + }, this.config.heartbeatTimeout); + } + + private clearHeartbeatTimeout(): void { + if (this.heartbeatTimeout) { + clearTimeout(this.heartbeatTimeout); + this.heartbeatTimeout = null; + } + } + + private cleanup(): void { + this.stopHeartbeat(); + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + this.retryTimeout = null; + } + } +} diff --git a/@plotly/dash-websocket-worker/src/index.ts b/@plotly/dash-websocket-worker/src/index.ts new file mode 100644 index 0000000000..e21b382d41 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/index.ts @@ -0,0 +1,18 @@ +/** + * Dash WebSocket Worker Package + * + * Provides a SharedWorker for WebSocket-based Dash callbacks. + */ + +export * from './types'; + +/** + * Get the URL for the WebSocket worker script. + * This should be used to instantiate the SharedWorker. + * + * @param baseUrl Base URL where the worker script is served + * @returns Full URL to the worker script + */ +export function getWorkerUrl(baseUrl: string): string { + return `${baseUrl}/dash-ws-worker.js`; +} diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts new file mode 100644 index 0000000000..fac282b5e1 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -0,0 +1,151 @@ +/** + * Message types for communication between renderer and worker. + */ +export enum WorkerMessageType { + // Renderer -> Worker + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + + // Worker -> Renderer + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** + * Base message structure for worker communication. + */ +export interface WorkerMessage { + type: WorkerMessageType; + rendererId: string; + requestId?: string; + payload?: unknown; +} + +/** + * Message from renderer to worker requesting connection. + */ +export interface ConnectMessage extends WorkerMessage { + type: WorkerMessageType.CONNECT; + payload: { + serverUrl: string; + inactivityTimeout?: number; + }; +} + +/** + * Message from renderer to worker requesting disconnect. + */ +export interface DisconnectMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECT; +} + +/** + * Callback request payload structure. + */ +export interface CallbackPayload { + output: string; + outputs: unknown[]; + inputs: unknown[]; + state?: unknown[]; + changedPropIds: string[]; + parsedChangedPropsIds?: string[]; +} + +/** + * Message from renderer to worker with callback request. + */ +export interface CallbackRequestMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_REQUEST; + payload: CallbackPayload; +} + +/** + * Message from worker to renderer with callback response. + */ +export interface CallbackResponseMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_RESPONSE; + payload: { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; + }; +} + +/** + * Message from worker to renderer to set component props. + */ +export interface SetPropsMessage extends WorkerMessage { + type: WorkerMessageType.SET_PROPS; + payload: { + componentId: string; + props: Record; + }; +} + +/** + * Message from worker to renderer requesting prop values. + */ +export interface GetPropsRequestMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_REQUEST; + payload: { + componentId: string; + properties: string[]; + }; +} + +/** + * Message from renderer to worker with prop values. + */ +export interface GetPropsResponseMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_RESPONSE; + payload: Record; +} + +/** + * Error message from worker to renderer. + */ +export interface ErrorMessage extends WorkerMessage { + type: WorkerMessageType.ERROR; + payload: { + message: string; + code?: string; + }; +} + +/** + * Connected confirmation message from worker to renderer. + */ +export interface ConnectedMessage extends WorkerMessage { + type: WorkerMessageType.CONNECTED; +} + +/** + * Disconnected notification message from worker to renderer. + */ +export interface DisconnectedMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECTED; + payload?: { + reason?: string; + }; +} + +/** + * Union type of all possible worker messages. + */ +export type AnyWorkerMessage = + | ConnectMessage + | DisconnectMessage + | CallbackRequestMessage + | CallbackResponseMessage + | SetPropsMessage + | GetPropsRequestMessage + | GetPropsResponseMessage + | ErrorMessage + | ConnectedMessage + | DisconnectedMessage; diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts new file mode 100644 index 0000000000..0e68f0b09a --- /dev/null +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -0,0 +1,135 @@ +/** + * Dash WebSocket Worker + * + * A SharedWorker that maintains a single WebSocket connection to the Dash server + * and routes messages between multiple renderer instances (browser tabs). + */ + +import { WebSocketManager } from './WebSocketManager'; +import { MessageRouter } from './MessageRouter'; +import { + WorkerMessageType, + WorkerMessage, + ConnectMessage +} from './types'; + +// SharedWorker global scope +declare const self: SharedWorkerGlobalScope; + +/** WebSocket connection manager */ +const wsManager = new WebSocketManager(); + +/** Message router for renderers */ +const router = new MessageRouter(); + +/** Current server URL */ +let serverUrl: string | null = null; + +/** + * Set up WebSocket manager callbacks. + */ +wsManager.onOpen = () => { + console.log('[DashWSWorker] WebSocket connected'); + // Notify all renderers that connection is established + for (const rendererId of getRendererIds()) { + router.notifyConnected(rendererId); + } +}; + +wsManager.onClose = (reason?: string) => { + console.log(`[DashWSWorker] WebSocket closed: ${reason}`); + router.notifyDisconnected(reason); +}; + +wsManager.onMessage = (data: unknown) => { + router.handleServerMessage(data); +}; + +wsManager.onError = (error: Error) => { + console.error('[DashWSWorker] WebSocket error:', error.message); +}; + +/** + * Set up router to send messages to WebSocket. + */ +router.sendToServer = (message: unknown) => { + wsManager.send(message); +}; + +// Track renderer IDs separately for iteration +const rendererIds = new Set(); + +/** + * Get all registered renderer IDs. + */ +function getRendererIds(): string[] { + return Array.from(rendererIds); +} + +/** + * Handle new connection from a renderer (browser tab). + */ +self.onconnect = (event: MessageEvent) => { + const port = event.ports[0]; + + port.onmessage = (e: MessageEvent) => { + const message = e.data as WorkerMessage; + + switch (message.type) { + case WorkerMessageType.CONNECT: { + const connectMsg = message as ConnectMessage; + const rendererId = connectMsg.rendererId; + const newServerUrl = connectMsg.payload.serverUrl; + const inactivityTimeout = connectMsg.payload.inactivityTimeout; + + // Register the renderer + router.registerRenderer(rendererId, port); + rendererIds.add(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}`); + + // Update inactivity timeout if provided + if (typeof inactivityTimeout === 'number') { + wsManager.setConfig({ inactivityTimeout }); + } + + // Connect to server if not already connected + if (!wsManager.isConnected) { + if (serverUrl !== newServerUrl) { + serverUrl = newServerUrl; + } + wsManager.connect(serverUrl); + } else { + // Already connected, notify the renderer + router.notifyConnected(rendererId); + } + break; + } + + case WorkerMessageType.DISCONNECT: { + const rendererId = message.rendererId; + router.unregisterRenderer(rendererId); + rendererIds.delete(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} disconnected`); + + // If no more renderers, disconnect from server + if (router.rendererCount === 0) { + wsManager.disconnect(); + serverUrl = null; + console.log('[DashWSWorker] All renderers disconnected, closing WebSocket'); + } + break; + } + + default: + // Forward other messages through the router + router.handleRendererMessage(message.rendererId, message); + } + }; + + port.start(); +}; + +// Log worker startup +console.log('[DashWSWorker] SharedWorker initialized'); diff --git a/@plotly/dash-websocket-worker/tsconfig.json b/@plotly/dash-websocket-worker/tsconfig.json new file mode 100644 index 0000000000..0254db7f91 --- /dev/null +++ b/@plotly/dash-websocket-worker/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "ESNext", + "lib": ["ES2020", "WebWorker"], + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "moduleResolution": "node", + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] +} diff --git a/@plotly/dash-websocket-worker/webpack.config.js b/@plotly/dash-websocket-worker/webpack.config.js new file mode 100644 index 0000000000..efe7b59e89 --- /dev/null +++ b/@plotly/dash-websocket-worker/webpack.config.js @@ -0,0 +1,25 @@ +const path = require('path'); + +// This config is for standalone development/testing of the worker. +// The production build is handled by dash-renderer's webpack config. +module.exports = { + entry: './src/worker.ts', + output: { + filename: 'dash-ws-worker.js', + path: path.resolve(__dirname, 'dist'), + clean: true + }, + resolve: { + extensions: ['.ts', '.js'] + }, + module: { + rules: [ + { + test: /\.ts$/, + use: 'ts-loader', + exclude: /node_modules/ + } + ] + }, + target: 'webworker' +}; diff --git a/CHANGELOG.md b/CHANGELOG.md index 43afd88fe7..74925de692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,23 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3738](https://github.com/plotly/dash/pull/3738) Add missing `stacklevel=2` to `warnings.warn()` calls so warnings report the caller's location instead of internal Dash source lines - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) + +## [4.2.0rc3] - 2026-05-12 + +- [#3771](https://github.com/plotly/dash/pull/3771) Add persistent callbacks and no inputs/no outputs callback support. +- Rename ctx.get_websocket to ctx.websocket + +## [4.2.0rc2] - 2026-05-01 + +## Fixed +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the error when using `set_props()` to update component-type properties in the `websocket` callback. +- Add threadpool for running websocket callbacks. + +## [4.2.0rc1] - 2026-04-13 + +## Added +- [#3742](https://github.com/plotly/dash/pull/3742) Add websocket callbacks to fastapi and quart backends. ## [4.1.0] - 2026-03-23 @@ -34,6 +51,22 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3609](https://github.com/plotly/dash/pull/3609) Add backward compat alias for _Wildcard - [#3672](https://github.com/plotly/dash/pull/3672) Improve browser performance when app contains a large number of pattern matching callback callbacks. Exposes an api endpoint to fetch the latest computeGraph call. +# [4.2.0rc0] - 2026-04-13 + +## Fixed + +- Fix websocket used in the same FastAPI server. Fix [#3636](https://github.com/plotly/dash/issues/3636) +- Fix FastAPI url paths order. Fix [3667](https://github.com/plotly/dash/issues/3667) + +# [4.1.0rc0] - 2026-02-23 + +## Added + +- Add support for multiple backend implementation beside flask such as fastapi and quart (both included). + - Add `app = Dash(backend="flask" | "fastapi" | "quart" | CustomBackendImpl)` parameter to automatically setup + - An existing `Fastapi`, `Quart` or `Flask` instance can also be given as `app = Dash(server=Fastapi())` to automatically setup a dash app on the server. + - Install fastapi dependencies with `pip install dash[fastapi]` or quart with `pip install dash[quart]`, flask is still included by default. + - Custom backend implementation can be added as a subclass of `dash.backends.base_server.BaseDashServer` and response/request adapters. ## [4.0.0] - 2026-02-03 diff --git a/components/dash-core-components/tests/integration/input/conftest.py b/components/dash-core-components/tests/integration/input/conftest.py index c03087db1a..f612cfe462 100644 --- a/components/dash-core-components/tests/integration/input/conftest.py +++ b/components/dash-core-components/tests/integration/input/conftest.py @@ -2,7 +2,7 @@ from dash import Dash, Input, Output, dcc, html -@pytest.fixture(scope="module") +@pytest.fixture def ninput_app(): app = Dash(__name__) app.layout = html.Div( @@ -35,7 +35,7 @@ def render(fval, tval): yield app -@pytest.fixture(scope="module") +@pytest.fixture def input_range_app(): app = Dash(__name__) app.layout = html.Div( @@ -59,7 +59,7 @@ def range_out(val): yield app -@pytest.fixture(scope="module") +@pytest.fixture def debounce_text_app(): app = Dash(__name__) app.layout = html.Div( @@ -89,7 +89,7 @@ def render(slow_val, fast_val): yield app -@pytest.fixture(scope="module") +@pytest.fixture def debounce_number_app(): app = Dash(__name__) app.layout = html.Div( diff --git a/components/dash-core-components/tests/integration/input/test_number_input.py b/components/dash-core-components/tests/integration/input/test_number_input.py index 299cbecf93..5e42454c45 100644 --- a/components/dash-core-components/tests/integration/input/test_number_input.py +++ b/components/dash-core-components/tests/integration/input/test_number_input.py @@ -238,6 +238,7 @@ def update_output(val): def test_inni010_valid_numbers(dash_dcc, ninput_app): dash_dcc.start_server(ninput_app) + elem = dash_dcc.wait_for_element("#input_false") for num, op in ( ("1.0", lambda x: int(float(x))), # limitation of js/json ("10e10", lambda x: int(float(x))), @@ -245,7 +246,7 @@ def test_inni010_valid_numbers(dash_dcc, ninput_app): (str(sys.float_info.max), float), (str(sys.float_info.min), float), ): - elem = dash_dcc.find_element("#input_false") + elem = dash_dcc.wait_for_element("#input_false") elem.send_keys(num) assert dash_dcc.wait_for_text_to_equal( "#div_false", str(op(num)) diff --git a/dash/_callback.py b/dash/_callback.py index 3785df7166..f5f64970b0 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,11 +1,8 @@ +from typing import Callable, Optional, Any, List, Tuple, Union, Dict +from functools import wraps import collections import hashlib import inspect -from functools import wraps - -from typing import Callable, Optional, Any, List, Tuple, Union, Dict - -import flask from .dependencies import ( handle_callback_args, @@ -23,7 +20,7 @@ BackgroundCallbackError, ImportedInsideCallbackError, ) - +from ._get_app import get_app from ._grouping import ( flatten_grouping, make_grouping_by_index, @@ -38,10 +35,10 @@ clean_property_name, ) -from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value from ._no_update import NoUpdate +from . import _validate async def _async_invoke_callback( @@ -80,6 +77,8 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + websocket: Optional[bool] = False, + persistent: Optional[bool] = False, **_kwargs, ) -> Callable[..., Any]: """ @@ -174,6 +173,10 @@ def callback( The endpoint is relative to the Dash app's base URL. Note that the endpoint will not appear in the list of registered callbacks in the Dash devtools. + :param persistent: + If True, this callback will not show the "Updating..." title while + running. Useful for persistent WebSocket callbacks that stay active + for long periods without requiring a loading indicator. """ background_spec: Any = None @@ -231,6 +234,8 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + websocket=websocket, + persistent=persistent, ) @@ -278,7 +283,9 @@ def insert_callback( no_output=False, optional=False, hidden=None, -): + websocket=False, + persistent=False, +) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -303,6 +310,8 @@ def insert_callback( "no_output": no_output, "optional": optional, "hidden": hidden, + "websocket": websocket, + "persistent": persistent, } if running: callback_spec["running"] = running @@ -318,6 +327,7 @@ def insert_callback( "manager": manager, "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, + "websocket": websocket, } callback_list.append(callback_spec) @@ -359,7 +369,7 @@ def _initialize_context(args, kwargs, inputs_state_indices, has_output, insert_o def _get_callback_manager( kwargs: dict, background: dict -) -> Union[BaseBackgroundCallbackManager, None]: +) -> BaseBackgroundCallbackManager: """Set up the background callback and manage jobs.""" callback_manager = background.get( "manager", kwargs.get("background_callback_manager", None) @@ -375,7 +385,8 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = flask.request.args.getlist("oldJob") + adapter = get_app().backend.request_adapter() + old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: for job in old_job: @@ -389,6 +400,8 @@ def _setup_background_callback( ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) + if not callback_manager: + return to_json({"error": "No background callback manager configured"}) progress_outputs = background.get("progress") @@ -396,14 +409,11 @@ def _setup_background_callback( cache_key = callback_manager.build_cache_key( func, - # Inputs provided as dict is kwargs. func_args if func_args else func_kwargs, background.get("cache_args_to_ignore", []), None if cache_ignore_triggered else callback_ctx.get("triggered_inputs", []), ) - job_fn = callback_manager.func_registry.get(background_key) - ctx_value = AttributeDict(**context_value.get()) ctx_value.ignore_register_page = True ctx_value.pop("background_callback_manager") @@ -435,7 +445,8 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = flask.request.args.get("cacheKey") + adapter = get_app().backend.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -452,8 +463,9 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + adapter = get_app().backend.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None _progress_background_callback(response, callback_manager, background) @@ -473,8 +485,9 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + adapter = get_app().backend.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -652,6 +665,8 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + websocket=_kwargs.get("websocket", False), + persistent=_kwargs.get("persistent", False), ) # pylint: disable=too-many-locals @@ -687,11 +702,11 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} # type: ignore - - jsonResponse = None + jsonResponse: Optional[str] = None try: if background is not None: - if not flask.request.args.get("cacheKey"): + adapter = get_app().backend.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, @@ -762,7 +777,8 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not flask.request.args.get("cacheKey"): + adapter = get_app().backend.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 0f4d1a9924..1bcf235036 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,12 +1,14 @@ +import asyncio import functools import warnings import json import contextvars import typing -import flask +from dash.backends.ws import DashWebsocketCallback from . import exceptions +from ._get_app import get_app from ._utils import AttributeDict, stringify_id context_value: contextvars.ContextVar[ @@ -224,14 +226,15 @@ def record_timing(name, duration, description=None): :param description: A description of the resource. :type description: string or None """ - timing_information = getattr(flask.g, "timing_information", {}) + request = get_app().backend.request_adapter() + timing_information = getattr(request.context, "timing_information", {}) if name in timing_information: raise KeyError(f'Duplicate resource name "{name}" found.') timing_information[name] = {"dur": round(duration * 1000), "desc": description} - setattr(flask.g, "timing_information", timing_information) + setattr(request.context, "timing_information", timing_information) @property @has_context @@ -254,7 +257,8 @@ def using_outputs_grouping(self): @property @has_context def timing_information(self): - return getattr(flask.g, "timing_information", {}) + request = get_app().backend.request_adapter() + return getattr(request.context, "timing_information", {}) @has_context def set_props(self, component_id: typing.Union[str, dict], props: dict): @@ -292,6 +296,14 @@ def path(self): """ return _get_from_context("path", "") + @property + @has_context + def args(self): + """ + Query parameters of the callback request as a dictionary-like object. + """ + return _get_from_context("args", "") + @property @has_context def remote(self): @@ -316,6 +328,32 @@ def custom_data(self): """ return _get_from_context("custom_data", {}) + @property + @has_context + def websocket(self) -> typing.Optional[DashWebsocketCallback]: + """Get WebSocket interface if running in WebSocket context. + + Returns the DashWebsocketCallback instance if the callback is being + executed via WebSocket, otherwise returns None. + + Raises: + RuntimeError: If websocket_callbacks is requested but the backend + doesn't support WebSocket. + """ + ws = _get_from_context("dash_websocket", None) + if ws is None: + app = get_app() + if ( + hasattr(app, "_websocket_callbacks") + and app._websocket_callbacks # pylint: disable=protected-access + and not app.backend.websocket_capability + ): + raise RuntimeError( + f"WebSocket callbacks requested but backend " + f"'{app.backend.server_type}' doesn't support them." + ) + return ws + callback_context = CallbackContext() @@ -323,5 +361,26 @@ def custom_data(self): def set_props(component_id: typing.Union[str, dict], props: dict): """ Set the props for a component not included in the callback outputs. + + If running in a WebSocket context, props are streamed immediately to the + client. Otherwise, props are batched and sent with the callback response. """ - callback_context.set_props(component_id, props) + ws = _get_from_context("dash_websocket", None) + if ws is not None: + # Stream immediately via WebSocket + _id = stringify_id(component_id) + + async def _send_props(): + for prop_name, value in props.items(): + await ws.set_prop(_id, prop_name, value) + + # If we're in an async context, schedule the coroutine + try: + asyncio.get_running_loop() + asyncio.ensure_future(_send_props()) + except RuntimeError: + # No running event loop - run synchronously + asyncio.run(_send_props()) + else: + # Batch for response (existing behavior) + callback_context.set_props(component_id, props) diff --git a/dash/_configs.py b/dash/_configs.py index edbf7b50d1..107b8308f5 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -1,5 +1,6 @@ import os -import flask + +from ._utils import get_root_path # noinspection PyCompatibility from . import exceptions @@ -127,7 +128,7 @@ def pages_folder_config(name, pages_folder, use_pages): if not pages_folder: return None is_custom_folder = str(pages_folder) != "pages" - pages_folder_path = os.path.join(flask.helpers.get_root_path(name), pages_folder) + pages_folder_path = os.path.join(get_root_path(name), pages_folder) if (use_pages or is_custom_folder) and not os.path.isdir(pages_folder_path): error_msg = f""" A folder called `{pages_folder}` does not exist. If a folder for pages is not diff --git a/dash/_dash_renderer.py b/dash/_dash_renderer.py index ee507ddb71..5574131d10 100644 --- a/dash/_dash_renderer.py +++ b/dash/_dash_renderer.py @@ -1,7 +1,7 @@ import os from typing import Any, List, Dict -__version__ = "3.0.0" +__version__ = "3.1.0" _available_react_versions = {"18.3.1", "18.2.0", "16.14.0"} _available_reactdom_versions = {"18.3.1", "18.2.0", "16.14.0"} @@ -65,7 +65,7 @@ def _set_react_version(v_react, v_reactdom=None): { "relative_package_path": "dash-renderer/build/dash_renderer.min.js", "dev_package_path": "dash-renderer/build/dash_renderer.dev.js", - "external_url": "https://unpkg.com/dash-renderer@3.0.0" + "external_url": "https://unpkg.com/dash-renderer@3.1.0" "/build/dash_renderer.min.js", "namespace": "dash", }, @@ -75,4 +75,9 @@ def _set_react_version(v_react, v_reactdom=None): "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/_get_app.py b/dash/_get_app.py index a64a7450cc..ab0b897f81 100644 --- a/dash/_get_app.py +++ b/dash/_get_app.py @@ -4,6 +4,8 @@ from textwrap import dedent from typing import Any, Optional +from dash.exceptions import AppNotFoundError + APP: Optional[Any] = None app_context: ContextVar[Any] = ContextVar("dash_app_context") @@ -55,7 +57,7 @@ def get_app(): pass if APP is None: - raise Exception( + raise AppNotFoundError( dedent( """ App object is not yet defined. `app = dash.Dash()` needs to be run diff --git a/dash/_hooks.py b/dash/_hooks.py index 3fe3c40e6d..f260b1fcb0 100644 --- a/dash/_hooks.py +++ b/dash/_hooks.py @@ -3,7 +3,6 @@ from importlib import metadata as _importlib_metadata import typing_extensions as _tx -import flask as _f from .exceptions import HookError from .resources import ResourceType @@ -50,6 +49,8 @@ def __init__(self) -> None: "index": [], "custom_data": [], "dev_tools": [], + "websocket_connect": [], + "websocket_message": [], } self._js_dist: _t.List[_t.Any] = [] self._css_dist: _t.List[_t.Any] = [] @@ -127,7 +128,7 @@ def route( Add a route to the Dash server. """ - def wrap(func: _t.Callable[[], _f.Response]): + def wrap(func: _t.Callable[[], _t.Any]): _name = name or func.__name__ self.add_hook( "routes", @@ -245,6 +246,60 @@ def devtool( } ) + def websocket_connect(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket connection validation hook. + + The hook receives the WebSocket object and should return: + - True (or any truthy value): Allow the connection + - False: Reject with default code (4001) and reason + - tuple (code, reason): Reject with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_connect() + async def validate_session(websocket): + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_connect", func, priority=priority, final=final) + return func + + return decorator + + def websocket_message(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket message validation hook. + + The hook receives the WebSocket object and message dict, and should return: + - True (or any truthy value): Allow the message + - False: Disconnect with default code (4001) and reason + - tuple (code, reason): Disconnect with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_message() + async def validate_session(websocket, message): + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_message", func, priority=priority, final=final) + return func + + return decorator + hooks = _Hooks() diff --git a/dash/_pages.py b/dash/_pages.py index 45538546e8..be9d847309 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -9,13 +9,11 @@ from pathlib import Path from urllib.parse import parse_qs -import flask - from . import _validate from ._callback_context import context_value from ._get_app import get_app from ._get_paths import get_relative_path -from ._utils import AttributeDict +from ._utils import AttributeDict, get_root_path CONFIG = AttributeDict() PAGE_REGISTRY = collections.OrderedDict() @@ -98,7 +96,7 @@ def _path_to_module_name(path): def _infer_module_name(page_path): relative_path = page_path.split(CONFIG.pages_folder)[-1] module = _path_to_module_name(relative_path) - proj_root = flask.helpers.get_root_path(CONFIG.name) + proj_root = get_root_path(CONFIG.name) if CONFIG.pages_folder.startswith(proj_root): parent_path = CONFIG.pages_folder[len(proj_root) :] else: @@ -150,23 +148,12 @@ def _parse_path_variables(pathname, path_template): return dict(zip(var_names, variables)) -def _create_redirect_function(redirect_to): - def redirect(): - return flask.redirect(redirect_to, code=301) - - return redirect - - def _set_redirect(redirect_from, path): app = get_app() if redirect_from and len(redirect_from): for redirect in redirect_from: fullname = app.get_relative_path(redirect) - app.server.add_url_rule( - fullname, - fullname, - _create_redirect_function(app.get_relative_path(path)), - ) + app.backend.add_redirect_rule(app, fullname, app.get_relative_path(path)) def register_page( @@ -318,18 +305,22 @@ def register_page( ) page.update( supplied_title=title, - title=title - if title is not None - else CONFIG.title - if CONFIG.title != "Dash" - else page["name"], + title=( + title + if title is not None + else CONFIG.title + if CONFIG.title != "Dash" + else page["name"] + ), ) page.update( - description=description - if description - else CONFIG.description - if CONFIG.description - else "", + description=( + description + if description + else CONFIG.description + if CONFIG.description + else "" + ), order=order, supplied_order=order, supplied_layout=layout, @@ -389,16 +380,14 @@ def _path_to_page(path_id): return {}, None -def _page_meta_tags(app): - start_page, path_variables = _path_to_page(flask.request.path.strip("/")) +def _page_meta_tags(app, request): + request_path = request.path + start_page, path_variables = _path_to_page(request_path.strip("/")) - # use the supplied image_url or create url based on image in the assets folder image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([flask.request.url_root, image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +402,7 @@ def _page_meta_tags(app): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": flask.request.url}, + {"property": "twitter:url", "content": request.url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/_utils.py b/dash/_utils.py index 48a378e1cf..1ab2036820 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -3,6 +3,7 @@ import sys import uuid import hashlib +import importlib from collections import abc import subprocess import logging @@ -12,6 +13,7 @@ import string import inspect import re +import os from html import escape from functools import wraps @@ -104,6 +106,11 @@ def set_read_only(self, names, msg="Attribute is read-only"): else: object.__setattr__(self, "_read_only", new_read_only) + def unset_read_only(self, keys): + if hasattr(self, "_read_only"): + for key in keys: + self._read_only.pop(key, None) + def finalize(self, msg="Object is final: No new keys may be added."): """Prevent any new keys being set.""" object.__setattr__(self, "_final", msg) @@ -158,6 +165,20 @@ def _concat(x): if no_output: # No output will hash the inputs. + # For no-input callbacks, also include the call site to make each unique + if not inputs: + # Get the call site of the @callback decorator + stack = inspect.stack() + # Walk up the stack to find the actual callback call site + # Fallback to empty hash if no external frame found + # (skip internal dash package frames) + dash_package_path = os.path.dirname(__file__) + for frame_info in stack: + # Skip frames from within the dash package itself + if not frame_info.filename.startswith(dash_package_path): + call_site = f"{frame_info.filename}:{frame_info.lineno}" + return hashlib.sha256(call_site.encode("utf-8")).hexdigest() + return _hash_inputs() if isinstance(output, (list, tuple)): @@ -317,3 +338,60 @@ def pascal_case(name: Union[str, None]): return s[0].upper() + re.sub( r"[\-_\.]+([a-z])", lambda match: match.group(1).upper(), s[1:] ) + + +def get_root_path(import_name: str) -> str: + """Find the root path of a package, or the path that contains a + module. If it cannot be found, returns the current working + directory. + + Not to be confused with the value returned by :func:`find_package`. + + :meta private: + """ + # Module already imported and has a file attribute. Use that first. + mod = sys.modules.get(import_name) + + if mod is not None and hasattr(mod, "__file__") and mod.__file__ is not None: + return os.path.dirname(os.path.abspath(mod.__file__)) + + # Next attempt: check the loader. + try: + spec = importlib.util.find_spec(import_name) + + if spec is None: + raise ValueError + except (ImportError, ValueError): + loader = None + else: + loader = spec.loader + + # Loader does not exist or we're referring to an unloaded main + # module or a main module without path (interactive sessions), go + # with the current working directory. + if loader is None: + return os.getcwd() + + if hasattr(loader, "get_filename"): + filepath = loader.get_filename(import_name) # pyright: ignore + else: + # Fall back to imports. + __import__(import_name) + mod = sys.modules[import_name] + filepath = getattr(mod, "__file__", None) + + # If we don't have a file path it might be because it is a + # namespace package. In this case pick the root path from the + # first module that is contained in the package. + if filepath is None: + raise RuntimeError( + "No root path can be found for the provided module" + f" {import_name!r}. This can happen because the module" + " came from an import hook that does not provide file" + " name information or because it's a namespace package." + " In this case the root path needs to be explicitly" + " provided." + ) + + # filepath is import_name.py for a module, or __init__.py for a package. + return os.path.dirname(os.path.abspath(filepath)) # type: ignore[no-any-return] diff --git a/dash/_validate.py b/dash/_validate.py index f7502f245b..b80c61df2c 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -3,12 +3,13 @@ import re from textwrap import dedent from keyword import iskeyword -import flask from ._grouping import grouping_len, map_grouping from ._no_update import NoUpdate from .development.base_component import Component +from . import backends from . import exceptions +from ._get_app import get_app from ._utils import ( patch_collections_abc, stringify_id, @@ -510,13 +511,17 @@ def validate_use_pages(config): "`dash.register_page()` must be called after app instantiation" ) - if flask.has_request_context(): - raise exceptions.PageError( - """ - dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable. - For more details, see https://dash.plotly.com/sharing-data-between-callbacks#why-global-variables-will-break-your-app - """ - ) + try: + if get_app().backend.has_request_context(): + raise exceptions.PageError( + """ + dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable. + For more details, see https://dash.plotly.com/sharing-data-between-callbacks#why-global-variables-will-break-your-app + """ + ) + except exceptions.AppNotFoundError: + # If the app is not found we can add pages since before instantiation. + pass def validate_module_name(module): @@ -585,3 +590,72 @@ def _valid(out): return _valid(output) + + +def check_async(use_async): + if use_async is None: + try: + import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa + + use_async = True + except ImportError: + pass + elif use_async: + try: + import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa + except ImportError as exc: + raise Exception( + "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" + ) from exc + return use_async or False + + +def check_backend(backend, inferred_backend): + if backend is not None: + if isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + expected_backend_cls, _ = backends.get_backend(inferred_backend) + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): + raise ValueError( + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." + ) + elif not isinstance(backend, str): + raise ValueError("Invalid backend argument") + elif backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) + + +def validate_websocket_callback_request( + callback_id, callback_map, websocket_callbacks_enabled +): + """Validate a WebSocket callback request at runtime. + + Called by WebSocket handlers to verify that a callback received via WebSocket + is actually allowed to use WebSocket transport. + + Args: + callback_id: The callback output ID from the request + callback_map: The app's callback_map dictionary + websocket_callbacks_enabled: Whether websocket_callbacks=True at app level + + Raises: + WebSocketCallbackError: If the callback is not websocket-enabled + """ + # If global websocket_callbacks is enabled, all callbacks can use WebSocket + if websocket_callbacks_enabled: + return + + # Otherwise, check if this specific callback has websocket=True + cb = callback_map.get(callback_id, {}) + if not cb.get("websocket"): + raise exceptions.WebSocketCallbackError( + f"Callback '{callback_id}' received via WebSocket but does not have " + f"websocket=True. Either enable websocket_callbacks=True globally " + f"or add websocket=True to this callback." + ) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py new file mode 100644 index 0000000000..585d34a65c --- /dev/null +++ b/dash/backends/__init__.py @@ -0,0 +1,75 @@ +import importlib +from typing import Type + +from .base_server import BaseDashServer + + +_backend_imports = { + "flask": ("dash.backends._flask", "FlaskDashServer"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer"), + "quart": ("dash.backends._quart", "QuartDashServer"), +} + + +def get_backend(name: str) -> Type[BaseDashServer]: + module_name, server_class = _backend_imports[name.lower()] + try: + module = importlib.import_module(module_name) + server = getattr(module, server_class) + return server + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e + except ImportError as e: + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e + except AttributeError as e: + raise AttributeError( + f"Module '{module_name}' does not have class '{server_class}' for backend '{name}': {e}" + ) from e + + +def _is_flask_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from flask import Flask + + return isinstance(obj, Flask) + except ImportError: + return False + + +def _is_fastapi_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from fastapi import FastAPI # type: ignore[import-not-found] + + return isinstance(obj, FastAPI) + except ImportError: + return False + + +def _is_quart_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from quart import Quart # type: ignore[import-not-found] + + return isinstance(obj, Quart) + except ImportError: + return False + + +def get_server_type(server): + if _is_flask_instance(server): + return "flask" + if _is_quart_instance(server): + return "quart" + if _is_fastapi_instance(server): + return "fastapi" + raise ValueError("Invalid backend argument") + + +__all__ = [ + "get_backend", + "get_server_type", +] diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py new file mode 100644 index 0000000000..9e06bd418c --- /dev/null +++ b/dash/backends/_fastapi.py @@ -0,0 +1,894 @@ +from __future__ import annotations + +from contextvars import copy_context, ContextVar +import asyncio +import concurrent.futures +import json +import queue +from typing import TYPE_CHECKING, Any, Callable, Dict +import sys +import mimetypes +import hashlib +import inspect +import pkgutil +import time +import os +import subprocess +import threading +import traceback +from urllib.parse import urlparse + +try: + from fastapi import FastAPI, Request, Response, Body + from fastapi.responses import JSONResponse, RedirectResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders + from starlette.types import ASGIApp, Scope, Receive, Send + from starlette.websockets import WebSocket, WebSocketDisconnect + import uvicorn +except ImportError as _err: + raise ImportError( + "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." + ) from _err + +import janus + +from dash.fingerprint import check_fingerprint +from dash import _validate, get_app +from dash.exceptions import PreventUpdate +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, +) +from .ws import ( + DashWebsocketCallback, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, + DISCONNECTED, +) +from ._utils import format_traceback_html + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash import Dash + + +class FastAPIResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps FastAPI's JSONResponse + and provides a set_response() method for compatibility with Dash's callback system. + """ + + @property + def callback_response(self): + """Get the response object to be returned from a callback.""" + print( + "Cannot access callback_response directly on FastAPIResponseAdapter. Use set_response() to create a response with data." + ) + raise NotImplementedError() + + def set_response(self, **kwargs): + """ + Set the response data. This method provides compatibility with Flask's Response.set_data(). + """ + data = kwargs.get("data") + if isinstance(data, (str, bytes, bytearray)): + resp = Response(content=data) + else: + resp = JSONResponse(content=data) + if self._headers: + for key, value in self._headers.items(): + if isinstance(value, list): + for v in value: + resp.headers.append(key, v) + else: + resp.headers[key] = value + if self._cookies: + for key, (value, cookie_kwargs) in self._cookies.items(): + resp.set_cookie(key, value, **cookie_kwargs) + return resp + + +_current_request_var = ContextVar("dash_current_request", default=None) + + +def set_current_request(req): + return _current_request_var.set(req) + + +def reset_current_request(token): + _current_request_var.reset(token) + + +def get_current_request() -> Request: + req = _current_request_var.get() + if req is None: + raise RuntimeError("No active request in context") + return req + + +_ENV_CONFIG = "_DASH_FASTAPI_CONFIG" + + +class DashMiddleware: # pylint: disable=too-few-public-methods + """Consolidated middleware for all Dash/FastAPI integration needs.""" + + def __init__( + self, + app: ASGIApp, + dash_app: Dash, + dash_server: FastAPIDashServer, + before_request_funcs: list, + after_request_func: Callable | None = None, + enable_timing: bool = False, + ) -> None: + self.app = app + self.dash_app = dash_app + self.dash_server = dash_server + self.before_request_funcs = before_request_funcs + self.after_request_func = after_request_func + self.enable_timing = enable_timing + self._dev_tools_initialized = False + + async def _initialize_dev_tools(self) -> None: + """Initialize dev tools from environment config on first run.""" + if not self._dev_tools_initialized: + config = json.loads(os.getenv(_ENV_CONFIG, "{}")) + if config: + self.dash_app.enable_dev_tools(**config, first_run=False) + self._dev_tools_initialized = True + + async def _setup_timing(self, request: Request) -> None: + """Set up timing information for the request.""" + try: + request.state.json_body = ( + await request.json() + if request.headers.get("content-type", "").startswith( + "application/json" + ) + else None + ) + except Exception: # pylint: disable=broad-exception-caught + request.state.json_body = None + if self.enable_timing: + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + + async def _run_before_hooks(self) -> None: + """Run all before-request hooks.""" + for func in self.before_request_funcs: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + + async def _run_after_hooks(self) -> None: + """Run after-request hook if configured.""" + if self.after_request_func is not None: + if inspect.iscoroutinefunction(self.after_request_func): + await self.after_request_func() + else: + self.after_request_func() + + def _finalize_timing(self, request: Request) -> dict | None: + """Calculate final timing information and return headers to add.""" + if not self.enable_timing or not hasattr(request.state, "timing_information"): + return None + + timing_information = request.state.timing_information + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + + return timing_information + + async def _handle_error( + self, error: Exception, scope: Scope, receive: Receive, send: Send + ) -> None: + """Handle exceptions during request processing.""" + if isinstance(error, PreventUpdate): + response = Response(status_code=204) + elif self.dash_server.error_handling_mode in ["raise", "prune"]: + tb = self.dash_server._get_traceback(None, error) # pylint: disable=W0212 + response = Response(content=tb, media_type="text/html", status_code=500) + else: + response = JSONResponse( + status_code=500, + content={ + "error": "InternalServerError", + "message": "An internal server error occurred.", + }, + ) + await response(scope, receive, send) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # Handle lifespan events (startup/shutdown) + + if scope["type"] == "lifespan": + try: + dash_app = get_app() + dash_app.backend._setup_catchall() + except Exception: # pylint: disable=broad-exception-caught + traceback.print_exc() + await self._initialize_dev_tools() + await self.app(scope, receive, send) + return + + # Non-HTTP/WebSocket scopes pass through + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # HTTP/WebSocket request handling + request = Request(scope, receive=receive) + token = set_current_request(request) + + try: + await self._setup_timing(request) + await self._run_before_hooks() + + await self.app(scope, receive, send) + + await self._run_after_hooks() + self._finalize_timing(request) + + except Exception as e: # pylint: disable=W0718 + await self._handle_error(e, scope, receive, send) + finally: + reset_current_request(token) + + +class FastAPIDashServer(BaseDashServer[FastAPI]): + websocket_capability: bool = True + + def __init__(self, server: FastAPI): + super().__init__(server) + self.server_type = "fastapi" + self.error_handling_mode = "ignore" + self.request_adapter = FastAPIRequestAdapter + self.response_adapter = FastAPIResponseAdapter + self._before_request_funcs = [] + self._after_request_func = None + self._enable_timing = False + + def __call__(self, *args: Any, **kwargs: Any): + # ASGI: pass through to FastAPI + return self.server(*args, **kwargs) + + @staticmethod + # pylint: disable=W0613 + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = FastAPI() + + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ): + try: + self.server.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self): + self.error_handling_mode = "ignore" + + def _get_traceback(self, _secret, error: Exception): + return format_traceback_html( + error, self.error_handling_mode, "FastAPI Debugger", "FastAPI" + ) + + def register_prune_error_handler(self, _secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + + def _html_response_wrapper(self, view_func: Callable[..., Any] | str): + async def wrapped(*_args, **_kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, dash_app: Dash): + async def index(_request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app: Dash): + """This is needed to ensure that all routes are handled by FastAPI + and passed through the middleware, which is necessary for features like authentication + and timing to work correctly on all routes. FastAPI will match this catch-all route + for any path that isn't matched by a more specific route, allowing the middleware to + process the request and then return the appropriate response (e.g., 404 if no Dash route matches).""" + + def _setup_catchall(self): + try: + dash_app = get_app() + + async def catchall(_request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + self.add_url_rule("{path:path}", catchall, methods=["GET"]) + except Exception: # pylint: disable=broad-exception-caught + traceback.print_exc() + + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any] | str, + endpoint: str | None = None, + methods: list[str] | None = None, + include_in_schema: bool = False, + ): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + self.server.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=include_in_schema, + ) + + def before_request(self, func: Callable[[], Any] | None): + if func is not None: + self._before_request_funcs.append(func) + + def after_request(self, func: Callable[[], Any] | None): + self._after_request_func = func + + def has_request_context(self) -> bool: + try: + get_current_request() + return True + except RuntimeError: + return False + + def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R0912 + frame = inspect.stack()[2] + if debug and kwargs.get("reload") is None: + kwargs["reload"] = True + + # Check if we're running in a thread (e.g., from testing framework) + # If so, run uvicorn directly instead of spawning a subprocess + is_threaded = threading.current_thread() != threading.main_thread() + + if is_threaded: + # Running in a thread (testing context) - use uvicorn.Server + # This allows graceful shutdown via should_exit flag + kwargs.pop("reload", None) # Reload not supported in threaded mode + config = uvicorn.Config(self.server, host=host, port=port, **kwargs) + server = uvicorn.Server(config) + # Store server reference on the app for graceful shutdown + dash_app._uvicorn_server = server # pylint: disable=protected-access + server.run() + else: + # Running in main thread (normal context) - use subprocess + file_path = frame.filename + rel_path = os.path.relpath(file_path, os.getcwd()) + + # Check if the file is outside the current working directory + if rel_path.startswith(".."): + # File is outside cwd, try to find the module name from sys.modules + module_name = None + for mod_name, mod in sys.modules.items(): + if hasattr(mod, "__file__") and mod.__file__: + if os.path.abspath(mod.__file__) == os.path.abspath(file_path): + module_name = mod_name + break + + # If we still can't find it, raise an error + if not module_name: + raise RuntimeError( + f"Cannot determine module name for {file_path}. " + "The file is outside the current working directory and not found in sys.modules. " + "Please ensure the FastAPI app is being run from a file within the current working directory." + ) + else: + # File is within cwd, use relative path + module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") + + # Find the Dash app variable name by inspecting the calling frame + dash_var_name = None + calling_frame = frame.frame + for var_name, var_value in calling_frame.f_locals.items(): + if var_value is dash_app: + dash_var_name = var_name + break + + # If not found in locals, check globals + if not dash_var_name: + for var_name, var_value in calling_frame.f_globals.items(): + if var_value is dash_app: + dash_var_name = var_name + break + + # Construct the app path - use .server to access the FastAPI instance + if dash_var_name: + app_path = f"{module_name}:{dash_var_name}.server" + else: + # Fallback to looking for 'server' variable (old behavior) + app_path = f"{module_name}:server" + + uvicorn_args = [ + sys.executable, + "-m", + "uvicorn", + app_path, + "--host", + str(host), + "--port", + str(port), + ] + if kwargs.get("reload"): + uvicorn_args.append("--reload") + + dev_tools = dash_app._dev_tools # pylint: disable=W0212 + config = dict( + {"debug": debug} if debug else {"debug": False}, + **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, + ) + env = os.environ.copy() + env[_ENV_CONFIG] = json.dumps(config) + + # Add any other kwargs as CLI args if needed + + # pylint: disable=R1732 + proc = subprocess.Popen(uvicorn_args, env=env) + proc.wait() + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + status: int | None = None, + ): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers, status_code=status or 200) + + def jsonify(self, obj: Any): + return JSONResponse(content=obj) + + def serve_component_suites( + self, + dash_app: Dash, + package_name: str, + fingerprinted_path: str, + request: Request, + ): + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app: Dash): + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + name = "_dash-component-suites/{package_name}/{fingerprinted_path:path}" + dash_app._add_url(name, serve) # pylint: disable=protected-access + + def _create_redirect_function(self, redirect_to): + def _redirect(): + return RedirectResponse(url=redirect_to, status_code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_api_route( + fullname, + self._create_redirect_function(app.get_relative_path(path)), + methods=["GET"], + name=fullname, + include_in_schema=False, + ) + + def serve_callback(self, dash_app: Dash): + async def _dispatch(request: Request): # pylint: disable=unused-argument + # pylint: disable=protected-access + body = self.request_adapter().get_json() + cb_ctx = dash_app._initialize_context( + body + ) # pylint: disable=protected-access + func = dash_app._prepare_callback( + cb_ctx, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + cb_ctx.inputs_list + cb_ctx.states_list + ) # pylint: disable=protected-access + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) # pylint: disable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + return cb_ctx.dash_response.set_response(data=response_data) + + return _dispatch + + def register_timing_hooks(self, first_run: bool): + if first_run: + self._enable_timing = True + + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Accepts a JSON body (dict) and filters keys based on the handler's signature. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + + def make_view_func(handler, param_names): + async def view_func(_request: Request, body: dict = Body(...)): + kwargs = { + k: v + for k, v in body.items() + if k in param_names and v is not None + } + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + return view_func + + self.server.add_api_route( + route, + make_view_func(handler, param_names), + methods=methods, + name=endpoint, + include_in_schema=True, + ) + + def enable_compression(self) -> None: + # pylint: disable=import-outside-toplevel,import-error + from fastapi.middleware.gzip import ( + GZipMiddleware, + ) + + self.server.add_middleware(GZipMiddleware, minimum_size=500) + + def setup_backend(self, dash_app: Dash): + # Add consolidated middleware for all Dash functionality + self.server.add_middleware( + DashMiddleware, + dash_app=dash_app, + dash_server=self, + before_request_funcs=self._before_request_funcs, + after_request_func=self._after_request_func, + enable_timing=self._enable_timing, + ) + + # Add timing middleware separately if enabled (needs to modify response headers) + if self._enable_timing: + + @self.server.middleware("http") + async def timing_headers_middleware(request: Request, call_next): + response = await call_next(request) + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + + async def _run_ws_hooks( + self, hooks, websocket: "WebSocket", *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + websocket: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(websocket, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements,too-many-locals + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + + # Get allowed origins from dash app config + allowed_origins = getattr( + dash_app, "_websocket_allowed_origins", [] + ) # pylint: disable=protected-access + + def validate_origin(origin: str | None, host: str | None) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + async def websocket_handler(websocket: WebSocket): + # Validate Origin header to prevent Cross-Site WebSocket Hijacking + origin = websocket.headers.get("origin") + host = websocket.headers.get("host") + error = validate_origin(origin, host) + if error: + await websocket.close(code=4003, reason=error) + return + + # Call websocket_connect hooks (before accept) + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + websocket, + default_reason="Connection rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + await websocket.accept() + + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Shutdown event to signal connection closure to worker threads + shutdown_event = threading.Event() + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task( + run_ws_sender(websocket.send_text, outbound_queue) + ) + + try: + while True: + message = await websocket.receive_json() + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + websocket, + message, + default_reason="Message rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + msg_type = message.get("type") + + if msg_type == "callback_request": + request_id = message.get("requestId") + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, + renderer_id, + outbound_queue, + shutdown_event, + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + shutdown_event, + ) + ) + pending_callbacks[request_id] = future + + elif msg_type == "get_props_response": + # Put response in waiting queue (non-blocking) + request_id = message.get("requestId") + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) + + elif msg_type == "heartbeat": + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') + + except WebSocketDisconnect: + pass # Clean disconnect + finally: + # Signal shutdown to worker threads + shutdown_event.set() + # Unblock any threads waiting on get_prop responses + for response_queue in pending_get_props.values(): + response_queue.put_nowait(DISCONNECTED) + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() + + self.server.add_api_websocket_route(ws_path, websocket_handler) + + +class FastAPIRequestAdapter(RequestAdapter): + def __init__(self): + self._request: Request = get_current_request() + super().__init__() + + def __call__(self): + self._request = get_current_request() + return self + + @property + def context(self): + if self._request is None: + raise RuntimeError("No active request in context") + + return self._request.state + + @property + def root(self): + return str(self._request.base_url) + + @property + def args(self): + return self._request.query_params + + @property + def is_json(self): + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) + + @property + def cookies(self): + return self._request.cookies + + @property + def headers(self): + return self._request.headers + + @property + def full_path(self): + return str(self._request.url) + + @property + def url(self): + return str(self._request.url) + + @property + def remote_addr(self): + client = getattr(self._request, "client", None) + return getattr(client, "host", None) + + @property + def origin(self): + return self._request.headers.get("origin") + + @property + def path(self): + return self._request.url.path + + async def _get_json(self, request: Request = None): + req = self._request + if not hasattr(req.state, "json_body"): + req.state.json_body = await request.json() + return req.state.json_body + + def get_json(self): + if not hasattr(self, "_request") or self._request is None: + self._request = get_current_request() + return self._request.state.json_body diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py new file mode 100644 index 0000000000..00c8730d8a --- /dev/null +++ b/dash/backends/_flask.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import asyncio +import pkgutil +import sys +import mimetypes +import time +import inspect +import traceback + +from contextvars import copy_context +from typing import TYPE_CHECKING, Any, Callable, Dict + +from importlib_metadata import version as _get_distribution_version + +from flask import ( + Flask, + Blueprint, + Response, + request, + jsonify, + g as flask_g, + has_request_context, + redirect, +) +from werkzeug.debug import tbtools + +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash._callback import _invoke_callback, _async_invoke_callback +from dash._utils import parse_version +from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter + + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash import Dash + + +class FlaskResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps Flask's Response + and provides a set_response() method for compatibility with Dash's callback system. + """ + + def __init__(self): + self._flask_response = Response(content_type="application/json") + super().__init__() + + @property + def callback_response(self) -> Response: + return self._flask_response + + def set_cookie(self, key, value="", **kwargs): + self._flask_response.set_cookie(key, value, **kwargs) + + def append_header(self, key, value): + self._flask_response.headers.add(key, value) + + def set_header(self, key, value): + self._flask_response.headers.set(key, value) + + def set_response(self, **kwargs): + self._flask_response.set_data(kwargs.get("data", "")) + return self._flask_response + + +class FlaskDashServer(BaseDashServer[Flask]): + def __init__(self, server: Flask) -> None: + super().__init__(server) + self.server_type = "flask" + self.request_adapter = FlaskRequestAdapter + self.response_adapter = FlaskResponseAdapter + + def __call__(self, *args: Any, **kwargs: Any): + # Always WSGI + return self.server(*args, **kwargs) + + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ): + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + self.server.register_blueprint(bp) + + def register_error_handlers(self): + @self.server.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @self.server.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def _get_traceback(self, secret, error: Exception): + def _get_skip(error): + tb = error.__traceback__ + skip = 1 + while tb.tb_next is not None: + skip += 1 + tb = tb.tb_next + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return skip + return skip + + def _do_skip(error): + tb = error.__traceback__ + while tb.tb_next is not None: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return tb.tb_next + tb = tb.tb_next + return error.__traceback__ + + if hasattr(tbtools, "get_current_traceback"): + return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() + if hasattr(tbtools, "DebugTraceback"): + return tbtools.DebugTraceback( + error, skip=_get_skip(error) + ).render_debugger_html(True, secret, True) + return "".join(traceback.format_exception(type(error), error, _do_skip(error))) + + def register_prune_error_handler(self, secret, prune_errors): + if prune_errors: + + @self.server.errorhandler(Exception) + def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return tb, 500 + + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def before_request(self, func: Callable[[], Any]): + # Flask expects a callable; user responsibility not to pass None + self.server.before_request(func) + + def after_request(self, func: Callable[[Any], Any]): + # Flask after_request expects a function(response) -> response + self.server.after_request(func) + + def has_request_context(self) -> bool: + return has_request_context() + + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any): + self.server.run(host=host, port=port, debug=debug, **kwargs) + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + status: int | None = None, + ): + return Response( + data, mimetype=mimetype, content_type=content_type, status=status + ) + + def jsonify(self, obj: Any): + return jsonify(obj) + + def setup_catchall(self, dash_app: Dash): + def catchall(*args, **kwargs): + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def setup_index(self, dash_app: Dash): + def index(*args, **kwargs): + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = Response(None, status=304) + return response + + def setup_component_suites(self, dash_app: Dash): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + + # pylint: disable=unused-argument + def serve_callback(self, dash_app: Dash): + def _dispatch(): + body = request.get_json() + # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + return cb_ctx.dash_response.set_response(data=response_data) + + async def _dispatch_async(): + body = request.get_json() + # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + return cb_ctx.dash_response.set_response(data=response_data) + + if dash_app._use_async: # pylint: disable=protected-access + return _dispatch_async + return _dispatch + + def register_timing_hooks(self, _first_run: bool): + # Define timing hooks inside method scope and register them + def _before_request() -> None: + flask_g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response: Response): # type: ignore[name-defined] + timing_information = flask_g.get("timing_information", None) # type: ignore[attr-defined] + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(_before_request) + self.after_request(_after_request) + + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + + async def _async_view_func(*args, handler=handler, **kwargs): + data = request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + + view_func = _async_view_func + else: + + def _sync_view_func(*args, handler=handler, **kwargs): + data = request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + view_func = _sync_view_func + + view_func = _sync_view_func + + # Flask 2.x+ supports async views natively + self.server.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) + + def enable_compression(self) -> None: + try: + import flask_compress # pylint: disable=import-outside-toplevel + + Compress = flask_compress.Compress + Compress(self.server) + _flask_compress_version = parse_version( + _get_distribution_version("flask_compress") + ) + if not hasattr( + self.server.config, "COMPRESS_ALGORITHM" + ) and _flask_compress_version >= parse_version("1.6.0"): + self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] + except ImportError as error: + raise ImportError( + "To use the compress option, you need to install dash[compress]" + ) from error + + +class FlaskRequestAdapter(RequestAdapter): + """Flask implementation using property-based accessors.""" + + def __init__(self) -> None: + # Store the request LocalProxy so we can reference it consistently + self._request = request + super().__init__() + + def __call__(self, *args: Any, **kwds: Any): + return self + + @property + def context(self): + if not has_request_context(): + raise RuntimeError("No active request in context") + return flask_g + + @property + def args(self): + return self._request.args + + @property + def root(self): + return self._request.url_root + + def get_json(self): # kept as method + return self._request.get_json() + + @property + def is_json(self): + return self._request.is_json + + @property + def cookies(self): + return self._request.cookies + + @property + def headers(self): + return self._request.headers + + @property + def url(self): + return self._request.url + + @property + def full_path(self): + return self._request.full_path + + @property + def remote_addr(self): + return self._request.remote_addr + + @property + def origin(self): + return getattr(self._request, "origin", None) + + @property + def path(self): + return self._request.path diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py new file mode 100644 index 0000000000..0916f206fb --- /dev/null +++ b/dash/backends/_quart.py @@ -0,0 +1,731 @@ +from __future__ import annotations + +import typing as _t +import mimetypes +import inspect +import pkgutil +import time +import sys +import asyncio +import concurrent.futures +import queue +import threading +from urllib.parse import urlparse + +from logging.config import dictConfig +from contextvars import copy_context +from typing import Any, Dict, TYPE_CHECKING + +from importlib_metadata import version as _get_distribution_version + +# Attempt top-level Quart imports; allow absence if user not using quart backend +try: + from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g as quart_g, + has_request_context, + redirect, + websocket, + ) +except ImportError as _err: + raise ImportError( + "All dependencies not installed. Please install it with `dash[quart]` to use the Quart backend." + ) from _err + +import janus + +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.fingerprint import check_fingerprint +from dash._utils import parse_version +from dash import _validate +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, +) +from .ws import ( + DashWebsocketCallback, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, + DISCONNECTED, +) +from ._utils import format_traceback_html + +if TYPE_CHECKING: + from dash import Dash + + +class QuartResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps Quart's Response + and provides a set_response() method for compatibility with Dash's callback system. + """ + + def __init__(self): + self._quart_response = Response(content_type="application/json") + super().__init__() + + @property + def callback_response(self) -> Response: + return self._quart_response + + def set_cookie(self, key, value="", **kwargs): + self._quart_response.set_cookie(key, value, **kwargs) + + def append_header(self, key, value): + self._quart_response.headers.add(key, value) + + def set_header(self, key, value): + self._quart_response.headers.set(key, value) + + def set_response(self, **kwargs): + self._quart_response.set_data(kwargs.get("data", "")) + return self._quart_response + + +class QuartDashServer(BaseDashServer[Quart]): + websocket_capability: bool = True + + def __init__(self, server: Quart) -> None: + super().__init__(server) + self.server_type = "quart" + self.config = {} + self.error_handling_mode = "ignore" + self.request_adapter = QuartRequestAdapter + self.response_adapter = QuartResponseAdapter + self._active_websockets: set = set() + self._ws_shutdown_event: asyncio.Event | None = None + + def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] + return self.server(*args, **kwargs) + + @staticmethod + def create_app( + name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None + ): + if Quart is None: + raise RuntimeError( + "Quart is not installed. Install with 'pip install quart' to use the quart backend." + ) + app = Quart(name) # type: ignore + if config: + for key, value in config.items(): + app.config[key] = value + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str # type: ignore[name-defined] + ): + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + self.server.register_blueprint(bp) + + def _get_traceback(self, _secret, error: Exception): + return format_traceback_html( + error, self.error_handling_mode, "Quart Debugger", "Quart" + ) + + def register_prune_error_handler(self, secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + + @self.server.errorhandler(Exception) + async def _wrap_errors(error): + if self.error_handling_mode == "ignore": + return Response( + "Internal server error.", status=500, content_type="text/plain" + ) + tb = self._get_traceback(secret, error) + return Response(tb, status=500, content_type="text/html") + + def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + @self.server.before_request + async def _before_request(): # pragma: no cover - timing infra + if quart_g is not None: + quart_g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } + + @self.server.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = ( + getattr(quart_g, "timing_information", None) + if quart_g is not None + else None + ) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response + + def register_error_handlers(self): # type: ignore[name-defined] + @self.server.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @self.server.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): + async def wrapped(*_args, **_kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return Response(html, content_type="text/html") + + return wrapped + + def add_url_rule( + self, + rule: str, + view_func: _t.Callable[..., _t.Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def setup_index(self, dash_app: Dash): # type: ignore[name-defined] + async def index(*args, **kwargs): + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app: Dash): + async def catchall( + path: str, *args, **kwargs + ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def before_request(self, func: _t.Callable[[], _t.Any]): + self.server.before_request(func) + + def after_request(self, func: _t.Callable[[], _t.Any]): + @self.server.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + def has_request_context(self) -> bool: + if has_request_context is None: + raise RuntimeError("Quart not installed; cannot check request context") + return has_request_context() + + # pylint: disable=W0613 + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): + import signal # pylint: disable=import-outside-toplevel + + # pylint: disable=import-outside-toplevel,import-error + from hypercorn.config import Config + from hypercorn.asyncio import serve + + # pylint: enable=import-error + + self.config = {"debug": debug, **kwargs} if debug else kwargs + # pylint: disable=protected-access + if dash_app._dev_tools.silence_routes_logging: + dictConfig( + { + "version": 1, + "loggers": { + "quart.app": { + "level": "ERROR", + }, + }, + } + ) + + # Check if we're running in a non-main thread (e.g., testing context) + is_main_thread = threading.current_thread() is threading.main_thread() + + config = Config() + config.bind = [f"{host}:{port}"] + config.use_reloader = False + if not is_main_thread: + config.accesslog = None + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Initialize shutdown event for WebSocket handlers + self._ws_shutdown_event = asyncio.Event() + + def signal_handler(): + """Handle shutdown signal by setting the WebSocket shutdown event.""" + if self._ws_shutdown_event is not None: + self._ws_shutdown_event.set() + + # Set up signal handlers in main thread + if is_main_thread: + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, signal_handler) + except (NotImplementedError, ValueError): + pass + + print(f" * Serving Quart app '{self.server.name}'") + print(f" * Debug mode: {debug}") + print( + " * Please use an ASGI server (e.g. Hypercorn) directly in production" + ) + print(f" * Running on http://{host}:{port} (CTRL + C to quit)") + + async def shutdown_trigger(): + if self._ws_shutdown_event is not None: + await self._ws_shutdown_event.wait() + + try: + loop.run_until_complete( + serve(self.server, config, shutdown_trigger=shutdown_trigger) + ) + finally: + loop.close() + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + status=None, + ): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") + return Response( + data, mimetype=mimetype, content_type=content_type, status=status + ) + + def jsonify(self, obj): + return jsonify(obj) + + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): # noqa: ARG002 unused req preserved for interface parity + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") + return Response(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app: Dash): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + + # pylint: disable=unused-argument + def serve_callback(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + async def _dispatch(): + adapter = QuartRequestAdapter() + body = await adapter.get_json() + # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context(body) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, body) + # pylint: disable=protected-access + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + # pylint: disable=protected-access + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return cb_ctx.dash_response.set_response(data=response_data) # type: ignore[arg-type] + + return _dispatch + + def register_callback_api_routes( + self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]] + ): + """ + Register callback API endpoints on the Quart app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + def _make_view_func(handler): + if inspect.iscoroutinefunction(handler): + + async def async_view_func(*args, **kwargs): + if request is None: + raise RuntimeError( + "Quart not installed; request unavailable" + ) + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) # type: ignore[arg-type] + + return async_view_func + + async def sync_view_func(*args, **kwargs): + if request is None: + raise RuntimeError("Quart not installed; request unavailable") + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) # type: ignore[arg-type] + + return sync_view_func + + view_func = _make_view_func(handler) + self.server.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) + + def enable_compression(self) -> None: + try: + import quart_compress # pylint: disable=import-outside-toplevel + + Compress = quart_compress.Compress + Compress(self.server) + _flask_compress_version = parse_version( + _get_distribution_version("quart_compress") + ) + if not hasattr( + self.server.config, "COMPRESS_ALGORITHM" + ) and _flask_compress_version >= parse_version("1.6.0"): + self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] + except ImportError as error: + raise ImportError( + "To use the compress option, you need to install quart_compress." + ) from error + + async def _run_ws_hooks( + self, hooks, ws, *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + ws: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(ws, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def _validate_ws_origin( + self, origin: str | None, host: str | None, allowed_origins: list + ) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements,too-many-locals + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + # pylint: disable=protected-access + allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) + + @self.server.websocket(ws_path) + async def websocket_handler(): # pylint: disable=too-many-branches + ws = websocket + + # Validate Origin header + error = self._validate_ws_origin( + ws.headers.get("origin"), ws.headers.get("host"), allowed_origins + ) + if error: + await ws.close(code=4003, reason=error) + return + + # Call websocket_connect hooks + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + ws, + default_reason="Connection rejected", + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + + await ws.accept() + + # Track this connection for graceful shutdown + try: + ws_obj = ws._get_current_object() + self._active_websockets.add(ws_obj) + except AttributeError: + ws_obj = ws + self._active_websockets.add(ws) + + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Shutdown event to signal connection closure to worker threads + connection_shutdown_event = threading.Event() + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task(run_ws_sender(ws.send, outbound_queue)) + + try: + shutdown_event = self._ws_shutdown_event + while shutdown_event is None or not shutdown_event.is_set(): + try: + # Use timeout to periodically check shutdown event + message = await asyncio.wait_for(ws.receive_json(), timeout=1.0) + except asyncio.TimeoutError: + # Re-check shutdown event (may have been set during run()) + shutdown_event = self._ws_shutdown_event + continue + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + ws, + message, + default_reason="Message rejected", + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + + msg_type = message.get("type") + + if msg_type == "callback_request": + request_id = message.get("requestId") + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, + renderer_id, + outbound_queue, + connection_shutdown_event, + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + connection_shutdown_event, + ) + ) + pending_callbacks[request_id] = future + + elif msg_type == "get_props_response": + # Put response in waiting queue (non-blocking) + request_id = message.get("requestId") + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) + + elif msg_type == "heartbeat": + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') + + except asyncio.CancelledError: + pass # Server is shutting down, exit gracefully + except Exception: # pylint: disable=broad-exception-caught + pass # Other exceptions treated as disconnect + finally: + self._active_websockets.discard(ws_obj) + # Signal shutdown to worker threads + connection_shutdown_event.set() + # Unblock any threads waiting on get_prop responses + for response_queue in pending_get_props.values(): + response_queue.put_nowait(DISCONNECTED) + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() + + +class QuartRequestAdapter(RequestAdapter): + def __init__(self) -> None: + self._request = request # type: ignore[assignment] + if self._request is None: + raise RuntimeError("Quart not installed; cannot access request context") + + @property + def context(self): + if not has_request_context(): + raise RuntimeError("No active request in context") + return quart_g + + @property + def request(self) -> _t.Any: + return self._request + + @property + def root(self): + return self.request.root_url + + @property + def args(self): + return self.request.args + + @property + def is_json(self): + return self.request.is_json + + @property + def cookies(self): + return self.request.cookies + + @property + def headers(self): + return self.request.headers + + @property + def full_path(self): + return self.request.full_path + + @property + def url(self): + return str(self.request.url) + + @property + def remote_addr(self): + return self.request.remote_addr + + @property + def origin(self): + return self.request.headers.get("origin") + + @property + def path(self): + return self.request.path + + async def get_json(self): # pylint: disable=W0236 + # TODO consider using a sync wraper + return await self.request.get_json() diff --git a/dash/backends/_utils.py b/dash/backends/_utils.py new file mode 100644 index 0000000000..1191d21038 --- /dev/null +++ b/dash/backends/_utils.py @@ -0,0 +1,108 @@ +import traceback +import re + + +def format_traceback_html(error, error_handling_mode, title, backend): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+            + "\n".join(card)
+            + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // {title} + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your Dash application. You can now + look at the traceback which led to the error. +
+
+ Brought to you by DON'T PANIC, your + friendly {backend} powered traceback interpreter. +
+
+ + + """ + return html diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py new file mode 100644 index 0000000000..52443d4104 --- /dev/null +++ b/dash/backends/base_server.py @@ -0,0 +1,422 @@ +"""Base server abstractions for Dash backend implementations. + +This module provides abstract base classes and protocols that define the interface +for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + Dict, + Type, + TypeVar, + Generic, + Protocol, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + import dash + + +class _ServerCallable(Protocol): # pylint: disable=too-few-public-methods + """Protocol for callable server instances. + + Defines the interface for server objects that can be called as WSGI/ASGI applications. + """ + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise NotImplementedError + + +ServerType = TypeVar("ServerType", bound=_ServerCallable) + + +class RequestAdapter(ABC): + """Abstract adapter for normalizing HTTP request objects across different server backends. + + This adapter provides a unified interface for accessing request data regardless of + the underlying web framework (Flask, Quart, FastAPI, etc.). Concrete implementations + wrap framework-specific request objects and expose their data through these properties. + """ + + def __call__(self) -> "RequestAdapter": + return self + + @property + @abstractmethod + def context(self) -> Any: # pragma: no cover - interface + """Get the framework-specific request context object.""" + raise NotImplementedError() + + # Properties to be implemented in concrete adapters + @property # pragma: no cover - interface + @abstractmethod + def root(self) -> str: + """Get the application root path.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def args(self): + """Get the request query string arguments.""" + raise NotImplementedError() + + @abstractmethod # kept as method (may be sync or async) + def get_json(self): # pragma: no cover - interface + """Get the parsed JSON body of the request. + + May be synchronous or asynchronous depending on the backend. + """ + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def is_json(self) -> bool: + """Check if the request has a JSON content type.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def cookies(self): + """Get the request cookies.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def headers(self): + """Get the request headers.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def full_path(self) -> str: + """Get the full request path including query string.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def url(self) -> str: + """Get the full request URL.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def remote_addr(self): + """Get the remote client IP address.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def origin(self): + """Get the Origin header value.""" + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def path(self) -> str: + """Get the request path without query string.""" + raise NotImplementedError() + + +class ResponseAdapter: + """Adapter for server response objects to allow setting data.""" + + def __init__(self): + # Accept a pre-made response object + self._headers = {} + self._cookies = {} + + @property + def callback_response(self): + """Get the response object to be returned from a callback.""" + # This method should be overridden in concrete implementations to return the appropriate response object + raise NotImplementedError() + + def set_cookie(self, key, value="", **kwargs): + """Set a cookie in the response (like Flask's set_cookie).""" + # Store as a tuple: (value, kwargs) + self._cookies[key] = (value, kwargs) + + def append_header(self, key, value): + """Add a header to the response (like Flask's headers.add).""" + # Allow multiple values per header key + if key in self._headers: + if isinstance(self._headers[key], list): + self._headers[key].append(value) + else: + self._headers[key] = [self._headers[key], value] + else: + self._headers[key] = value + + def set_header(self, key, value): + """Set a header to the response.""" + self._headers[key] = [value] + + def set_response(self, **kwargs): + """Set the response data if supported by the response object.""" + raise NotImplementedError() + + +class BaseDashServer(ABC, Generic[ServerType]): + """Abstract base class for Dash server backend implementations. + + This class defines the interface that all server backends must implement to + work with Dash. Concrete implementations exist for Flask, Quart, FastAPI, and + other web frameworks. + + Attributes: + server_type: String identifier for the server backend (e.g., 'flask', 'quart') + server: The underlying server instance + config: Configuration dictionary for the server + request_adapter: RequestAdapter class for normalizing requests + """ + + server_type: str + server: ServerType + config: Dict[str, Any] + request_adapter: Type[RequestAdapter] + response_adapter: Type[ResponseAdapter] + websocket_capability: bool = False + + def __init__(self, server: ServerType) -> None: + """Initialize the server wrapper. + + Args: + server: The underlying server instance to wrap + """ + super().__init__() + self.server = server + self._callback_executor: ThreadPoolExecutor | None = None + + def get_callback_executor( + self, max_workers: int | None = None + ) -> ThreadPoolExecutor: + """Get or create the thread pool executor for callback execution. + + Args: + max_workers: Maximum number of worker threads. If None, uses default. + + Returns: + ThreadPoolExecutor instance for running callbacks. + """ + if self._callback_executor is None: + self._callback_executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) + return self._callback_executor + + def shutdown_executor(self, wait: bool = True) -> None: + """Shutdown the callback executor. + + Args: + wait: If True, wait for pending tasks to complete. + """ + if self._callback_executor is not None: + self._callback_executor.shutdown(wait=wait) + self._callback_executor = None + + def __call__(self, *args, **kwargs) -> Any: + """Make the server wrapper callable as a WSGI/ASGI application. + + Delegates to the underlying server instance. + """ + # Default: WSGI + return self.server(*args, **kwargs) + + @staticmethod + @abstractmethod + def create_app( + name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface + """Create a new server application instance. + + Args: + name: Application name, defaults to '__main__' + config: Configuration dictionary or object + + Returns: + The server application instance + """ + + @abstractmethod + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + """Register a blueprint/router for serving static assets. + + Args: + blueprint_name: Name for the assets blueprint + assets_url_path: URL path prefix for assets + assets_folder: Filesystem path to the assets folder + """ + + @abstractmethod + def register_error_handlers(self) -> None: # pragma: no cover - interface + """Register error handlers for common HTTP errors.""" + + @abstractmethod + def add_url_rule( + self, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface + """Add a URL routing rule. + + Args: + rule: URL pattern/route + view_func: View function to handle the route + endpoint: Optional endpoint name + methods: Optional list of HTTP methods (e.g., ['GET', 'POST']) + """ + + @abstractmethod + def before_request(self, func) -> None: # pragma: no cover - interface + """Register a function to run before each request. + + Args: + func: Function to execute before request handling + """ + + @abstractmethod + def after_request(self, func) -> None: # pragma: no cover - interface + """Register a function to run after each request. + + Args: + func: Function to execute after request handling + """ + + @abstractmethod + def has_request_context(self) -> bool: # pragma: no cover - interface + """Check if currently executing within a request context. + + Returns: + True if in request context, False otherwise + """ + + @abstractmethod + def run( + self, dash_app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface + """Start the development server. + + Args: + dash_app: The Dash application instance + host: Hostname to bind to + port: Port number to bind to + debug: Enable debug mode + **kwargs: Additional server-specific arguments + """ + + @abstractmethod + def make_response( + self, + data, + mimetype=None, + content_type=None, + status=None, + ) -> Any: # pragma: no cover - interface + """Create an HTTP response object. + + Args: + data: Response body data + mimetype: MIME type of the response + content_type: Content-Type header value + status: HTTP status code + + Returns: + Server-specific response object + """ + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + """Convert an object to a JSON response. + + Args: + obj: Object to serialize to JSON + + Returns: + JSON response object + """ + + @abstractmethod + def enable_compression(self) -> None: # pragma: no cover - interface + """Enable HTTP compression for responses.""" + + @abstractmethod + def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: + """Register handler for pruning error stack traces. + + Args: + secret: Secret key for error handling + prune_errors: Whether to prune stack traces in errors + """ + + @abstractmethod + def register_timing_hooks(self, first_run: bool) -> None: + """Register hooks for timing request/response cycles. + + Args: + first_run: Whether this is the first run of the application + """ + + @abstractmethod + def register_callback_api_routes(self, callback_api_paths): + """Register routes for Dash callback API endpoints. + + Args: + callback_api_paths: Paths for callback API endpoints + """ + + @abstractmethod + def setup_component_suites(self, dash_app: "dash.Dash") -> str: + """Set up routes for serving component JavaScript bundles. + + Args: + dash_app: The Dash application instance + + Returns: + Base path for component suites + """ + + @abstractmethod + def serve_callback(self, dash_app: "dash.Dash"): + """Set up the callback handling endpoint. + + Args: + dash_app: The Dash application instance + """ + + @abstractmethod + def setup_index(self, dash_app: "dash.Dash"): + """Set up the index/root route for serving the main application. + + Args: + dash_app: The Dash application instance + """ + + @abstractmethod + def setup_catchall(self, dash_app: "dash.Dash"): + """Set up the catchall route for client-side routing. + + Args: + dash_app: The Dash application instance + """ + + def setup_backend(self, dash_app: "dash.Dash"): + """Perform any additional backend-specific setup. + + Override this method in concrete implementations to provide custom setup logic. + + Args: + dash_app: The Dash application instance + """ + + def serve_websocket_callback(self, dash_app: "dash.Dash"): + """Set up the WebSocket endpoint for callback handling. + + Override this method in backends that support WebSocket callbacks. + + Args: + dash_app: The Dash application instance + """ diff --git a/dash/backends/ws.py b/dash/backends/ws.py new file mode 100644 index 0000000000..041241823e --- /dev/null +++ b/dash/backends/ws.py @@ -0,0 +1,333 @@ +"""WebSocket callback support for Dash backend implementations. + +This module provides the WebSocket callback infrastructure for real-time +bidirectional communication between Dash backends and the renderer. +""" +from __future__ import annotations + +import asyncio +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import inspect +import json +import queue +import threading +import traceback +import uuid +from contextvars import copy_context +from typing import Any, Callable, Dict, TYPE_CHECKING, cast + +import janus + +from dash.exceptions import PreventUpdate, WebsocketDisconnected +from dash._utils import to_json + +if TYPE_CHECKING: + import dash + from .base_server import ResponseAdapter + + +SHUTDOWN_SIGNAL = "__shutdown__" +DISCONNECTED = "__disconnected__" + + +class DashWebsocketCallback: + """WebSocket callback communication via queues. + + Provides methods for real-time bidirectional communication between + the server and renderer during callback execution. + + Uses janus.Queue for outbound messages (serialized with to_json) and + queue.Queue for get_props responses, enabling thread-safe communication + between worker threads and the main event loop. + """ + + def __init__( + self, + pending_get_props: Dict[str, queue.Queue[Any]], + renderer_id: str, + outbound_queue: janus.Queue[str], + shutdown_event: "threading.Event", + ): + """Initialize the WebSocket callback interface. + + Args: + pending_get_props: Dict to track pending get_props requests. + Values are queue.Queue instances for blocking response retrieval. + renderer_id: The renderer ID for routing messages back to the correct client + outbound_queue: janus.Queue for thread-safe outbound messaging. + shutdown_event: Event signaling the websocket connection has closed. + """ + self._pending_get_props = pending_get_props + self._renderer_id = renderer_id + self._outbound_queue = outbound_queue + self._shutdown_event = shutdown_event + + @property + def is_shutdown(self) -> bool: + """Check if the websocket connection has been shut down.""" + return self._shutdown_event.is_set() + + def _queue_message(self, msg: dict) -> None: + """Serialize and queue message for sending (thread-safe, non-blocking). + + Uses to_json for proper serialization of Dash components. + Does nothing if the connection has been shut down. + """ + if self._shutdown_event.is_set(): + return + self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) + + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Queues the message for the sender coroutine to send. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + msg = { + "type": "set_props", + "rendererId": self._renderer_id, + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + self._queue_message(msg) + + async def get_prop( + self, component_id: str, prop_name: str, timeout: float = 30.0 + ) -> Any: + """Request current prop value from the client. + + Uses queue.Queue for blocking wait in worker thread. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + timeout: Timeout in seconds for waiting for response + + Returns: + The current value of the property from the client's state + + Raises: + WebsocketDisconnected: If the websocket connection has been closed. + TimeoutError: If the response doesn't arrive within the timeout. + """ + if self._shutdown_event.is_set(): + raise WebsocketDisconnected() + + request_id = str(uuid.uuid4()) + msg = { + "type": "get_props_request", + "rendererId": self._renderer_id, + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + + # Use standard queue.Queue for response + response_queue: queue.Queue = queue.Queue() + self._pending_get_props[request_id] = response_queue + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + # Wait for response (blocking is OK in worker thread) + try: + result = response_queue.get(timeout=timeout) + if result == DISCONNECTED: + raise WebsocketDisconnected() + if result and prop_name in result: + return result[prop_name] + return None + except queue.Empty as exc: + raise TimeoutError( + f"Timeout waiting for {component_id}.{prop_name}" + ) from exc + finally: + self._pending_get_props.pop(request_id, None) + + +def create_ws_context( + payload: dict, + response_adapter: "ResponseAdapter", + websocket_callback: DashWebsocketCallback, +): + """Create callback context from WebSocket message. + + Args: + payload: The callback payload + response_adapter: The response adapter instance for the backend + websocket_callback: The websocket callback instance for the backend + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = response_adapter + g.updated_props = {} + g.dash_websocket = websocket_callback + + return g + + +async def run_ws_sender( + send_text: Callable[[str], Any], outbound_queue: janus.Queue[str] +) -> None: + """Sender coroutine - drains queue and sends to WebSocket. + + This coroutine runs in the main event loop and handles sending + messages that are queued by worker threads via janus.Queue. + + Messages are pre-serialized strings (using to_json). + + Args: + send_text: Async function to send text data over WebSocket + outbound_queue: janus.Queue instance for receiving messages (strings) + """ + try: + while True: + msg = await outbound_queue.async_q.get() + if msg == SHUTDOWN_SIGNAL: + break + await send_text(msg) + except asyncio.CancelledError: + pass + + +def make_callback_done_handler( + outbound_queue: janus.Queue[str], + pending_callbacks: Dict[str, concurrent.futures.Future], + request_id: str, + renderer_id: str, + shutdown_event: threading.Event, +) -> Callable[[concurrent.futures.Future], None]: + """Create a done callback handler for executor futures. + + This factory creates a callback that sends the result back through + the WebSocket when an executor future completes. + + Args: + outbound_queue: janus.Queue for sending responses + pending_callbacks: Dict tracking pending callbacks for cleanup + request_id: The request ID for the callback response + renderer_id: The renderer ID for routing the response + shutdown_event: Event signaling the websocket connection has closed. + + Returns: + A callback function suitable for Future.add_done_callback() + """ + + def on_done(f: concurrent.futures.Future) -> None: + try: + if shutdown_event.is_set(): + return + result = f.result() + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + if shutdown_event.is_set(): + return + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) + ) + finally: + pending_callbacks.pop(request_id, None) + + return on_done + + +def run_callback_in_executor( + executor: ThreadPoolExecutor, + dash_app: "dash.Dash", + payload: dict, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> concurrent.futures.Future: + """Submit callback to executor for thread pool execution. + + This function creates a callback execution context and runs it + in a separate thread. Both sync and async callbacks are supported. + + Args: + executor: ThreadPoolExecutor to submit the task to + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + Future representing the pending callback execution + """ + + def execute() -> dict: + try: + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + ) + + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result + + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except WebsocketDisconnected: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + return executor.submit(execute) diff --git a/dash/dash-renderer/init.template b/dash/dash-renderer/init.template index 463cfa02aa..a6b84d3d70 100644 --- a/dash/dash-renderer/init.template +++ b/dash/dash-renderer/init.template @@ -75,4 +75,9 @@ _js_dist = [ "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index f92d22cfc5..a404fa2425 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -13,7 +13,7 @@ "build:dev": "webpack", "build:local": "renderer build local", "build": "renderer build && npm run prepublishOnly", - "postbuild": "es-check es2015 ../deps/*.js build/*.js", + "postbuild": "es-check es2015 ../deps/*.js build/dash_renderer.*.js", "test": "karma start karma.conf.js --single-run", "format": "run-s private::format.*", "lint": "run-s private::lint.* --continue-on-error" diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 343789ca43..f9d8b06f1a 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -1,9 +1,14 @@ import PropTypes from 'prop-types'; -import React, {useState} from 'react'; +import React, {useState, useEffect} from 'react'; import {Provider} from 'react-redux'; import Store from './store'; import AppContainer from './AppContainer.react'; +import getConfigFromDOM from './config'; +import { + initializeWebSocket, + disconnectWebSocket +} from './observers/websocketObserver'; const AppProvider = ({ hooks = { @@ -16,6 +21,35 @@ const AppProvider = ({ } }: any) => { const [{store}] = useState(() => new Store()); + + // Initialize WebSocket connection if enabled or if websocket config is available + // (for per-callback websocket=True) + useEffect(() => { + const config = getConfigFromDOM(); + if ( + config.websocket?.enabled || + (config.websocket?.url && config.websocket?.worker_url) + ) { + // Add fetch config for consistency + const fullConfig = { + ...config, + fetch: { + credentials: 'same-origin', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json' + } + } + }; + initializeWebSocket(store, fullConfig); + } + + // Cleanup on unmount + return () => { + disconnectWebSocket(); + }; + }, [store]); + return ( diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 206839cd07..e6604a2337 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -52,6 +52,11 @@ import {parsePMCId} from './patternMatching'; import {replacePMC} from './patternMatching'; import {loaded, loading} from './loading'; import {getComponentLayout} from '../wrapper/wrapping'; +import { + getWorkerClient, + isWebSocketEnabled, + isWebSocketAvailable +} from '../utils/workerClient'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -685,6 +690,140 @@ function handleServerside( }); } +/** + * Handle serverside callback via WebSocket connection. + * + * Uses the SharedWorker to send the callback request through the persistent + * WebSocket connection instead of HTTP POST. + */ +async function handleWebsocketCallback( + dispatch: any, + hooks: any, + config: any, + payload: ICallbackPayload, + running: any +): Promise { + if (hooks.request_pre) { + hooks.request_pre(payload); + } + + const requestTime = Date.now(); + let runningOff: any; + + if (running) { + dispatch(sideUpdate(running.running, payload)); + runningOff = running.runningOff; + } + + const workerClient = getWorkerClient(); + + try { + // Ensure WebSocket connection is established + await workerClient.ensureConnected(config); + + const response = await workerClient.sendCallback(payload); + + // Handle running off state + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (response.status === 'prevent_update') { + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.PREVENT_UPDATE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + return {}; + } + + if (response.status === 'error') { + throw new Error(response.message || 'Callback error'); + } + + // Extract the callback data - structure is {multi: boolean, response: {...}} + const callbackData = response.data as CallbackResponseData; + + // Handle sideUpdate if present + if (callbackData?.sideUpdate) { + dispatch(sideUpdate(callbackData.sideUpdate, payload)); + } + + // Extract the actual outputs from the response + // Format is similar to HTTP path's finishLine function + let result: CallbackResponse; + const {multi, response: callbackResponse} = callbackData || {}; + + if (hooks.request_post) { + hooks.request_post(payload, callbackResponse); + } + + if (multi) { + result = callbackResponse as CallbackResponse; + } else { + // Single output - convert to the expected format + const {output} = payload; + const id = output.substr(0, output.lastIndexOf('.')); + result = {[id]: (callbackResponse as CallbackResponse)?.props}; + } + + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.OK, + result: result || {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + return result || {}; + } catch (error) { + // Handle running off state on error + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (config.ui) { + dispatch( + updateResourceUsage({ + id: payload.output, + status: STATUS.NO_RESPONSE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + throw error; + } +} + function inputsToDict(inputs_list: any) { // Ported directly from _utils.py, inputs_to_dict // takes an array of inputs (some inputs may be an array) @@ -890,18 +1029,44 @@ export function executeCallback( } ); + // Use WebSocket for callbacks when: + // 1. Global WebSocket is enabled, OR + // 2. Per-callback websocket flag is set (and WebSocket is available) + // (but never for background callbacks) + const useWebSocket = + !background && + (isWebSocketEnabled(config) || + (cb.callback.websocket && + isWebSocketAvailable(config))); + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { - let data = await handleServerside( - dispatch, - hooks, - newConfig, - payload, - background, - additionalArgs.length ? additionalArgs : undefined, - getState, - cb.callback.running - ); + let data: CallbackResponse; + + if (useWebSocket) { + // Use WebSocket path for real-time callbacks + data = await handleWebsocketCallback( + dispatch, + hooks, + newConfig, + payload, + cb.callback.running + ); + } else { + // Use traditional HTTP path + data = await handleServerside( + dispatch, + hooks, + newConfig, + payload, + background, + additionalArgs.length + ? additionalArgs + : undefined, + getState, + cb.callback.running + ); + } if (newHeaders) { dispatch(addHttpHeaders(newHeaders)); diff --git a/dash/dash-renderer/src/actions/dependencies.js b/dash/dash-renderer/src/actions/dependencies.js index 7b5d1665f0..fa29199a1d 100644 --- a/dash/dash-renderer/src/actions/dependencies.js +++ b/dash/dash-renderer/src/actions/dependencies.js @@ -224,14 +224,17 @@ function validateDependencies(parsedDependencies, dispatchError) { 'In the callback for output(s):\n ' + outputs.map(combineIdAndProp).join('\n '); - if (!inputs.length) { + if (!inputs.length && dep.prevent_initial_call) { dispatchError('A callback is missing Inputs', [ head, 'there are no `Input` elements.', 'Without `Input` elements, it will never get called.', '', 'Subscribing to `Input` components will cause the', - 'callback to be called whenever their values change.' + 'callback to be called whenever their values change.', + '', + 'If you want a callback without inputs that fires on initial load,', + 'set prevent_initial_call=False.' ]); } diff --git a/dash/dash-renderer/src/actions/dependencies_ts.ts b/dash/dash-renderer/src/actions/dependencies_ts.ts index 33f968cf91..4056cdeac1 100644 --- a/dash/dash-renderer/src/actions/dependencies_ts.ts +++ b/dash/dash-renderer/src/actions/dependencies_ts.ts @@ -352,12 +352,18 @@ export const getLayoutCallbacks = ( export const getUniqueIdentifier = ({ anyVals, - callback: {inputs, outputs, state} -}: ICallback): string => - concat( - map(combineIdAndProp, [...inputs, ...outputs, ...state]), + callback: {inputs, outputs, state, output} +}: ICallback): string => { + const idParts = map(combineIdAndProp, [...inputs, ...outputs, ...state]); + // For no-output callbacks, include the output hash to ensure uniqueness + if (outputs.length === 0 && output) { + idParts.push(output); + } + return concat( + idParts, Array.isArray(anyVals) ? anyVals : anyVals === '' ? [] : [anyVals] ).join(','); +}; export function includeObservers( id: any, diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index 6169c4f65e..e595f79f9a 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -5,7 +5,12 @@ import {getAppState} from '../reducers/constants'; import {getAction} from './constants'; import * as cookie from 'cookie'; import {validateCallbacksToLayout} from './dependencies'; -import {includeObservers, getLayoutCallbacks} from './dependencies_ts'; +import { + includeObservers, + getLayoutCallbacks, + makeResolvedCallback, + resolveDeps +} from './dependencies_ts'; import {computePaths, getPath} from './paths'; import {recordUiEdit} from '../persistence'; @@ -95,13 +100,59 @@ function triggerDefaultState(dispatch, getState) { ); } - dispatch( - addRequestedCallbacks( - getLayoutCallbacks(graphs, paths, layout.components, { - outputsOnly: true - }) - ) + const layoutCallbacks = getLayoutCallbacks( + graphs, + paths, + layout.components, + { + outputsOnly: true + } ); + + // Also include no-output and no-input callbacks that should fire on initial load + const specialCallbacks = (graphs.callbacks || []).reduce((acc, cb) => { + if (cb.prevent_initial_call) { + return acc; + } + + const isNoOutput = cb.noOutput; + const isNoInput = !cb.noOutput && cb.inputs.length === 0; + + if (!isNoOutput && !isNoInput) { + return acc; + } + + const resolved = makeResolvedCallback(cb, resolveDeps(), ''); + resolved.initialCall = true; + + if (isNoOutput) { + // No-output: include if no inputs or any input is in layout + if (cb.inputs.length === 0) { + acc.push(resolved); + } else { + const inputs = resolved.getInputs(paths); + if ( + inputs.some(inp => + Array.isArray(inp) ? inp.length > 0 : inp + ) + ) { + acc.push(resolved); + } + } + } else { + // No-input: include if any output is in layout + const outputs = resolved.getOutputs(paths); + if ( + outputs.some(out => (Array.isArray(out) ? out.length > 0 : out)) + ) { + acc.push(resolved); + } + } + + return acc; + }, []); + + dispatch(addRequestedCallbacks([...layoutCallbacks, ...specialCallbacks])); } export const redo = moveHistory('REDO'); diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index 176cb2c6f8..db4c6ddd2b 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -121,13 +121,18 @@ function BackendError({error, base}) { const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { + // Helper to detect full HTML document + const isFullHtmlDoc = + typeof error.html === 'string' && + error.html.trim().toLowerCase().startsWith(' - {/* - * 40 is a rough heuristic - if longer than 40 then the - * message might overflow into ellipses in the title above & - * will need to be displayed in full in this error body - */} + {/* Frontend error message */} {typeof error.message !== 'string' || error.message.length < MAX_MESSAGE_LENGTH ? null : (
@@ -137,6 +142,7 @@ function UnconnectedErrorContent({error, base}) {
)} + {/* Frontend stack trace */} {typeof error.stack !== 'string' ? null : (
@@ -149,7 +155,6 @@ function UnconnectedErrorContent({error, base}) { browser's console.) - {error.stack.split('\n').map((line, i) => (

{line}

))} @@ -157,24 +162,30 @@ function UnconnectedErrorContent({error, base}) {
)} - {/* Backend Error */} - {typeof error.html !== 'string' ? null : error.html - .substring(0, '
- {/* Embed werkzeug debugger in an iframe to prevent - CSS leaking - werkzeug HTML includes a bunch - of CSS on base html elements like `` - */}
- ) : ( + ) : isHtmlFragment ? ( + // Backend error: HTML fragment +
+
+
+ ) : typeof error.html === 'string' ? ( + // Backend error: plain text
-
{error.html}
+
+
{error.html}
+
- )} + ) : null}
); } diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index d7f16beda8..6eb9e27b58 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -22,6 +22,12 @@ export type DashConfig = { serve_locally?: boolean; plotlyjs_url?: string; validate_callbacks: boolean; + websocket?: { + enabled: boolean; + url: string; + worker_url: string; + inactivity_timeout?: number; + }; csrf_token_name?: string; csrf_header_name?: string; }; diff --git a/dash/dash-renderer/src/observers/isLoading.ts b/dash/dash-renderer/src/observers/isLoading.ts index 687f607378..cc3bf193b8 100644 --- a/dash/dash-renderer/src/observers/isLoading.ts +++ b/dash/dash-renderer/src/observers/isLoading.ts @@ -9,7 +9,12 @@ const observer: IStoreObserverDefinition = { const pendingCallbacks = getPendingCallbacks(callbacks); - const next = Boolean(pendingCallbacks.length); + // Filter out persistent callbacks - they shouldn't trigger the loading indicator + const nonPersistentCallbacks = pendingCallbacks.filter( + cb => !cb.callback.persistent + ); + + const next = Boolean(nonPersistentCallbacks.length); if (isLoading !== next) { dispatch(setIsLoading(next)); diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts new file mode 100644 index 0000000000..7b75fada38 --- /dev/null +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -0,0 +1,215 @@ +/** + * Observer for handling incoming WebSocket messages (SET_PROPS, GET_PROPS_REQUEST). + */ + +/* eslint-disable no-console */ + +import {Store} from 'redux'; +import {path} from 'ramda'; + +import {IStoreState} from '../store'; +import {updateProps, notifyObservers, setPaths} from '../actions'; +import {parsePatchProps} from '../actions/patch'; +import {computePaths, getPath} from '../actions/paths'; +import { + getWorkerClient, + SetPropsPayload, + GetPropsRequestPayload +} from '../utils/workerClient'; +import {DashConfig} from '../config'; + +/** + * Parse a component ID that may be a stringified JSON object. + * This handles dict IDs like '{"index":0,"type":"output"}' that need + * to be parsed back to objects for getPath to work correctly. + */ +function parseComponentId( + componentId: string +): string | Record { + if (componentId.startsWith('{') && componentId.endsWith('}')) { + try { + return JSON.parse(componentId); + } catch { + // Not valid JSON, return as-is + return componentId; + } + } + return componentId; +} + +/** + * Initialize the WebSocket observer. + * + * Sets up handlers for: + * - SET_PROPS: Update component props when received from server + * - GET_PROPS_REQUEST: Send current prop values back to server + * + * @param store Redux store + * @param config Dash configuration + */ +export async function initializeWebSocket( + store: Store, + config: DashConfig +): Promise { + // Initialize WebSocket if: + // 1. Global websocket is enabled, OR + // 2. WebSocket config is available (for per-callback websocket=True) + const wsAvailable = !!( + config.websocket?.url && config.websocket?.worker_url + ); + if (!wsAvailable) { + return; + } + + // Check if SharedWorker is supported + if (typeof SharedWorker === 'undefined') { + console.warn( + 'SharedWorker not supported in this browser. ' + + 'WebSocket callbacks will fall back to HTTP.' + ); + return; + } + + const workerClient = getWorkerClient(); + + // Handle SET_PROPS messages + workerClient.onSetProps = (payload: SetPropsPayload) => { + const {componentId, props: rawProps} = payload; + const parsedId = parseComponentId(componentId); + const state = store.getState(); + const componentPath = getPath(state.paths, parsedId); + + if (!componentPath) { + console.warn( + `SET_PROPS: Component ${componentId} not found in layout` + ); + return; + } + + // Get old component for Patch processing and path recomputation + const oldComponent = path(componentPath, state.layout) as Record< + string, + unknown + > | null; + const oldProps = (oldComponent?.props || {}) as Record; + + // Process props to handle Patch objects + const processedProps = parsePatchProps(rawProps, oldProps); + + // Update the component props + store.dispatch( + updateProps({ + props: processedProps, + itempath: componentPath, + renderType: 'websocket' + }) as any + ); + + // Notify observers + store.dispatch( + notifyObservers({id: parsedId, props: processedProps}) as any + ); + + // Recompute paths for any new child components + if (oldComponent) { + const updatedState = store.getState(); + store.dispatch( + setPaths( + computePaths( + { + ...oldComponent, + props: {...oldProps, ...processedProps} + }, + [...componentPath], + updatedState.paths, + updatedState.paths.events + ) + ) as any + ); + } + }; + + // Handle GET_PROPS_REQUEST messages + workerClient.onGetPropsRequest = ( + requestId: string, + payload: GetPropsRequestPayload + ) => { + const {componentId, properties} = payload; + const parsedId = parseComponentId(componentId); + const state = store.getState(); + const componentPath = getPath(state.paths, parsedId); + + const result: Record = {}; + + if (componentPath) { + const componentProps = path( + [...componentPath, 'props'], + state.layout + ) as Record | undefined; + + if (componentProps) { + for (const propName of properties) { + result[propName] = componentProps[propName]; + } + } + } else { + console.warn( + `GET_PROPS_REQUEST: Component ${componentId} not found in layout` + ); + } + + // Send the response + workerClient.sendGetPropsResponse(requestId, result); + }; + + // Handle connection events + workerClient.onConnected = () => { + console.log('[Dash] WebSocket connected'); + }; + + workerClient.onDisconnected = (reason?: string) => { + console.log(`[Dash] WebSocket disconnected: ${reason}`); + }; + + workerClient.onError = (message: string, code?: string) => { + console.error(`[Dash] WebSocket error: ${message}`, code); + }; + + // Connect to the worker + const wsUrl = buildWebSocketUrl(config); + + try { + // config.websocket is guaranteed to exist due to wsAvailable check above + await workerClient.connect( + config.websocket!.worker_url, + wsUrl, + config.websocket!.inactivity_timeout + ); + } catch (error) { + console.error('[Dash] Failed to connect to WebSocket worker:', error); + } +} + +/** + * Build the WebSocket URL from config. + */ +function buildWebSocketUrl(config: DashConfig): string { + if (!config.websocket?.url) { + throw new Error('WebSocket URL not configured'); + } + + // Convert HTTP(S) URL to WS(S) + const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + + // The config.websocket.url is a path like "/_dash-ws-callback" + return `${wsProtocol}//${host}${config.websocket.url}`; +} + +/** + * Disconnect from the WebSocket. + */ +export function disconnectWebSocket(): void { + const workerClient = getWorkerClient(); + workerClient.disconnect(); +} diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index f1e1dc382c..5f963463d2 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -15,6 +15,8 @@ export interface ICallbackDefinition { dynamic_creator?: boolean; running: any; no_output?: boolean; + websocket?: boolean; + persistent?: boolean; } export interface ICallbackProperty { diff --git a/dash/dash-renderer/src/utils/rendererId.ts b/dash/dash-renderer/src/utils/rendererId.ts new file mode 100644 index 0000000000..b9bfcfd3af --- /dev/null +++ b/dash/dash-renderer/src/utils/rendererId.ts @@ -0,0 +1,22 @@ +/** Cached renderer ID for this page instance */ +let cachedRendererId: string | null = null; + +/** + * Generate a unique renderer ID for this page instance. + * + * Each page load gets a fresh ID to avoid conflicts with stale + * connections in the SharedWorker after page reloads. + */ +export function getRendererId(): string { + if (!cachedRendererId) { + if (typeof crypto !== 'undefined' && crypto.randomUUID) { + cachedRendererId = crypto.randomUUID(); + } else { + // Fallback for older browsers + cachedRendererId = `${Date.now()}-${Math.random() + .toString(36) + .slice(2)}`; + } + } + return cachedRendererId; +} diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts new file mode 100644 index 0000000000..c16594f8ef --- /dev/null +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -0,0 +1,350 @@ +/** + * Client for communicating with the Dash WebSocket SharedWorker. + */ + +import {getRendererId} from './rendererId'; + +/** Message types for worker communication */ +export enum WorkerMessageType { + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** Callback response structure */ +export interface CallbackResponse { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; +} + +/** Set props message payload */ +export interface SetPropsPayload { + componentId: string; + props: Record; +} + +/** Get props request payload */ +export interface GetPropsRequestPayload { + componentId: string; + properties: string[]; +} + +/** Pending callback request */ +interface PendingRequest { + resolve: (value: CallbackResponse) => void; + reject: (error: Error) => void; +} + +/** + * Client for the Dash WebSocket SharedWorker. + */ +class WorkerClient { + private worker: SharedWorker | null = null; + private rendererId: string; + private pendingCallbacks: Map = new Map(); + private requestCounter = 0; + private isConnected = false; + private connectionPromise: Promise | null = null; + private connectionResolve: (() => void) | null = null; + + /** Callback when SET_PROPS message is received */ + public onSetProps: ((payload: SetPropsPayload) => void) | null = null; + + /** Callback when GET_PROPS_REQUEST message is received */ + public onGetPropsRequest: + | ((requestId: string, payload: GetPropsRequestPayload) => void) + | null = null; + + /** Callback when connection is established */ + public onConnected: (() => void) | null = null; + + /** Callback when connection is lost */ + public onDisconnected: ((reason?: string) => void) | null = null; + + /** Callback when an error occurs */ + public onError: ((message: string, code?: string) => void) | null = null; + + constructor() { + this.rendererId = getRendererId(); + } + + /** + * Initialize the worker connection. + * @param workerUrl URL to the SharedWorker script + * @param serverUrl WebSocket server URL + * @param inactivityTimeout Optional inactivity timeout in ms + */ + public async connect( + workerUrl: string, + serverUrl: string, + inactivityTimeout?: number + ): Promise { + if (this.worker) { + // Already connected + return; + } + + // Create the SharedWorker + this.worker = new SharedWorker(workerUrl, { + name: 'dash-ws-worker' + }); + + // Set up message handling + this.worker.port.onmessage = this.handleMessage.bind(this); + + // Create promise for connection + this.connectionPromise = new Promise(resolve => { + this.connectionResolve = resolve; + }); + + // Start the port + this.worker.port.start(); + + // Send connect message + this.worker.port.postMessage({ + type: WorkerMessageType.CONNECT, + rendererId: this.rendererId, + payload: { + serverUrl, + inactivityTimeout + } + }); + + // Wait for connection + await this.connectionPromise; + } + + /** + * Disconnect from the worker. + */ + public disconnect(): void { + if (this.worker) { + this.worker.port.postMessage({ + type: WorkerMessageType.DISCONNECT, + rendererId: this.rendererId + }); + this.worker.port.close(); + this.worker = null; + } + this.isConnected = false; + this.connectionPromise = null; + this.connectionResolve = null; + + // Resolve pending callbacks with prevent_update so loading states clear + for (const [, pending] of this.pendingCallbacks) { + pending.resolve({status: 'prevent_update'}); + } + this.pendingCallbacks.clear(); + } + + /** + * Ensure the worker is connected, initiating connection if needed. + * @param config The Dash config with websocket settings + */ + public async ensureConnected(config: { + websocket?: { + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; + }): Promise { + // Already connected + if (this.isConnected) { + return; + } + + // Connection in progress, wait for it + if (this.connectionPromise) { + await this.connectionPromise; + return; + } + + // Need to initiate connection + if (!config.websocket?.url || !config.websocket?.worker_url) { + throw new Error('WebSocket config not available'); + } + + if (typeof SharedWorker === 'undefined') { + throw new Error('SharedWorker not supported'); + } + + // Build WebSocket URL + const wsProtocol = + window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + const wsUrl = `${wsProtocol}//${host}${config.websocket.url}`; + + await this.connect( + config.websocket.worker_url, + wsUrl, + config.websocket.inactivity_timeout + ); + } + + /** + * Send a callback request to the server via the worker. + * @param payload The callback payload + * @returns Promise that resolves with the callback response + */ + public async sendCallback(payload: unknown): Promise { + // Wait for initial connection if one is in progress + if (this.connectionPromise && !this.isConnected) { + await this.connectionPromise; + } + + if (!this.worker) { + throw new Error('Worker not connected'); + } + + const requestId = `${this.rendererId}-${++this.requestCounter}`; + + return new Promise((resolve, reject) => { + this.pendingCallbacks.set(requestId, {resolve, reject}); + + this.worker!.port.postMessage({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId: this.rendererId, + requestId, + payload + }); + }); + } + + /** + * Send a get_props response back to the server. + * @param requestId The request ID from the get_props request + * @param props The property values + */ + public sendGetPropsResponse( + requestId: string, + props: Record + ): void { + if (!this.worker || !this.isConnected) { + return; + } + + this.worker.port.postMessage({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId: this.rendererId, + requestId, + payload: props + }); + } + + /** + * Check if the worker is connected. + */ + public get connected(): boolean { + return this.isConnected; + } + + private handleMessage(event: MessageEvent): void { + const message = event.data; + + switch (message.type) { + case WorkerMessageType.CONNECTED: + this.isConnected = true; + if (this.connectionResolve) { + this.connectionResolve(); + this.connectionResolve = null; + } + if (this.onConnected) { + this.onConnected(); + } + break; + + case WorkerMessageType.DISCONNECTED: + this.isConnected = false; + // Resolve pending callbacks with prevent_update so loading states clear + for (const [, pending] of this.pendingCallbacks) { + pending.resolve({status: 'prevent_update'}); + } + this.pendingCallbacks.clear(); + if (this.onDisconnected) { + this.onDisconnected(message.payload?.reason); + } + break; + + case WorkerMessageType.CALLBACK_RESPONSE: { + const requestId = message.requestId; + const pending = this.pendingCallbacks.get(requestId); + if (pending) { + this.pendingCallbacks.delete(requestId); + pending.resolve(message.payload); + } + break; + } + + case WorkerMessageType.SET_PROPS: + if (this.onSetProps) { + this.onSetProps(message.payload); + } + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + if (this.onGetPropsRequest) { + this.onGetPropsRequest(message.requestId, message.payload); + } + break; + + case WorkerMessageType.ERROR: + if (this.onError) { + this.onError( + message.payload?.message || 'Unknown error', + message.payload?.code + ); + } + break; + } + } +} + +// Singleton instance +let workerClientInstance: WorkerClient | null = null; + +/** + * Get the singleton WorkerClient instance. + */ +export function getWorkerClient(): WorkerClient { + if (!workerClientInstance) { + workerClientInstance = new WorkerClient(); + } + return workerClientInstance; +} + +/** + * Check if WebSocket callbacks are globally enabled and supported. + * @param config The Dash config + */ +export function isWebSocketEnabled(config: { + websocket?: {enabled: boolean}; +}): boolean { + return !!(config.websocket?.enabled && typeof SharedWorker !== 'undefined'); +} + +/** + * Check if WebSocket infrastructure is available (for per-callback websocket). + * @param config The Dash config + */ +export function isWebSocketAvailable(config: { + websocket?: { + enabled?: boolean; + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; +}): boolean { + return !!( + config.websocket?.url && + config.websocket?.worker_url && + typeof SharedWorker !== 'undefined' + ); +} diff --git a/dash/dash-renderer/webpack.base.config.js b/dash/dash-renderer/webpack.base.config.js index ed95239f7d..e8a9d14596 100644 --- a/dash/dash-renderer/webpack.base.config.js +++ b/dash/dash-renderer/webpack.base.config.js @@ -72,6 +72,31 @@ const rendererOptions = { ...defaults }; +// WebSocket Worker configuration +const workerOptions = { + mode: 'production', + entry: { + 'dash-ws-worker': '../../@plotly/dash-websocket-worker/src/worker.ts', + }, + output: { + path: path.resolve(__dirname, "build"), + filename: '[name].js', + }, + target: 'webworker', + module: { + rules: [ + { + test: /\.ts$/, + exclude: /node_modules/, + use: ['ts-loader'], + }, + ] + }, + resolve: { + extensions: ['.ts', '.js'] + } +}; + module.exports = options => [ R.mergeAll([ options, @@ -109,5 +134,7 @@ module.exports = options => [ ] ), } - ]) + ]), + // WebSocket Worker build + workerOptions ]; diff --git a/dash/dash.py b/dash/dash.py index 91c02e4867..f0821abef2 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -5,7 +5,6 @@ import inspect import importlib import warnings -from contextvars import copy_context from importlib.machinery import ModuleSpec from importlib.util import find_spec from importlib import metadata @@ -13,24 +12,18 @@ import threading import re import logging -import time import mimetypes import hashlib import base64 -import traceback from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List -import asyncio -import flask - -from importlib_metadata import version as _get_distribution_version +import traceback from dash import dcc from dash import html from dash import dash_table - -from .fingerprint import build_fingerprint, check_fingerprint +from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( Input, @@ -39,11 +32,10 @@ ) from .development.base_component import ComponentRegistry from .exceptions import ( - PreventUpdate, - InvalidResourceError, ProxyError, DuplicateCallback, ) +from .backends import get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -59,8 +51,8 @@ convert_to_AttributeDict, gen_salt, hooks_to_js_object, - parse_version, get_caller_name, + get_root_path, ) from . import _callback from . import _get_paths @@ -68,10 +60,12 @@ from . import _validate from . import _watch from . import _get_app +from . import backends -from ._get_app import with_app_context, with_app_context_async, with_app_context_factory +from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group from ._obsolete import ObsoleteChecker +from ._callback_context import callback_context from . import _pages from ._pages import ( @@ -243,12 +237,17 @@ class Dash(ObsoleteChecker): best value to use. Default ``'__main__'``, env: ``DASH_APP_NAME`` :type name: string - :param server: Sets the Flask server for your app. There are three options: - ``True`` (default): Dash will create a new server + :param server: Sets the server for your app. There are three options: + ``True`` (default): Dash will create a new server using the specified backend ``False``: The server will be added later via ``app.init_app(server)`` - where ``server`` is a ``flask.Flask`` instance. - ``flask.Flask``: use this pre-existing Flask server. - :type server: boolean or flask.Flask + A server instance: Use a pre-existing server (Flask, Quart, or FastAPI) + :type server: boolean or server instance + + :param backend: The backend to use for the Dash app. Can be a string + (name of the backend) or a backend class. Default is None, which + selects the Flask backend. Currently, "flask", "fastapi", and "quart" backends + are supported. + :type backend: string or type :param assets_folder: a path, relative to the current working directory, for extra files to be used in the browser. Default ``'assets'``. @@ -435,16 +434,17 @@ class Dash(ObsoleteChecker): _plotlyjs_url: str STARTUP_ROUTES: list = [] - server: flask.Flask + server: Any # Layout is a complex type which can be many things _layout: Any _extra_components: Any - def __init__( # pylint: disable=too-many-statements + def __init__( # pylint: disable=too-many-statements, too-many-branches self, name: Optional[str] = None, - server: Union[bool, flask.Flask] = True, + server: Union[bool, Callable[[], Any]] = True, + backend: Union[str, type, None] = None, assets_folder: str = "assets", pages_folder: str = "pages", use_pages: Optional[bool] = None, @@ -483,24 +483,13 @@ def __init__( # pylint: disable=too-many-statements health_endpoint: Optional[str] = None, csrf_token_name: str = "_csrf_token", csrf_header_name: str = "X-CSRFToken", + websocket_callbacks: Optional[bool] = False, + websocket_allowed_origins: Optional[List[str]] = None, + websocket_inactivity_timeout: Optional[int] = 300000, **obsolete, ): - if use_async is None: - try: - import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa - - use_async = True - except ImportError: - pass - elif use_async: - try: - import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa - except ImportError as exc: - raise Exception( - "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" - ) from exc - + use_async = _validate.check_async(use_async) _validate.check_obsolete(obsolete) if not csrf_token_name or not csrf_token_name.strip(): @@ -510,16 +499,31 @@ def __init__( # pylint: disable=too-many-statements caller_name: str = name if name is not None else get_caller_name() - # We have 3 cases: server is either True (we create the server), False - # (defer server creation) or a Flask app instance (we use their server) - if isinstance(server, flask.Flask): - self.server = server + # Determine backend + if backend is None: + backend_cls = get_backend("flask") + elif isinstance(backend, str): + backend_cls = get_backend(backend) + elif isinstance(backend, type): + backend_cls = backend + else: + raise ValueError("Invalid backend argument") + + # Determine server and backend instance + if server not in (None, True, False): + # User provided a server instance (e.g., Flask, Quart, FastAPI) + inferred_backend = backends.get_server_type(server) + _validate.check_backend(backend, inferred_backend) + backend_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) - elif isinstance(server, bool): - self.server = flask.Flask(caller_name) if server else None # type: ignore + + self.backend = backend_cls(server) + self.server = server else: - raise ValueError("server must be a Flask app or a boolean") + # No server instance provided, create backend and let backend create server + self.server = backend_cls.create_app(caller_name) # type: ignore + self.backend = backend_cls(self.server) base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -528,7 +532,7 @@ def __init__( # pylint: disable=too-many-statements self.config = AttributeDict( name=caller_name, assets_folder=os.path.join( - flask.helpers.get_root_path(caller_name), assets_folder + get_root_path(caller_name), assets_folder ), # type: ignore assets_url_path=assets_url_path, assets_ignore=assets_ignore, @@ -638,6 +642,9 @@ def __init__( # pylint: disable=too-many-statements self._assets_files: list = [] self._background_manager = background_callback_manager + self._websocket_callbacks = websocket_callbacks + self._websocket_allowed_origins = websocket_allowed_origins or [] + self._websocket_inactivity_timeout = websocket_inactivity_timeout self.logger = logging.getLogger(__name__) @@ -655,7 +662,7 @@ def __init__( # pylint: disable=too-many-statements # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} - if self.server is not None: + if server: self.init_app() self.logger.setLevel(logging.INFO) @@ -702,11 +709,15 @@ def _setup_hooks(self): if self._hooks.get_hooks("error"): self._on_error = self._hooks.HookErrorHandler(self._on_error) - def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: - """Initialize the parts of Dash that require a flask app.""" - + def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config - + config.unset_read_only( + [ + "url_base_pathname", + "routes_pathname_prefix", + "requests_pathname_prefix", + ] + ) config.update(kwargs) config.set_read_only( [ @@ -716,91 +727,75 @@ def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: ], "Read-only: can only be set in the Dash constructor or during init_app()", ) - if app is not None: self.server = app - bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" - - self.server.register_blueprint( - flask.Blueprint( - assets_blueprint_name, - config.name, - static_folder=self.config.assets_folder, - static_url_path=config.routes_pathname_prefix - + self.config.assets_url_path.lstrip("/"), - ) + self.backend.register_assets_blueprint( + assets_blueprint_name, + config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), + self.config.assets_folder, ) - if config.compress: - try: - # pylint: disable=import-outside-toplevel - from flask_compress import Compress # type: ignore - - # gzip - Compress(self.server) + self.backend.enable_compression() # type: ignore - _flask_compress_version = parse_version( - _get_distribution_version("flask_compress") - ) - - if not hasattr( - self.server.config, "COMPRESS_ALGORITHM" - ) and _flask_compress_version >= parse_version("1.6.0"): - # flask-compress==1.6.0 changed default to ['br', 'gzip'] - # and non-overridable default compression with Brotli is - # causing performance issues - self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] - except ImportError as error: - raise ImportError( - "To use the compress option, you need to install dash[compress]" - ) from error - - @self.server.errorhandler(PreventUpdate) def _handle_error(_): """Handle a halted callback and return an empty 204 response.""" return "", 204 - self.server.before_request(self._setup_server) - + # To-Do add error handlers for these two scenarios + # add handler for halted callbacks + # self.backend.before_request(_handle_error) # add a handler for components suites errors to return 404 - self.server.errorhandler(InvalidResourceError)(self._invalid_resources_handler) + # self.server.errorhandler(InvalidResourceError)(self._invalid_resources_handler) + self.backend.register_error_handlers() + self.backend.before_request(self._setup_server) + self.backend.setup_backend(self) self._setup_routes() - _get_app.APP = self self.enable_pages() - self._setup_plotlyjs() def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name - - self.server.add_url_rule( - full_name, view_func=view_func, endpoint=full_name, methods=list(methods) + self.backend.add_url_rule( + full_name, + view_func=view_func, + endpoint=full_name, + methods=list(methods), ) - - # record the url in Dash.routes so that it can be accessed later - # e.g. for adding authentication with flask_login self.routes.append(full_name) - def _setup_routes(self): - self._add_url( - "_dash-component-suites//", - self.serve_component_suites, + def _serve_default_favicon(self): + return self.backend.make_response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + + def _setup_routes(self): + self.backend.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - if self._use_async: - self._add_url("_dash-update-component", self.async_dispatch, ["POST"]) - else: - self._add_url("_dash-update-component", self.dispatch, ["POST"]) + self._add_url( + "_dash-update-component", + self.backend.serve_callback(self), + ["POST"], + ) self._add_url("_reload-hash", self.serve_reload_hash) - self._add_url("_favicon.ico", self._serve_default_favicon) + self._add_url( + "_favicon.ico", + self._serve_default_favicon, # pylint: disable=protected-access + ) if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) - self._add_url("", self.index) + + # Set up WebSocket callback route if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + self.backend.serve_websocket_callback(self) + + self.backend.setup_index(self) + self.backend.setup_catchall(self) if jupyter_dash.active: self._add_url( @@ -814,9 +809,6 @@ def _setup_routes(self): hook.data["methods"], ) - # catch-all for front-end routes, used by dcc.Location - self._add_url("", self.index) - def setup_apis(self): """ Register API endpoints for all callbacks defined using `dash.callback`. @@ -840,30 +832,8 @@ def setup_apis(self): ) self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) - def make_parse_body(func): - def _parse_body(): - if flask.request.is_json: - data = flask.request.get_json() - return flask.jsonify(func(**data)) - return flask.jsonify({}) - - return _parse_body - - def make_parse_body_async(func): - async def _parse_body_async(): - if flask.request.is_json: - data = flask.request.get_json() - result = await func(**data) - return flask.jsonify(result) - return flask.jsonify({}) - - return _parse_body_async - - for path, func in self.callback_api_paths.items(): - if inspect.iscoroutinefunction(func): - self._add_url(path, make_parse_body_async(func), ["POST"]) - else: - self._add_url(path, make_parse_body(func), ["POST"]) + # Delegate to the server factory for route registration + self.backend.register_callback_api_routes(self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -935,7 +905,7 @@ def serve_layout(self): layout = hook(layout) # TODO - Set browser cache limit - pass hash into frontend - return flask.Response( + return self.backend.make_response( to_json(layout), mimetype="application/json", ) @@ -1005,6 +975,16 @@ def _config(self): custom_dev_tools.append({**hook_dev_tools, "props": props}) config["dev_tools"] = custom_dev_tools + # Add websocket config if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + config["websocket"] = { + "enabled": bool(self._websocket_callbacks), + "url": self.config.requests_pathname_prefix + "_dash-ws-callback", + "worker_url": self._get_worker_url(), + "inactivity_timeout": self._websocket_inactivity_timeout, + } + return config def serve_reload_hash(self): @@ -1016,7 +996,7 @@ def serve_reload_hash(self): _reload.hard = False _reload.changed_assets = [] - return flask.jsonify( + return self.backend.jsonify( { "reloadHash": _hash, "hard": hard, @@ -1030,7 +1010,34 @@ def serve_health(self): Health check endpoint for monitoring Dash server status. Returns a simple "OK" response with HTTP 200 status. """ - return flask.Response("OK", status=200, mimetype="text/plain") + return self.backend.make_response("OK", status=200, mimetype="text/plain") + + def _get_worker_url(self) -> str: + """Get the URL for the WebSocket worker script. + + Returns: + The fingerprinted URL for the worker script served via component suites. + """ + relative_path = "dash-renderer/build/dash-ws-worker.js" + namespace = "dash" + + # Register the path so it can be served + self.registered_paths[namespace].add(relative_path) + + # Build fingerprinted URL (same pattern as _collect_and_register_resources) + module_path = os.path.join( + os.path.dirname(sys.modules[namespace].__file__), # type: ignore + relative_path, + ) + + # Use a fallback if the file doesn't exist yet (during development) + try: + modified = int(os.stat(module_path).st_mtime) + except FileNotFoundError: + modified = 0 + + fingerprint = build_fingerprint(relative_path, __version__, modified) + return f"{self.config.requests_pathname_prefix}_dash-component-suites/{namespace}/{fingerprint}" def get_dist(self, libraries: Sequence[str]) -> list: dists = [] @@ -1233,58 +1240,18 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - # Serve the JS bundles for each package - def serve_component_suites(self, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - - _validate.validate_js_path(self.registered_paths, package_name, path_in_pkg) - - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - - package = sys.modules[package_name] - self.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - - response = flask.Response( - pkgutil.get_data(package_name, path_in_pkg), mimetype=mimetype - ) - - if has_fingerprint: - # Fingerprinted resources are good forever (1 year) - # No need for ETag as the fingerprint changes with each build - response.cache_control.max_age = 31536000 # 1 year - else: - # Non-fingerprinted resources are given an ETag that - # will be used / check on future requests - response.add_etag() - tag = response.get_etag()[0] - - request_etag = flask.request.headers.get("If-None-Match") - - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - - return response - - @with_app_context - def index(self, *args, **kwargs): # pylint: disable=unused-argument + def index(self, *_args, **_kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() metas = self._generate_meta() renderer = self._generate_renderer() - - # use self.title instead of app.config.title for backwards compatibility title = self.title + # Refactored: direct access to global request adapter + request = self.backend.request_adapter() - if self.use_pages and self.config.include_pages_meta: - metas = _page_meta_tags(self) + metas + if self.use_pages and self.config.include_pages_meta and request: + metas = _page_meta_tags(self, request) + metas if self._favicon: favicon_mod_time = os.path.getmtime( @@ -1388,7 +1355,7 @@ def interpolate_index(self, **kwargs): @with_app_context def dependencies(self): - return flask.Response( + return self.backend.make_response( to_json(self._callback_list), content_type="application/json", ) @@ -1491,9 +1458,13 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]: **_kwargs, ) + def _inputs_to_vals(self, inputs): + return inputs_to_vals(inputs) + # pylint: disable=R0915 def _initialize_context(self, body): """Initialize the global context for the request.""" + adapter = self.backend.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -1504,12 +1475,13 @@ def _initialize_context(self, body): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = flask.Response(mimetype="application/json") - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin + g.dash_response = self.backend.response_adapter() + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin g.updated_props = {} return g @@ -1573,11 +1545,6 @@ def _prepare_grouping(self, data_list, indices): def _execute_callback(self, func, args, outputs_list, g): """Execute the callback with the prepared arguments.""" - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin g.custom_data = AttributeDict({}) for hook in self._hooks.get_hooks("custom_data"): @@ -1596,47 +1563,6 @@ def _execute_callback(self, func, args, outputs_list, g): ) return partial_func - @with_app_context_async - async def async_dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - if asyncio.iscoroutine(func): - response_data = await ctx.run(partial_func) - else: - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - response_data = await response_data - - g.dash_response.set_data(response_data) - return g.dash_response - - @with_app_context - def dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - - g.dash_response.set_data(response_data) - return g.dash_response - def _setup_server(self): if self._got_first_request["setup_server"]: return @@ -1712,7 +1638,7 @@ def _setup_server(self): manager=manager, ) def cancel_call(*_): - job_ids = flask.request.args.getlist("cancelJob") + job_ids = callback_context.args.getlist("cancelJob") executor = _callback.context_value.get().background_callback_manager if job_ids: for job_id in job_ids: @@ -1781,12 +1707,6 @@ def _walk_assets_directory(self): def _invalid_resources_handler(err): return err.args[0], 404 - @staticmethod - def _serve_default_favicon(): - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - def csp_hashes(self, hash_algorithm="sha256") -> Sequence[str]: """Calculates CSP hashes (sha + base64) of all inline scripts, such that one of the biggest benefits of CSP (disallowing general inline scripts) @@ -2035,6 +1955,7 @@ def enable_dev_tools( # pylint: disable=too-many-branches dev_tools_disable_version_check: Optional[bool] = None, dev_tools_prune_errors: Optional[bool] = None, dev_tools_validate_callbacks: Optional[bool] = None, + first_run: bool = True, ) -> bool: """Activate the dev tools, called by `run`. If your application is served by wsgi and you want to activate the dev tools, you can call @@ -2095,9 +2016,10 @@ def enable_dev_tools( # pylint: disable=too-many-branches env: ``DASH_HOT_RELOAD_MAX_RETRY`` :type dev_tools_hot_reload_max_retry: int - :param dev_tools_silence_routes_logging: Silence the `werkzeug` logger, - will remove all routes logging. Enabled with debugging by default - because hot reload hash checks generate a lot of requests. + :param dev_tools_silence_routes_logging: Silence the route logging for the + web server (werkzeug for Flask, hypercorn for Quart, uvicorn for FastAPI). + Enabled with debugging by default because hot reload hash checks generate + a lot of requests. env: ``DASH_SILENCE_ROUTES_LOGGING`` :type dev_tools_silence_routes_logging: bool @@ -2137,7 +2059,18 @@ def enable_dev_tools( # pylint: disable=too-many-branches ) if dev_tools.silence_routes_logging: - logging.getLogger("werkzeug").setLevel(logging.ERROR) + # Silence route logging based on backend type + backend_type = getattr(self.backend, "server_type", "flask") + if backend_type == "flask": + logging.getLogger("werkzeug").setLevel(logging.ERROR) + elif backend_type == "quart": + # Quart uses hypercorn as its ASGI server + logging.getLogger("hypercorn.access").setLevel(logging.ERROR) + logging.getLogger("hypercorn.error").setLevel(logging.ERROR) + elif backend_type == "fastapi": + # FastAPI uses uvicorn as its ASGI server + logging.getLogger("uvicorn.access").setLevel(logging.ERROR) + logging.getLogger("uvicorn.error").setLevel(logging.ERROR) if dev_tools.hot_reload: _reload = self._hot_reload @@ -2225,49 +2158,11 @@ def enable_dev_tools( # pylint: disable=too-many-branches jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - elif dev_tools.prune_errors: - secret = gen_salt(20) - - @self.server.errorhandler(Exception) - def _wrap_errors(error): - # find the callback invocation, if the error is from a callback - # and skip the traceback up to that point - # if the error didn't come from inside a callback, we won't - # skip anything. - tb = _get_traceback(secret, error) - return tb, 500 + secret = gen_salt(20) + self.backend.register_prune_error_handler(secret, dev_tools.prune_errors) if debug and dev_tools.ui: - - def _before_request(): - flask.g.timing_information = { # pylint: disable=assigning-non-slot - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - - if info.get("dur") is not None: - value += f";dur={info['dur']}" - - response.headers.add("Server-Timing", value) - - return response - - self.server.before_request(_before_request) - - self.server.after_request(_after_request) + self.backend.register_timing_hooks(first_run) if ( debug @@ -2431,9 +2326,10 @@ def run( env: ``DASH_HOT_RELOAD_MAX_RETRY`` :type dev_tools_hot_reload_max_retry: int - :param dev_tools_silence_routes_logging: Silence the `werkzeug` logger, - will remove all routes logging. Enabled with debugging by default - because hot reload hash checks generate a lot of requests. + :param dev_tools_silence_routes_logging: Silence the route logging for the + web server (werkzeug for Flask, hypercorn for Quart, uvicorn for FastAPI). + Enabled with debugging by default because hot reload hash checks generate + a lot of requests. env: ``DASH_SILENCE_ROUTES_LOGGING`` :type dev_tools_silence_routes_logging: bool @@ -2501,6 +2397,7 @@ def run( host = host or "127.0.0.1" else: host = host or os.getenv("HOST", "127.0.0.1") + assert host port = port or os.getenv("PORT", "8050") proxy = proxy or os.getenv("DASH_PROXY") @@ -2570,7 +2467,9 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server.run(host=host, port=port, debug=debug, **flask_run_options) + self.backend.run( + dash_app=self, host=host, port=port, debug=debug, **flask_run_options + ) def enable_pages(self) -> None: if not self.use_pages: @@ -2578,8 +2477,8 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - @self.server.before_request - def router(): + # Async version + async def router_async(): if self._got_first_request["pages"]: return self._got_first_request["pages"] = True @@ -2588,163 +2487,152 @@ def router(): "pathname_": Input(_ID_LOCATION, "pathname"), "search_": Input(_ID_LOCATION, "search"), } - inputs.update(self.routing_callback_inputs) # type: ignore[reportCallIssue] - - if self._use_async: + inputs.update(self.routing_callback_inputs) - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - hidden=True, + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + hidden=True, + ) + async def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) ) - async def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break + else: + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] + + if callable(layout): + layout = await execute_async_function( + layout, + **{**(path_variables or {}), **query_parameters, **states}, + ) + if callable(title): + title = await execute_async_function( + title, **{**(path_variables or {})} ) + return layout, {"title": title} - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title - else: - layout = page.get("layout", "") - title = page["title"] + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) - if callable(layout): - layout = await execute_async_function( - layout, - **{**(path_variables or {}), **query_parameters, **states}, - ) - if callable(title): - title = await execute_async_function( - title, **(path_variables or {}) - ) + if not self.config.suppress_callback_exceptions: - return layout, {"title": title} - - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) - - # Set validation_layout - if not self.config.suppress_callback_exceptions: - self.validation_layout = html.Div( - [ - ( - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") - else: + async def get_layouts(): + return [ + await execute_async_function(page["layout"]) + if callable(page["layout"]) + else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - hidden=True, - ) - def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) - ) + layouts = await get_layouts() + # pylint: disable=not-callable + layouts += [self.layout() if callable(self.layout) else self.layout] + self.validation_layout = html.Div(layouts) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title + self.clientside_callback( + """ + function(data) { + document.title = data.title + } + """, + Output(_ID_DUMMY, "children"), + Input(_ID_STORE, "data"), + hidden=True, + ) + + # Sync version + def router_sync(): + if self._got_first_request["pages"]: + return + self._got_first_request["pages"] = True + + inputs = { + "pathname_": Input(_ID_LOCATION, "pathname"), + "search_": Input(_ID_LOCATION, "search"), + } + inputs.update(self.routing_callback_inputs) + + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + hidden=True, + ) + def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) + ) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break else: - layout = page.get("layout", "") - title = page["title"] + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] - if callable(layout): - layout = layout( - **{**(path_variables or {}), **query_parameters, **states} - ) - if callable(title): - title = title(**(path_variables or {})) - - return layout, {"title": title} - - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) - - # Set validation_layout - if not self.config.suppress_callback_exceptions: - layout = self.layout - if not isinstance(layout, list): - layout = [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - self.validation_layout = html.Div( - [ - ( - page["layout"]() - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + layout - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") + if callable(layout): + layout = layout( + **{**(path_variables or {}), **query_parameters, **states} + ) + if callable(title): + title = title(**(path_variables or {})) + return layout, {"title": title} + + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) + + if not self.config.suppress_callback_exceptions: + layout = self.layout + if not isinstance(layout, list): + # pylint: disable=not-callable + layout = [self.layout() if callable(self.layout) else self.layout] + self.validation_layout = html.Div( + [ + page["layout"]() if callable(page["layout"]) else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + + layout + ) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - # Update the page title on page navigation self.clientside_callback( """ - function(data) {{ + function(data) { document.title = data.title - }} + } """, Output(_ID_DUMMY, "children"), Input(_ID_STORE, "data"), - hidden=True, ) - def __call__(self, environ, start_response): - """ - This method makes instances of Dash WSGI-compliant callables. - It delegates the actual WSGI handling to the internal Flask app's - __call__ method. - """ - return self.server(environ, start_response) + if self._use_async: + self.backend.before_request(router_async) + else: + self.backend.before_request(router_sync) + + def __call__(self, *args, **kwargs): + return self.backend.__call__(*args, **kwargs) diff --git a/dash/exceptions.py b/dash/exceptions.py index 00bd2c1553..9366f9359c 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -109,3 +109,15 @@ class ImportedInsideCallbackError(DashException): class HookError(DashException): pass + + +class AppNotFoundError(DashException): + pass + + +class WebSocketCallbackError(CallbackException): + pass + + +class WebsocketDisconnected(CallbackException): + pass diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index dc88afe844..51a938f72f 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -147,6 +147,7 @@ class ThreadedRunner(BaseDashRunner): def __init__(self, keep_open=False, stop_timeout=3): super().__init__(keep_open=keep_open, stop_timeout=stop_timeout) self.thread = None + self._app = None # Store app reference for graceful shutdown def running_and_accessible(self, url): if self.thread.is_alive(): # type: ignore[reportOptionalMemberAccess] @@ -156,6 +157,7 @@ def running_and_accessible(self, url): # pylint: disable=arguments-differ def start(self, app, start_timeout=3, **kwargs): """Start the app server in threading flavor.""" + self._app = app # Store app reference for graceful shutdown def run(): app.scripts.config.serve_locally = True @@ -171,7 +173,16 @@ def run(): self.port = options["port"] try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if module.startswith("fastapi"): + app.run(**options) + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") except Exception as error: @@ -207,9 +218,17 @@ def run(): raise DashAppLoadingError("threaded server failed to start") def stop(self): - self.thread.kill() # type: ignore[reportOptionalMemberAccess] - self.thread.join() # type: ignore[reportOptionalMemberAccess] - wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + # For FastAPI apps with uvicorn, use graceful shutdown + if self._app and hasattr(self._app, "_uvicorn_server"): + server = self._app._uvicorn_server # pylint: disable=protected-access + server.should_exit = True + self.thread.join(timeout=self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + else: + # Fall back to killing threads for Flask/other backends + self.thread.kill() # type: ignore[reportOptionalMemberAccess] + self.thread.join() # type: ignore[reportOptionalMemberAccess] + wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + self._app = None self.started = False @@ -229,7 +248,16 @@ def target(): options = kwargs.copy() try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if module.startswith("fastapi"): + app.run(**options) + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") raise diff --git a/dash/version.py b/dash/version.py index 7039708762..d39d52c3b8 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.1.0" +__version__ = "4.2.0rc3" diff --git a/package.json b/package.json index 2a53cdc585..f44a0f2805 100644 --- a/package.json +++ b/package.json @@ -44,7 +44,7 @@ "setup-tests.R": "run-s private::test.R.deploy-*", "citest.integration": "run-s setup-tests.py private::test.integration-*", "citest.unit": "run-s private::test.unit-**", - "test": "pytest && cd dash/dash-renderer && npm run test", + "test": "pytest --ignore=tests/backend_tests && cd dash/dash-renderer && npm run test", "first-build": "cd dash/dash-renderer && npm i && cd ../../ && cd components/dash-html-components && npm i && npm run extract && cd ../../ && npm run build" }, "devDependencies": { diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt new file mode 100644 index 0000000000..364e2ee48e --- /dev/null +++ b/requirements/fastapi.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn[standard] diff --git a/requirements/install.txt b/requirements/install.txt index df0e1299e3..284f3a5031 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -7,3 +7,4 @@ requests retrying nest-asyncio setuptools +janus>=1.0.0 diff --git a/requirements/quart.txt b/requirements/quart.txt new file mode 100644 index 0000000000..60af440c9c --- /dev/null +++ b/requirements/quart.txt @@ -0,0 +1 @@ +quart diff --git a/setup.py b/setup.py index f87ef21d70..e2ecb16055 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ def read_req_file(req_type): "celery": read_req_file("celery"), "diskcache": read_req_file("diskcache"), "compress": read_req_file("compress"), + "fastapi": read_req_file("fastapi"), + "quart": read_req_file("quart"), "cloud": read_req_file("cloud"), "ag-grid": read_req_file("ag-grid") }, diff --git a/tests/backend_tests/__init__.py b/tests/backend_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/backend_tests/test_custom_backend.py b/tests/backend_tests/test_custom_backend.py new file mode 100644 index 0000000000..befff7734b --- /dev/null +++ b/tests/backend_tests/test_custom_backend.py @@ -0,0 +1,246 @@ +import pytest +from dash import Dash, Input, Output, html, dcc +import traceback +import re + +try: + from dash.backends._fastapi import FastAPIDashServer +except ImportError: + FastAPIDashServer = None + + +class CustomDashServer(FastAPIDashServer): + def _get_traceback(self, _secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+                + "\n".join(card)
+                + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // Custom Debugger + + + +
+

{error_type}: {error_msg}

+ {cards_html} +
+ + + """ + return html + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Hello CustomBackend!"), + ], +) +def test_custom_backend_basic_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend Test") + dash_duo.wait_for_text_to_equal("#output", "You typed: CustomBackend Test") + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ("dash_duo", {"debug": True, "reload": False, "dev_tools_ui": True}), + ], +) +def test_custom_backend_error_handling(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ( + "dash_duo", + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + ), + ], +) +def test_custom_backend_error_handling_no_prune(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" in error0 + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs, error_msg", + [ + ("dash_duo", {"debug": True, "reload": False}, "custombackend.py"), + ], +) +def test_custom_backend_error_handling_prune( + request, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" not in error0 + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Background CustomBackend!"), + ], +) +def test_custom_backend_background_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + + background_callback_manager = DiskcacheManager(cache) + + app = Dash( + __name__, + backend=CustomDashServer, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend BG Test") + dash_duo.wait_for_text_to_equal( + "#output", "Background typed: CustomBackend BG Test" + ) + assert dash_duo.get_logs() == [] diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py new file mode 100644 index 0000000000..eec832070a --- /dev/null +++ b/tests/backend_tests/test_preconfig_backends.py @@ -0,0 +1,293 @@ +import logging +import pytest +from dash import Dash, Input, Output, html, dcc, ctx + + +@pytest.mark.parametrize( + "backend,fixture", + [ + ("flask", "dash_duo"), + ("fastapi", "dash_duo"), + ("quart", "dash_duo_mp"), + ], +) +def test_set_cookie_and_header(request, backend, fixture): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([html.Button("Set", id="btn"), html.Div(id="output")]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def set_cookie_and_header(n): + if ctx.response: + ctx.response.set_cookie("mycookie", "cookieval") + ctx.response.set_header("X-My-Header", "HeaderVal") + ctx.response.append_header("X-My-Header", "HeaderVal2") + ctx.response.append_header("X-My-Header2", "HeaderVal3") + ctx.response.set_header("X-My-Header2", "HeaderVal4") + return f"Clicked {n}" if n else "Not clicked" + + dash_duo.start_server(app) + dash_duo.driver.execute_script( + """ + window._lastResponseHeaders = null; + const origFetch = window.fetch; + window.fetch = async function() { + const response = await origFetch.apply(this, arguments); + response.clone().headers.forEach((v, k) => { + if (!window._lastResponseHeaders) window._lastResponseHeaders = {}; + window._lastResponseHeaders[k] = v; + }); + return response; + }; + """ + ) + + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Check cookie + cookies = dash_duo.driver.get_cookies() + assert any(c["name"] == "mycookie" and c["value"] == "cookieval" for c in cookies) + + headers = dash_duo.driver.execute_script("return window._lastResponseHeaders;") + assert headers and headers["x-my-header"] == "HeaderVal, HeaderVal2" + assert headers and headers["x-my-header2"] == "HeaderVal4" + + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Hello FastAPI!"), + ("quart", "dash_duo_mp", "Hello Quart!"), + ], +) +def test_backend_basic_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + if backend == "fastapi": + from fastapi import FastAPI + + server = FastAPI() + else: + import quart + + server = quart.Quart(__name__) + app = Dash(__name__, server=server) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") + dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "reload": False, "dev_tools_ui": True}, + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + ), + ], +) +def test_backend_error_handling(request, backend, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + "_fastapi.py", + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + "dev_tools_prune_errors": False, + }, + "_quart.py", + ), + ], +) +def test_backend_error_handling_no_prune( + request, backend, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "backends/" in error0 and error_msg in error0 + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ("fastapi", "dash_duo", {"debug": True, "reload": False}, "fastapi.py"), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + "quart.py", + ), + ], +) +def test_backend_error_handling_prune( + request, backend, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "dash/backends/" not in error0 and error_msg not in error0 + + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Background FastAPI!"), + ("quart", "dash_duo_mp", "Background Quart!"), + ], +) +def test_backend_background_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + + background_callback_manager = DiskcacheManager(cache) + + app = Dash( + __name__, + backend=backend, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") + dash_duo.wait_for_text_to_equal( + "#output", f"Background typed: {backend.title()} BG Test" + ) + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "backend,expected_loggers", + [ + ("flask", ["werkzeug"]), + ("quart", ["hypercorn.access", "hypercorn.error"]), + ("fastapi", ["uvicorn.access", "uvicorn.error"]), + ], +) +def test_silence_routes_logging(backend, expected_loggers): + """Test that route logging is silenced for all backends when dev_tools_silence_routes_logging is enabled.""" + app = Dash(__name__, backend=backend) + app.layout = html.Div([html.Div(id="output", children="Test")]) + + # Enable dev tools with silence_routes_logging + app.enable_dev_tools(debug=True, dev_tools_silence_routes_logging=True) + + # Check that the expected loggers have been set to ERROR level + for logger_name in expected_loggers: + logger = logging.getLogger(logger_name) + assert ( + logger.level == logging.ERROR + ), f"Logger {logger_name} should be set to ERROR level for {backend} backend" diff --git a/tests/integration/callbacks/test_basic_callback.py b/tests/integration/callbacks/test_basic_callback.py index 87ce3507e7..6e724c186f 100644 --- a/tests/integration/callbacks/test_basic_callback.py +++ b/tests/integration/callbacks/test_basic_callback.py @@ -917,3 +917,194 @@ def on_click(_): assert error.text == error_title for error_text in dash_duo.find_elements(".dash-backend-error"): assert all(line in error_text for line in error_message) + + +def test_cbsc022_no_output_callback_initial_call(dash_duo): + """Test that no-output callbacks fire on initial load.""" + + call_count = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div(id="output"), + ] + ) + + @app.callback( + Input("btn", "n_clicks"), + ) + def no_output_callback(n_clicks): + call_count.value += 1 + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + ) + def with_output_callback(n_clicks): + return f"Clicks: {n_clicks}" + + dash_duo.start_server(app) + + # Wait for initial render + dash_duo.wait_for_text_to_equal("#output", "Clicks: 0") + + # No-output callback should have fired on initial load + assert call_count.value == 1, "no-output callback should fire on initial load" + + # Click button + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicks: 1") + + # No-output callback should have fired again + assert call_count.value == 2, "no-output callback should fire on click" + + assert dash_duo.get_logs() == [] + + +def test_cbsc023_no_input_callback_initial_call(dash_duo): + """Test that no-input callbacks fire on initial load (issue #3411).""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Store(id="store", data="initial"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + State("store", "data"), + ) + def no_input_callback(data): + return f"Data: {data}" + + dash_duo.start_server(app) + + # No-input callback should fire on initial load + dash_duo.wait_for_text_to_equal("#output", "Data: initial") + + assert dash_duo.get_logs() == [] + + +def test_cbsc024_no_input_no_output_callback_initial_call(dash_duo): + """Test that callbacks with no input and no output fire on initial load.""" + from multiprocessing import Value + + call_count = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback() + def no_input_no_output_callback(): + call_count.value += 1 + print(f"No-input no-output callback fired: {call_count.value}") + + dash_duo.start_server(app) + + # Give it time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # Callback should have fired on initial load + assert ( + call_count.value == 1 + ), "no-input no-output callback should fire on initial load" + + assert dash_duo.get_logs() == [] + + +def test_cbsc025_multiple_no_input_no_output_callbacks(dash_duo): + """Test that multiple no-input no-output callbacks all fire on initial load.""" + from multiprocessing import Value + + call_count_1 = Value("i", 0) + call_count_2 = Value("i", 0) + call_count_3 = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback() + def first_callback(): + call_count_1.value += 1 + + @app.callback() + def second_callback(): + call_count_2.value += 1 + + @app.callback() + def third_callback(): + call_count_3.value += 1 + + dash_duo.start_server(app) + + # Give callbacks time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # All callbacks should have fired on initial load + assert call_count_1.value == 1, "first callback should fire" + assert call_count_2.value == 1, "second callback should fire" + assert call_count_3.value == 1, "third callback should fire" + + assert dash_duo.get_logs() == [] + + +def test_cbsc026_no_input_with_duplicate_outputs(dash_duo): + """Test no-input callbacks with duplicate outputs.""" + from multiprocessing import Value + + call_count_1 = Value("i", 0) + call_count_2 = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Store(id="store", data="initial"), + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback( + Output("output", "children"), + State("store", "data"), + ) + def first_no_input_callback(data): + call_count_1.value += 1 + return f"First: {data}" + + @app.callback( + Output("output", "children", allow_duplicate=True), + State("store", "data"), + prevent_initial_call="initial_duplicate", + ) + def second_no_input_callback(data): + call_count_2.value += 1 + return f"Second: {data}" + + dash_duo.start_server(app) + + # Give callbacks time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # Both callbacks should have fired on initial load + assert call_count_1.value == 1, "first no-input callback should fire" + assert call_count_2.value == 1, "second no-input callback should fire" + + # Output should contain result from one of the callbacks + output_text = dash_duo.find_element("#output").text + assert "initial" in output_text, "output should contain data from store" + + assert dash_duo.get_logs() == [] diff --git a/tests/integration/devtools/test_callback_validation.py b/tests/integration/devtools/test_callback_validation.py index eaee814980..8501821886 100644 --- a/tests/integration/devtools/test_callback_validation.py +++ b/tests/integration/devtools/test_callback_validation.py @@ -69,10 +69,26 @@ def check_errors(dash_duo, specs): def test_dvcv001_blank(dash_duo): + """No-input no-output callbacks are allowed when prevent_initial_call=False (default).""" app = Dash(__name__) app.layout = html.Div() - @app.callback([], []) + @app.callback() + def x(): + pass # No-output callbacks shouldn't return anything + + dash_duo.start_server(app, **debugging) + # No errors expected - no-input callbacks are allowed when prevent_initial_call=False + dash_duo.wait_for_element("div") + assert dash_duo.get_logs() == [] + + +def test_dvcv001b_blank_prevent_initial_call(dash_duo): + """No-input callbacks should error when prevent_initial_call=True.""" + app = Dash(__name__) + app.layout = html.Div() + + @app.callback([], [], prevent_initial_call=True) def x(): return 42 diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index 40d5731202..005bf8c335 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "dash.py" in error0 + assert "backend" in error0 and "flask.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "dash.py" in error1 + assert "backend" in error1 and "flask.py" in error1 assert "self.wsgi_app" in error1 diff --git a/tests/integration/multi_page/test_pages_layout.py b/tests/integration/multi_page/test_pages_layout.py index 48751021b9..a209ae4517 100644 --- a/tests/integration/multi_page/test_pages_layout.py +++ b/tests/integration/multi_page/test_pages_layout.py @@ -3,6 +3,7 @@ from dash import Dash, Input, State, dcc, html, Output from dash.dash import _ID_LOCATION from dash.exceptions import NoLayoutException +from dash.testing.wait import until def get_app(path1="/", path2="/layout2"): @@ -57,7 +58,7 @@ def test_pala001_layout(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"], timeout=3) # test redirects dash_duo.wait_for_page(url=f"{dash_duo.server_url}/v2") diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 6c505ac3f5..24e7209a70 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -2,6 +2,7 @@ import dash from dash import Dash, dcc, html +from dash.testing.wait import until def get_app(app): @@ -70,7 +71,7 @@ def test_pare002_relative_path_with_url_base_pathname( for page in dash.page_registry.values(): dash_br.find_element("#" + page["id"]).click() dash_br.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_br.driver.title == page["title"], "check that page title updates" + until(lambda: dash_br.driver.title == page["title"], timeout=3) assert dash_br.get_logs() == [], "browser console should contain no error" @@ -83,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"], timeout=3) assert dash_duo.get_logs() == [], "browser console should contain no error" diff --git a/tests/integration/renderer/test_loading_states.py b/tests/integration/renderer/test_loading_states.py index 169b505ed1..9818902f61 100644 --- a/tests/integration/renderer/test_loading_states.py +++ b/tests/integration/renderer/test_loading_states.py @@ -298,3 +298,61 @@ def update(n): dash_duo.wait_for_text_to_equal("#final-output", "1") until(lambda: dash_duo.driver.title == "Page 1", timeout=1) + + +def test_rdls005_persistent_callback_no_update_title(dash_duo): + """Test that persistent=True callbacks don't trigger the 'Updating...' title.""" + lock = Lock() + + app = Dash(__name__) + + app.layout = html.Div( + children=[ + html.H3("Test persistent callback"), + html.Button("Persistent", id="persistent-btn", n_clicks=0), + html.Button("Regular", id="regular-btn", n_clicks=0), + html.Div(id="persistent-output"), + html.Div(id="regular-output"), + ] + ) + + @app.callback( + Output("persistent-output", "children"), + Input("persistent-btn", "n_clicks"), + persistent=True, + ) + def persistent_update(n): + with lock: + return f"Persistent: {n}" + + @app.callback( + Output("regular-output", "children"), + Input("regular-btn", "n_clicks"), + ) + def regular_update(n): + with lock: + return f"Regular: {n}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#persistent-output", "Persistent: 0") + dash_duo.wait_for_text_to_equal("#regular-output", "Regular: 0") + + # Verify title is "Dash" after initial load + until(lambda: dash_duo.driver.title == "Dash", timeout=1) + + # Test that persistent callback does NOT change title to "Updating..." + with lock: + dash_duo.find_element("#persistent-btn").click() + # Title should remain "Dash" even while callback is running + until(lambda: dash_duo.driver.title == "Dash", timeout=1) + + dash_duo.wait_for_text_to_equal("#persistent-output", "Persistent: 1") + + # Test that regular callback DOES change title to "Updating..." + with lock: + dash_duo.find_element("#regular-btn").click() + until(lambda: dash_duo.driver.title == "Updating...", timeout=1) + + dash_duo.wait_for_text_to_equal("#regular-output", "Regular: 1") + # Title should revert after callback completes + until(lambda: dash_duo.driver.title == "Dash", timeout=1) diff --git a/tests/integration/renderer/test_render_type.py b/tests/integration/renderer/test_render_type.py index 17a6cfbae3..417be6e586 100644 --- a/tests/integration/renderer/test_render_type.py +++ b/tests/integration/renderer/test_render_type.py @@ -25,6 +25,7 @@ def test_rtype001_rendertype(dash_duo): dash_clientside.set_props('render_test', {n_clicks: 20}) }""", Input("clientside_render", "n_clicks"), + prevent_initial_call=True, ) @app.callback( diff --git a/tests/websocket/__init__.py b/tests/websocket/__init__.py new file mode 100644 index 0000000000..1116026afc --- /dev/null +++ b/tests/websocket/__init__.py @@ -0,0 +1 @@ +# WebSocket callback tests diff --git a/tests/websocket/conftest.py b/tests/websocket/conftest.py new file mode 100644 index 0000000000..d72fcd04dc --- /dev/null +++ b/tests/websocket/conftest.py @@ -0,0 +1,12 @@ +import pytest +from dash import hooks + + +@pytest.fixture +def ws_hook_cleanup(): + """Clean up WebSocket hooks after each test.""" + yield + hooks._ns["websocket_connect"] = [] + hooks._ns["websocket_message"] = [] + hooks._finals.pop("websocket_connect", None) + hooks._finals.pop("websocket_message", None) diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py new file mode 100644 index 0000000000..1d74706a68 --- /dev/null +++ b/tests/websocket/test_ws_basic.py @@ -0,0 +1,254 @@ +""" +Basic WebSocket callback tests. + +Tests: +- Per-callback websocket (websocket=True) +- Global websocket callbacks (websocket_callbacks=True) +- Mixed HTTP and WebSocket callbacks +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_ws001_per_callback_websocket(dash_duo): + """Test single callback with websocket=True on FastAPI backend.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_ws002_global_websocket_callbacks(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_ws003_mixed_http_and_websocket(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_ws004_websocket_with_state(dash_duo): + """Test WebSocket callback with State inputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_ws005_websocket_context_available(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_ws006_websocket_multiple_outputs(dash_duo): + """Test WebSocket callback with multiple outputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] + + +def test_ws007_websocket_slider_callback(dash_duo): + """Test WebSocket callback with dcc.Slider component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Slider(id="slider", min=0, max=100, value=50, step=10), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("slider", "value")) + def update_output(value): + return f"Slider value: {value}" + + dash_duo.start_server(app) + + # Initial callback should work via WebSocket + dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_hooks.py b/tests/websocket/test_ws_hooks.py new file mode 100644 index 0000000000..db0a166efd --- /dev/null +++ b/tests/websocket/test_ws_hooks.py @@ -0,0 +1,285 @@ +""" +WebSocket hooks tests. + +Tests: +- websocket_connect hook - accept/reject connections +- websocket_message hook - accept/reject messages +- Custom close codes and reasons +""" + +from dash import Dash, html, Input, Output, hooks + + +def test_ws010_connect_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that accepts all connections.""" + connection_count = {"value": 0} + + @hooks.websocket_connect() + def allow_all(websocket): + connection_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Hook should have been called at least once for connection + assert connection_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws011_connect_hook_reject_false(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with False. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ + + @hooks.websocket_connect() + def reject_all(websocket): + return False + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) # Give time for potential callback + assert dash_duo.find_element("#output").text == "initial" + + dash_duo.find_element("#btn").click() + time.sleep(1) + # Still initial since WebSocket was rejected + assert dash_duo.find_element("#output").text == "initial" + + +def test_ws012_connect_hook_reject_tuple(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with custom code/reason. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ + + @hooks.websocket_connect() + def reject_with_reason(websocket): + return (4001, "Connection not allowed") + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" + + dash_duo.find_element("#btn").click() + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" + + +def test_ws013_message_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that accepts all messages.""" + message_count = {"value": 0} + + @hooks.websocket_message() + def allow_all_messages(websocket, message): + message_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Message hook should have been called + assert message_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws014_message_hook_reject(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that rejects specific messages.""" + reject_clicks = {"should_reject": False} + + @hooks.websocket_message() + def conditional_reject(websocket, message): + if reject_clicks["should_reject"]: + return (4010, "Message rejected") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # First click should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws015_async_connect_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_connect hook.""" + import asyncio + + @hooks.websocket_connect() + async def async_validate(websocket): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws016_async_message_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_message hook.""" + import asyncio + + @hooks.websocket_message() + async def async_validate_message(websocket, message): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws017_multiple_connect_hooks(dash_duo, ws_hook_cleanup): + """Test multiple websocket_connect hooks with priorities.""" + hook_order = [] + + @hooks.websocket_connect(priority=1) + def first_hook(websocket): + hook_order.append("first") + return True + + @hooks.websocket_connect(priority=2) + def second_hook(websocket): + hook_order.append("second") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Both hooks should have been called + assert "first" in hook_order + assert "second" in hook_order + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_inactivity.py b/tests/websocket/test_ws_inactivity.py new file mode 100644 index 0000000000..8dd95e1094 --- /dev/null +++ b/tests/websocket/test_ws_inactivity.py @@ -0,0 +1,194 @@ +""" +WebSocket inactivity timeout tests. + +Tests: +- Connection closes after inactivity period +- Activity resets the timer +- Heartbeats don't count as activity +- Auto-reconnect when callback fires after timeout +""" + +import time +from dash import Dash, html, Input, Output + + +def test_ws020_inactivity_timeout_closes(dash_duo): + """Test that WebSocket connection closes after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds for testing + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Trigger callback to establish connection + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for inactivity timeout + time.sleep(4) + + # Click again - should auto-reconnect and work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + +def test_ws021_activity_resets_timer(dash_duo): + """Test that callback activity resets the inactivity timer.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=4000, # 4 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + + # Click every 2 seconds - should keep connection alive + for i in range(1, 4): + time.sleep(2) + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", f"Clicked {i}") + + # All clicks should work without disconnection + assert dash_duo.get_logs() == [] + + +def test_ws022_quick_successive_callbacks(dash_duo): + """Test rapid successive callbacks work correctly.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=5000, + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Rapid clicks + for _ in range(5): + dash_duo.find_element("#btn").click() + time.sleep(0.1) + + dash_duo.wait_for_text_to_equal("#output", "5") + assert dash_duo.get_logs() == [] + + +def test_ws023_auto_reconnect_after_timeout(dash_duo): + """Test auto-reconnect when callback fires after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Initial callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for timeout to expire + time.sleep(3) + + # Click again - should auto-reconnect + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + # And keep working + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 3") + + assert dash_duo.get_logs() == [] + + +def test_ws024_long_callback_doesnt_timeout(dash_duo): + """Test that long-running callbacks don't cause timeout during execution.""" + import asyncio + + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start Long Task", id="btn"), + html.Div("ready", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + async def long_task(n_clicks): + if not n_clicks: + return "ready" + # Simulate long task (longer than inactivity timeout) + await asyncio.sleep(2) + return f"Completed task {n_clicks}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "ready") + + # Start long task + dash_duo.find_element("#btn").click() + + # Should complete despite being longer than half the timeout + dash_duo.wait_for_text_to_equal("#output", "Completed task 1", timeout=10) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_origin.py b/tests/websocket/test_ws_origin.py new file mode 100644 index 0000000000..c6235613a5 --- /dev/null +++ b/tests/websocket/test_ws_origin.py @@ -0,0 +1,154 @@ +""" +WebSocket origin validation tests. + +Tests: +- Same-origin connections allowed by default +- Cross-origin rejected unless explicitly allowed +- websocket_allowed_origins configuration +""" + +from dash import Dash, html, Input, Output + + +def test_ws040_same_origin_allowed(dash_duo): + """Test that same-origin WebSocket connections work by default.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin request should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws041_websocket_allowed_origins_empty(dash_duo): + """Test with empty websocket_allowed_origins (only same-origin).""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=[], # Only same-origin + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin should still work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws042_websocket_allowed_origins_wildcard(dash_duo): + """Test with wildcard in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["*"], # Allow all origins + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws043_websocket_allowed_origins_specific(dash_duo): + """Test with specific origins in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Should work since we're running on localhost + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws044_origin_with_per_callback_websocket(dash_duo): + """Test origin validation with per-callback websocket=True.""" + app = Dash( + __name__, + backend="fastapi", + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback( + Output("output", "children"), Input("btn", "n_clicks"), websocket=True + ) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_patch.py b/tests/websocket/test_ws_patch.py new file mode 100644 index 0000000000..f278b5f26e --- /dev/null +++ b/tests/websocket/test_ws_patch.py @@ -0,0 +1,42 @@ +""" +WebSocket set_props with Patch object test. + +Verifies that set_props works with Patch objects in websocket callbacks. +""" + +from dash import Dash, html, Input, Output, set_props, Patch +from dash.exceptions import PreventUpdate + + +def test_ws037_set_props_with_patch(dash_duo): + """Test set_props with Patch object in websocket callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Patch", id="btn"), + html.Div("initial", id="output"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + def patch_append(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" + click {n}" + + set_props("output", {"children": p}) + return f"Appended {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output", "initial + click 1", timeout=10) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py new file mode 100644 index 0000000000..a86402954d --- /dev/null +++ b/tests/websocket/test_ws_props.py @@ -0,0 +1,471 @@ +""" +WebSocket set_props and get_props tests. + +Tests: +- set_props streaming during long-running callback +- get_prop reads current component value +- async set_prop method +- set_props with Patch objects (bug fix for component property updates) +- set_props with pattern-matching components triggering MATCH callbacks +""" + +import asyncio +from dash import Dash, html, Input, Output, State, set_props, MATCH +from dash.exceptions import PreventUpdate + + +def test_ws030_set_props_streaming(dash_duo): + """Test that set_props streams updates during callback execution.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Start", id="btn"), + html.Div("0%", id="progress"), + html.Div("waiting", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def long_task(n): + if not n: + raise PreventUpdate + + for i in range(1, 6): + set_props("progress", {"children": f"{i * 20}%"}) + await asyncio.sleep(0.1) + + return "Done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#progress", "0%") + dash_duo.wait_for_text_to_equal("#result", "waiting") + + dash_duo.find_element("#btn").click() + + # Should see progress updates and final result + dash_duo.wait_for_text_to_equal("#result", "Done", timeout=10) + # Final progress should be 100% + dash_duo.wait_for_text_to_equal("#progress", "100%") + + assert dash_duo.get_logs() == [] + + +def test_ws031_set_props_multiple_components(dash_duo): + """Test set_props updating multiple components during callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update All", id="btn"), + html.Div("A: initial", id="output-a"), + html.Div("B: initial", id="output-b"), + html.Div("C: initial", id="output-c"), + html.Div("result", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_all(n): + if not n: + raise PreventUpdate + + set_props("output-a", {"children": f"A: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-b", {"children": f"B: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-c", {"children": f"C: updated {n}"}) + + return f"All updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output-a", "A: updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#output-b", "B: updated 1") + dash_duo.wait_for_text_to_equal("#output-c", "C: updated 1") + dash_duo.wait_for_text_to_equal("#result", "All updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws032_set_props_with_complex_values(dash_duo): + """Test set_props with various value types.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Test Values", id="btn"), + html.Div(id="text-output"), + html.Div(id="number-output"), + html.Div(id="list-output"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def test_values(n): + if not n: + raise PreventUpdate + + # String + set_props("text-output", {"children": "Hello World"}) + await asyncio.sleep(0.02) + + # Number as string + set_props("number-output", {"children": str(42)}) + await asyncio.sleep(0.02) + + # List of strings + set_props("list-output", {"children": ["Item 1", " - ", "Item 2"]}) + + return "Values set" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#text-output", "Hello World", timeout=10) + dash_duo.wait_for_text_to_equal("#number-output", "42") + dash_duo.wait_for_text_to_equal("#list-output", "Item 1 - Item 2") + dash_duo.wait_for_text_to_equal("#result", "Values set") + + assert dash_duo.get_logs() == [] + + +def test_ws033_set_props_sync_callback(dash_duo): + """Test set_props in synchronous callback with WebSocket.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Sync Update", id="btn"), + html.Div("before", id="side-effect"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + def sync_update(n): + if not n: + raise PreventUpdate + + # set_props should work in sync callback too + set_props("side-effect", {"children": f"Side effect {n}"}) + return f"Result {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Result 1", timeout=10) + dash_duo.wait_for_text_to_equal("#side-effect", "Side effect 1") + + assert dash_duo.get_logs() == [] + + +def test_ws034_get_prop_reads_value(dash_duo): + """Test that get_prop can read current component values.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Div("Source Value", id="source"), + html.Button("Read", id="btn"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def read_prop(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.websocket + if ws: + value = await ws.get_prop("source", "children") + return f"Read: {value}" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Read: Source Value", timeout=10) + + assert dash_duo.get_logs() == [] + + +def test_ws035_websocket_set_prop_method(dash_duo): + """Test using ws.set_prop() method directly.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Set via WS", id="btn"), + html.Div("original", id="target"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def set_via_ws(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.websocket + if ws: + await ws.set_prop("target", "children", f"Set via WebSocket {n}") + return "Set complete" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#target", "Set via WebSocket 1", timeout=10) + dash_duo.wait_for_text_to_equal("#result", "Set complete") + + assert dash_duo.get_logs() == [] + + +def test_ws036_set_props_dict_component_id(dash_duo): + """Test set_props with dict component ID (pattern matching).""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div("initial", id={"type": "output", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_with_dict_id(n): + if not n: + raise PreventUpdate + + set_props({"type": "output", "index": 0}, {"children": f"Updated {n}"}) + return f"Done {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + # Use attribute selector for the dict ID + dash_duo.wait_for_text_to_equal( + '[id=\'{"index":0,"type":"output"}\']', "Updated 1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#result", "Done 1") + + assert dash_duo.get_logs() == [] + + +def test_ws045_set_props_component_prop_children(dash_duo): + """Test set_props updating component props like Div's children with component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update Children", id="btn"), + html.Div(id="container"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_children(n): + if not n: + raise PreventUpdate + + set_props( + "container", + { + "children": html.Div( + [ + html.Span(f"Updated {n}"), + html.B(" - Bold Text"), + ] + ) + }, + ) + return f"Children updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#container span", "Updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#container b", "- Bold Text") + dash_duo.wait_for_text_to_equal("#result", "Children updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws046_set_props_nested_component_children(dash_duo): + """Test set_props with nested component in children prop.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update Nested", id="btn"), + html.Div(id="wrapper"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_nested(n): + if not n: + raise PreventUpdate + + set_props( + "wrapper", + { + "children": html.Div( + [ + html.Ul( + [ + html.Li(f"Item {n}.1"), + html.Li(f"Item {n}.2"), + ] + ) + ] + ) + }, + ) + return f"Nested updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal( + "#wrapper ul li:first-child", "Item 1.1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#wrapper ul li:last-child", "Item 1.2") + dash_duo.wait_for_text_to_equal("#result", "Nested updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws047_set_props_children_with_list(dash_duo): + """Test set_props with list of components wrapped in a single component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update List", id="btn"), + html.Div(id="list-container"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_list(n): + if not n: + raise PreventUpdate + + set_props( + "list-container", + { + "children": html.Div( + [ + html.Div(f"Item 1 - {n}"), + html.Div(f"Item 2 - {n}"), + html.Div(f"Item 3 - {n}"), + ] + ) + }, + ) + return f"List updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "List updated 1", timeout=10) + assert "Item 1 - 1" in dash_duo.find_element("#list-container").text + assert "Item 2 - 1" in dash_duo.find_element("#list-container").text + assert "Item 3 - 1" in dash_duo.find_element("#list-container").text + + assert dash_duo.get_logs() == [] + + +def test_ws048_set_props_dynamic_match_callback(dash_duo): + """Test set_props injecting components with pattern-matching IDs that trigger MATCH callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Add Component", id="add-btn"), + html.Div(id="container"), + html.Div("waiting", id="match-result"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("add-btn", "n_clicks")) + async def add_component(n): + if not n: + raise PreventUpdate + + # Inject component with pattern-matching ID via set_props + set_props( + "container", + { + "children": html.Div( + [ + html.Span("Hello"), + html.Button("Click me", id={"type": "dynamic", "index": 0}), + ] + ) + }, + ) + return f"Component added {n}" + + @app.callback( + Output("match-result", "children"), + Input({"type": "dynamic", "index": MATCH}, "n_clicks"), + State({"type": "dynamic", "index": MATCH}, "id"), + prevent_initial_call=True, + ) + def handle_dynamic_click(n_clicks, btn_id): + if not n_clicks: + raise PreventUpdate + return f"Clicked button index {btn_id['index']} - {n_clicks} times" + + dash_duo.start_server(app) + + # Initial state + dash_duo.wait_for_text_to_equal("#match-result", "waiting") + + # Add the dynamic component + dash_duo.find_element("#add-btn").click() + dash_duo.wait_for_text_to_equal("#result", "Component added 1", timeout=10) + + # Verify the component was added + dash_duo.wait_for_text_to_equal("#container span", "Hello", timeout=5) + + # Click the dynamically added button with pattern-matching ID + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + + # Verify the MATCH callback fired + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 1 times", timeout=10 + ) + + # Click again to verify it continues to work + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 2 times", timeout=10 + ) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py new file mode 100644 index 0000000000..3d40493ba5 --- /dev/null +++ b/tests/websocket/test_ws_quart.py @@ -0,0 +1,228 @@ +""" +Quart WebSocket callback tests. + +Tests the Quart backend websocket implementation which mirrors the FastAPI backend. +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_wsq001_per_callback_websocket_quart(dash_duo): + """Test single callback with websocket=True on Quart backend.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test (Quart)"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_wsq002_global_websocket_callbacks_quart(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks on Quart.""" + app = Dash( + __name__, + backend="quart", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_wsq003_mixed_http_and_websocket_quart(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app on Quart.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_wsq004_websocket_with_state_quart(dash_duo): + """Test WebSocket callback with State inputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_wsq005_websocket_context_available_quart(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_wsq006_websocket_multiple_outputs_quart(dash_duo): + """Test WebSocket callback with multiple outputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_validate.py b/tests/websocket/test_ws_validate.py new file mode 100644 index 0000000000..0a43ada553 --- /dev/null +++ b/tests/websocket/test_ws_validate.py @@ -0,0 +1,58 @@ +import pytest + +from dash.exceptions import WebSocketCallbackError +from dash._validate import validate_websocket_callback_request + + +class TestWebsocketCallbackRequestValidation: + """Tests for runtime WebSocket callback request validation.""" + + def test_global_enabled_allows_any_callback(self): + """When websocket_callbacks=True globally, any callback can use WebSocket.""" + callback_map = { + "out1.children": {"websocket": False}, + "out2.children": {}, # no websocket key + } + # Should not raise - global setting allows all + validate_websocket_callback_request("out1.children", callback_map, True) + validate_websocket_callback_request("out2.children", callback_map, True) + + def test_per_callback_websocket_enabled_passes(self): + """Callback with websocket=True should pass when global is False.""" + callback_map = { + "out1.children": {"websocket": True}, + } + # Should not raise + validate_websocket_callback_request("out1.children", callback_map, False) + + def test_per_callback_websocket_disabled_raises(self): + """Callback without websocket=True should raise when global is False.""" + callback_map = { + "out1.children": {"websocket": False}, + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + assert "websocket=True" in str(exc_info.value) + + def test_callback_without_websocket_key_raises(self): + """Callback without websocket key should raise when global is False.""" + callback_map = { + "out1.children": {}, # no websocket key + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + + def test_unknown_callback_raises(self): + """Unknown callback ID should raise when global is False.""" + callback_map = {} + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("unknown.children", callback_map, False) + + assert "unknown.children" in str(exc_info.value)