#!/usr/bin/env python3
"""
Oracle Database Query Execution

This module handles SQL query execution, explain plans, and session information.
"""

import os
import logging
from typing import Any, Dict, List, Optional, Union

import cx_Oracle
import datetime
import re

from server.oracle.pool import OraclePool
from util.json import OracleJSONEncoder


class OracleExecutor:
    """
    Handles execution of SQL queries against Oracle database.
    """

    def __init__(self, pool: OraclePool):
        """
        Initialize the executor with a connection pool.

        Args:
            pool: OraclePool instance
        """
        self.pool = pool
        self.logger = logging.getLogger(__name__)
        self.max_rows = int(os.getenv("MCP_MAX_ROWS", "500"))
        self.max_text_size = int(os.getenv("MCP_MAX_TEXT_SIZE", "5000"))

    async def ping(self) -> Dict[str, Any]:
        """
        Test database connectivity.

        Returns:
            dict: Ping result with connection status
        """
        return await self.pool.test_connection()

    async def execute_query(
        self, query: str, max_rows: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Execute a SQL query and return results.

        Args:
            query: SQL query to execute
            max_rows: Maximum number of rows to return (overrides default)

        Returns:
            dict: Query results with metadata
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not query or not query.strip():
                return {"status": "error", "message": "Empty query provided"}

            # Use provided max_rows or default
            limit = max_rows if max_rows is not None else self.max_rows

            # Get connection from pool
            conn = self.pool.get_connection()
            cursor = conn.cursor()

            self.logger.info(f"Executing query: {query[:500]}...")

            # Execute the query
            cursor.execute(query)

            # Get column information
            columns = (
                [col[0] for col in cursor.description] if cursor.description else []
            )

            # Fetch results
            rows = []
            row_count = 0

            for row in cursor:
                if row_count >= limit:
                    break

                # Process each column value
                processed_row = []
                for value in row:
                    processed_value = self._process_value(value)
                    processed_row.append(processed_value)

                rows.append(dict(zip(columns, processed_row)))
                row_count += 1

            # Get row count (total rows, not just limited)
            if cursor.rowcount >= 0:
                total_rows = cursor.rowcount
            else:
                total_rows = row_count

            cursor.close()

            return {
                "status": "success",
                "query": query,
                "columns": columns,
                "rows": rows,
                "row_count": row_count,
                "total_rows": total_rows,
                "limited": row_count >= limit,
                "limit": limit,
            }

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error executing query: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error executing query: {e}")
            return {"status": "error", "message": f"Error executing query: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)
    async def sql_select(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        timeout_s: Optional[int] = None,
        max_rows: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Execute a SELECT statement only. Returns columns (list) and rows (list of lists).
        This enforces that only SELECT statements are allowed for this method.
        """
        # Basic validation: non-empty and starts with SELECT (allow leading whitespace/comments)
        if not sql or not sql.strip():
            return {"status": "error", "message": "Empty SQL provided"}
        # Strip leading whitespace and common SQL comments (simple)
        lead = sql.lstrip()
        if not lead[:6].upper().startswith("SELECT"):
            return {"status": "error", "message": "sql_select only supports SELECT statements"}

        # Delegate to the existing query path but adapt result shape
        result = await self.query(sql, params=params, max_rows=max_rows, timeout_s=timeout_s)

        if result.get("status") != "success":
            return result

        # Convert rows from list[dict] -> list[list] in column order
        columns = result.get("columns", [])
        dict_rows = result.get("rows", [])
        rows_array = []
        for r in dict_rows:
            row_vals = [r.get(col) for col in columns]
            rows_array.append(row_vals)

        return {"status": "success", "columns": columns, "rows": rows_array}

    async def sql_execute(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        timeout_s: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Execute DML/DDL/PLSQL statements. Handles:
         - DML/DDL: returns rows_affected and message
         - PL/SQL blocks: supports OUT bind variables by auto-binding missing params
           (any bind variable present in the SQL but not provided in params is treated
           as an OUT variable and will be returned).
        """
        conn = None
        cursor = None
        try:
            if not sql or not sql.strip():
                return {"status": "error", "message": "Empty SQL provided"}

            original_sql = sql
            # Do not remove semicolons for PL/SQL/CREATE/BEGIN etc.; preserve original.
            timeout_ms = None
            if timeout_s is not None:
                try:
                    timeout_ms = int(timeout_s) * 1000
                except Exception:
                    timeout_ms = None

            conn = self.pool.get_connection()
            if timeout_ms:
                try:
                    conn.callTimeout = timeout_ms
                except Exception:
                    pass

            cursor = conn.cursor()

            # Find bind variables in SQL of the form :name (exclude :: or :1 numeric binds)
            bind_names = set(re.findall(r":([A-Za-z_][\w$#]*)", sql))
            provided = set(params.keys()) if params else set()
            missing = bind_names - provided

            # Build binds dict to pass to cursor.execute
            binds = {}
            if params:
                # shallow copy params
                binds.update(params)

            auto_out_vars = {}
            for name in missing:
                # Default to STRING var with generous size; callers can pass explicit cx_Oracle vars if needed
                try:
                    var = cursor.var(cx_Oracle.STRING, 4000)
                except Exception:
                    # Fallback to a generic object if var creation fails
                    var = cursor.var(cx_Oracle.STRING, 4000) if hasattr(cursor, "var") else None
                binds[name] = var
                auto_out_vars[name] = var

            # Execute with binds
            if binds:
                cursor.execute(sql, binds)
            else:
                cursor.execute(sql)

            # If no result set (DDL/DML/most PL/SQL), commit and return rows_affected
            if cursor.description is None:
                try:
                    conn.commit()
                except Exception:
                    pass

                rowcount = getattr(cursor, "rowcount", -1)
                if rowcount is None or rowcount < 0:
                    rowcount = 0

                # Collect output values from auto-bound OUT variables
                output_values = {}
                for name, var in auto_out_vars.items():
                    try:
                        val = var.getvalue() if hasattr(var, "getvalue") else var
                        output_values[name] = self._process_value(val)
                    except Exception:
                        output_values[name] = None

                # If there is exactly one OUT variable and it's named 'ret' or 'return' map it to return_value
                return_value = None
                if len(output_values) == 1:
                    only_name = next(iter(output_values.keys()))
                    if only_name.lower() in ("ret", "return", "retval", "return_value"):
                        return_value = output_values[only_name]
                # Also expose output_values for callers
                resp = {
                    "status": "success",
                    "query": original_sql,
                    "rows_affected": rowcount,
                    "message": "Statement executed",
                }
                if output_values:
                    resp["output_values"] = output_values
                if return_value is not None:
                    resp["return_value"] = return_value
                return resp

            # Otherwise fetch and return result set (rare for execute but handle)
            columns = [col[0] for col in cursor.description] if cursor.description else []
            rows = []
            for row in cursor:
                processed_row = [self._process_value(v) for v in row]
                rows.append(processed_row)

            # Collect any OUT vars as well
            output_values = {}
            for name, var in auto_out_vars.items():
                try:
                    val = var.getvalue() if hasattr(var, "getvalue") else var
                    output_values[name] = self._process_value(val)
                except Exception:
                    output_values[name] = None

            return {
                "status": "success",
                "query": original_sql,
                "columns": columns,
                "rows": rows,
                "row_count": len(rows),
                "output_values": output_values if output_values else None,
            }

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error executing statement: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error executing statement: {e}")
            return {"status": "error", "message": f"Error executing statement: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)

    async def explain_plan(self, query: str) -> Dict[str, Any]:
        """
        Get execution plan for a SQL query.

        Args:
            query: SQL query to explain

        Returns:
            dict: Execution plan details
        """
        conn = None
        cursor = None

        try:
            if not query or not query.strip():
                return {"status": "error", "message": "Empty query provided"}

            # Strip trailing semicolons to prevent execution errors (ADR-005)
            original_query = query
            query = query.rstrip().rstrip(';')
            if len(query) < len(original_query):
                self.logger.warning(f"Stripped trailing semicolons from explain plan query: {original_query}")

            conn = self.pool.get_connection()
            cursor = conn.cursor()

            # Explain plan query
            explain_query = f"EXPLAIN PLAN FOR {query}"
            cursor.execute(explain_query)

            # Get the plan
            cursor.execute(
                """
                SELECT * FROM TABLE(DBMS_XPLAN.DISPLAY())
            """
            )

            # Fetch and format the plan
            plan_rows = []
            for row in cursor:
                plan_rows.append(row[0] if row else "")

            cursor.close()

            return {"status": "success", "query": query, "plan": plan_rows}

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(
                f"Database error getting execution plan: {error_obj.message}"
            )
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error getting execution plan: {e}")
            return {
                "status": "error",
                "message": f"Error getting execution plan: {str(e)}",
            }
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)

    async def get_session_info(self) -> Dict[str, Any]:
        """
        Get current session information.

        Returns:
            dict: Session information
        """
        conn = None
        cursor = None

        try:
            conn = self.pool.get_connection()
            cursor = conn.cursor()

            # Get session information
            cursor.execute(
                """
                SELECT 
                    USER as username,
                    SYS_CONTEXT('USERENV', 'SESSION_USER') as session_user,
                    SYS_CONTEXT('USERENV', 'SESSION_SCHEMA') as session_schema,
                    SYS_CONTEXT('USERENV', 'SESSIONID') as session_id,
                    SYS_CONTEXT('USERENV', 'INSTANCE') as instance,
                    SYS_CONTEXT('USERENV', 'SERVICE_NAME') as service_name,
                    SYS_CONTEXT('USERENV', 'HOST') as host,
                    SYS_CONTEXT('USERENV', 'IP_ADDRESS') as ip_address,
                    SYSDATE as current_time
                FROM DUAL
            """
            )

            row = cursor.fetchone()
            if row:
                columns = [col[0] for col in cursor.description]
                session_info = dict(zip(columns, row))
            else:
                session_info = {}

            # Get pool status
            pool_status = self.pool.get_pool_status()

            cursor.close()

            return {
                "status": "success",
                "session": session_info,
                "pool": pool_status,
                "limits": {
                    "max_rows": self.max_rows,
                    "max_text_size": self.max_text_size,
                },
            }

        except Exception as e:
            self.logger.error(f"Error getting session info: {e}")
            return {
                "status": "error",
                "message": f"Error getting session info: {str(e)}",
            }
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)

    def _process_value(self, value: Any) -> Any:
        """
        Process a database value for JSON serialization.

        Args:
            value: Raw database value

        Returns:
            Any: Processed value suitable for JSON
        """
        if value is None:
            return None

        # Handle LOB types
        if isinstance(value, cx_Oracle.LOB):
            try:
                lob_value = value.read()
                if isinstance(lob_value, bytes):
                    lob_value = lob_value.decode("utf-8", errors="replace")
 
                # Return full LOB content (no forced truncation in test environment)
                return lob_value
            except Exception as e:
                self.logger.warning(f"Error reading LOB value: {e}")
                return "[LOB read error]"

        # Handle datetime objects (cx_Oracle may return Python datetime objects)
        if isinstance(value, datetime.datetime):
            try:
                return value.isoformat()
            except Exception:
                return str(value)

        # Handle large strings: return full string (no forced truncation)
        if isinstance(value, str):
            return value
 
        return value

    async def query(
        self,
        query: str,
        params: Optional[Dict[str, Any]] = None,
        max_rows: Optional[int] = None,
        timeout_s: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Execute a SQL query with parameter binding and safety limits.

        Args:
            query: SQL query to execute
            params: Dictionary of parameters for query binding
            max_rows: Maximum number of rows to return (overrides default)
            timeout_s: Query timeout in seconds (overrides default)

        Returns:
            dict: Query results with columns, rows, and metadata
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not query or not query.strip():
                return {"status": "error", "message": "Empty query provided"}

            # Strip trailing semicolons to prevent execution errors (ADR-005)
            # Do NOT strip semicolons for PL/SQL anonymous blocks or DDL/CREATE statements,
            # because removing the final semicolon breaks PL/SQL block syntax and causes PLS-00103.
            original_query = query
            _lead = original_query.lstrip()
            _lead_up = _lead[:20].upper() if _lead else ""
            if not (
                _lead_up.startswith("BEGIN")
                or _lead_up.startswith("DECLARE")
                or _lead_up.startswith("CREATE")
                or _lead_up.startswith("CALL")
                or _lead_up.startswith("REPLACE")
            ):
                query = query.rstrip().rstrip(';')
                if len(query) < len(original_query):
                    self.logger.warning(f"Stripped trailing semicolons from query: {original_query}")

            # In this test environment we allow full queries without enforced row/text limits.
            # Execute without imposing server-side max_rows or forced timeouts unless explicitly provided.
            timeout_ms = None
            if timeout_s is not None:
                try:
                    timeout_ms = int(timeout_s) * 1000
                except Exception:
                    timeout_ms = None
 
            # Get connection from pool
            conn = self.pool.get_connection()
 
            # Apply callTimeout only if explicitly requested by caller
            if timeout_ms:
                try:
                    conn.callTimeout = timeout_ms
                except Exception:
                    # If driver doesn't accept it, continue without failing
                    pass
 
            cursor = conn.cursor()
 
            self.logger.info(f"Executing query: {query[:200]}...")
 
            # Execute with parameter binding if provided
            if params:
                cursor.execute(query, params)
            else:
                cursor.execute(query)
 
            # If the statement did not produce a result set (e.g. DDL, DML without RETURNING, many PL/SQL blocks),
            # cursor.description will be None. Treat these as successful executions.
            if cursor.description is None:
                # Try to commit if the connection supports it; some test environments may be read-only
                try:
                    conn.commit()
                except Exception:
                    # Non-fatal if commit not applicable in this environment
                    pass
 
                # Determine rowcount if driver provides it (may be -1 or None)
                rowcount = getattr(cursor, "rowcount", -1)
                if rowcount is None or rowcount < 0:
                    rowcount = 0
 
                cursor.close()
 
                return {
                    "status": "success",
                    "query": query,
                    "message": "Statement executed",
                    "columns": [],
                    "rows": [],
                    "row_count": rowcount,
                    "total_rows": rowcount,
                    "truncated": False,
                    "limit": None,
                    "timeout_s": (timeout_ms // 1000) if timeout_ms else None,
                }
 
            # Get column information for queries that return result sets
            columns = [col[0] for col in cursor.description] if cursor.description else []
 
            # Fetch all results (no forced limit)
            rows = []
            for row in cursor:
                processed_row = [self._process_value(v) for v in row]
                rows.append(dict(zip(columns, processed_row)))
 
            row_count = len(rows)
 
            # Attempt to get total_rows from cursor if provided by driver; otherwise use fetched count
            total_rows = cursor.rowcount if getattr(cursor, "rowcount", -1) >= 0 else row_count
 
            cursor.close()
 
            # Format result using util/json.py
            result = {
                "status": "success",
                "query": query,
                "columns": columns,
                "rows": rows,
                "row_count": row_count,
                "total_rows": total_rows,
                "truncated": False,
                "limit": None,
                "timeout_s": (timeout_ms // 1000) if timeout_ms else None,
            }
 
            # Use util/json.py for proper formatting and serialization
            from util.json import format_query_result
 
            formatted_result = format_query_result(result)
 
            return formatted_result

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error executing query: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error executing query: {e}")
            return {"status": "error", "message": f"Error executing query: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)
