#!/usr/bin/env python3
"""
Oracle DB MCP Server - Tool Registration

This module handles the registration of all Oracle database tools with the MCP server.
"""

import os
import logging
import json
from typing import Any, Dict, List, Optional

from mcp.server import Server
from mcp.types import (
    CallToolResult,
    TextContent,
    Tool,
)

from server.oracle.pool import OraclePool
from server.oracle.exec import OracleExecutor
from server.oracle.meta import OracleMetadata
from server.oracle.plsql import OraclePLSQL
from util.log import setup_logging


def create_server() -> Server:
    """
    Create and configure the MCP server with Oracle database tools.

    Returns:
        Server: Configured MCP server instance
    """
    # Setup logging
    setup_logging()
    logger = logging.getLogger(__name__)

    # Create MCP server
    server = Server("oracle-db")

    # Initialize Oracle components
    pool = OraclePool()
    executor = OracleExecutor(pool)
    metadata = OracleMetadata(pool)
    plsql = OraclePLSQL(pool)

    def _build_tool_list():
        """Return the curated list of public MCP tools for Oracle DB operations.

        Only the tools approved in the current scope are exposed to the MCP client.
        Internal helper implementations remain in server/oracle/* for debugging and
        future use but are not published here unless explicitly added.
        """
        return [
            Tool(
                name="sql.select",
                description=(
                    "Execute a SELECT query that returns rows. This tool only accepts "
                    "SELECT statements and returns a payload with 'columns' and 'rows'."
                ),
                inputSchema={
                    "type": "object",
                    "properties": {
                        "sql": {"type": "string", "description": "SELECT statement to execute"},
                        "query": {"type": "string", "description": "Alias for 'sql'"},
                        "params": {"type": "object", "additionalProperties": True, "description": "Bind parameters"},
                        "max_rows": {"type": "integer", "minimum": 1, "description": "Optional maximum rows to return"},
                        "timeout_ms": {"type": "integer", "minimum": 1, "description": "Optional timeout in milliseconds"},
                        "timeout_s": {"type": "integer", "minimum": 1, "description": "Optional timeout in seconds"},
                    },
                    "required": ["sql"],
                },
            ),
            Tool(
                name="sql.execute",
                description=(
                    "Execute non-SELECT SQL: DML, DDL, or PL/SQL blocks. Supports bind parameters "
                    "and will auto-bind missing names as OUT variables when executing PL/SQL."
                ),
                inputSchema={
                    "type": "object",
                    "properties": {
                        "sql": {"type": "string", "description": "SQL or PL/SQL block to execute"},
                        "query": {"type": "string", "description": "Alias for 'sql'"},
                        "params": {"type": "object", "additionalProperties": True, "description": "Bind parameters (use null for OUT binds)"},
                        "timeout_ms": {"type": "integer", "minimum": 1, "description": "Optional timeout in milliseconds"},
                        "timeout_s": {"type": "integer", "minimum": 1, "description": "Optional timeout in seconds"},
                    },
                    "required": ["sql"],
                },
            ),
            Tool(
                name="db.ping",
                description="Diagnostics: check connectivity, return db_version, agent_version, capabilities and settings.",
                inputSchema={"type": "object", "properties": {}, "required": []},
            ),
        ]

    @server.list_tools()
    async def list_tools() -> List[Tool]:
        """List all available Oracle database tools."""
        return _build_tool_list()

    @server.call_tool()
    async def call_tool(name: str, arguments: Dict[str, Any]) -> CallToolResult:
        """Handle tool calls for the curated set of Oracle database operations."""
        try:
            # db.ping -> diagnostics (support both 'db.ping' and legacy 'ping')
            if name == "db.ping" or name == "ping":
                result = pool.ping()
                json_text = json.dumps(result, ensure_ascii=False)
                return CallToolResult(content=[TextContent(type="text", text=json_text)])
    
            # sql.select -> strict SELECT queries
            elif name == "sql.select":
                sql_text = arguments.get("sql") or arguments.get("query")
                params = arguments.get("params")
                timeout_ms = arguments.get("timeout_ms")
                timeout_s = arguments.get("timeout_s")
                if timeout_s is None and timeout_ms is not None:
                    try:
                        timeout_s = int(timeout_ms) // 1000
                    except Exception:
                        timeout_s = None
                max_rows = arguments.get("max_rows")
    
                result = await executor.sql_select(sql_text, params=params, timeout_s=timeout_s, max_rows=max_rows)
                json_text = json.dumps(result, ensure_ascii=False, default=str)
                return CallToolResult(content=[TextContent(type="text", text=json_text)])
    
            # sql.execute -> DML / DDL / PL/SQL
            elif name == "sql.execute":
                sql_text = arguments.get("sql") or arguments.get("query")
                params = arguments.get("params")
                timeout_ms = arguments.get("timeout_ms")
                timeout_s = arguments.get("timeout_s")
                if timeout_s is None and timeout_ms is not None:
                    try:
                        timeout_s = int(timeout_ms) // 1000
                    except Exception:
                        timeout_s = None
    
                result = await executor.sql_execute(sql_text, params=params, timeout_s=timeout_s)
                json_text = json.dumps(result, ensure_ascii=False, default=str)
                return CallToolResult(content=[TextContent(type="text", text=json_text)])
    
            else:
                raise ValueError(f"Unknown tool: {name}")
    
        except Exception as e:
            logger.error(f"Error executing tool {name}: {e}")
            err = {"status": "error", "message": str(e)}
            json_text = json.dumps(err, ensure_ascii=False)
            return CallToolResult(
                content=[TextContent(type="text", text=json_text)],
                isError=True,
            )

    return server
