#!/usr/bin/env python3
"""
Oracle Database PL/SQL Operations

This module handles PL/SQL block execution, stored procedure creation,
procedure calling, and testing functionality.
"""

import os
import logging
from typing import Any, Dict, List, Optional, Union

import cx_Oracle
import datetime

from server.oracle.pool import OraclePool
from util.json import OracleJSONEncoder


class OraclePLSQL:
    """
    Handles PL/SQL operations including block execution, stored procedure
    creation, procedure calling, and testing.
    """

    def __init__(self, pool: OraclePool):
        """
        Initialize the PL/SQL handler with a connection pool.

        Args:
            pool: OraclePool instance
        """
        self.pool = pool
        self.logger = logging.getLogger(__name__)
        self.max_text_size = int(os.getenv("MCP_MAX_TEXT_SIZE", "5000"))

    async def exec_plsql(
        self,
        plsql_block: str,
        params: Optional[Dict[str, Any]] = None,
        timeout_s: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Execute a PL/SQL block with autocommit.

        Args:
            plsql_block: PL/SQL block to execute
            params: Dictionary of parameters for bind variables
            timeout_s: Query timeout in seconds (overrides default)

        Returns:
            dict: Execution results with status and any output values
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not plsql_block or not plsql_block.strip():
                return {"status": "error", "message": "Empty PL/SQL block provided"}

            # Use provided timeout or default
            timeout_ms = (
                timeout_s
                if timeout_s is not None
                else int(os.getenv("MCP_TIMEOUT_S", "30"))
            ) * 1000

            # Get connection from pool
            conn = self.pool.get_connection()

            # Set autocommit for PL/SQL execution
            conn.autocommit = True

            # Set timeout on connection
            conn.callTimeout = timeout_ms

            cursor = conn.cursor()

            self.logger.info(f"Executing PL/SQL block: {plsql_block[:100]}...")

            # Execute PL/SQL block with parameter binding if provided
            if params:
                cursor.execute(plsql_block, params)
            else:
                cursor.execute(plsql_block)

            # Get any output variables or return values
            output_values = {}
            if cursor.description:
                # Get output bind variables if any
                try:
                    # Try to get any OUT or IN OUT bind variables
                    for bind_name in cursor.bindnames():
                        if bind_name.startswith(":"):
                            bind_var_name = bind_name[1:]  # Remove leading ':'
                            try:
                                output_values[bind_var_name] = cursor.getvalue(
                                    bind_name
                                )
                            except:
                                # Some bind variables might not have values
                                pass
                except:
                    # Not all cursors support bindnames/getvalue
                    pass

            cursor.close()

            result = {
                "status": "success",
                "plsql_block": plsql_block,
                "message": "PL/SQL block executed successfully",
                "output_values": output_values if output_values else None,
                "timeout_s": timeout_ms // 1000 if timeout_ms else None,
            }

            # Process output values for JSON serialization
            if output_values:
                processed_output = {}
                for key, value in output_values.items():
                    processed_output[key] = self._process_value(value)
                result["output_values"] = processed_output

            return result

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error executing PL/SQL: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error executing PL/SQL: {e}")
            return {"status": "error", "message": f"Error executing PL/SQL: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                # Reset autocommit for connection pool
                try:
                    conn.autocommit = False
                except:
                    pass
                self.pool.release_connection(conn)

    async def create_or_replace_proc(
        self, procedure_name: str, procedure_text: str, timeout_s: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Create or replace a stored procedure from text.

        Args:
            procedure_name: Name of the procedure to create
            procedure_text: Complete procedure definition text
            timeout_s: Query timeout in seconds (overrides default)

        Returns:
            dict: Creation results with status and DDL information
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not procedure_name or not procedure_name.strip():
                return {"status": "error", "message": "Procedure name is required"}

            if not procedure_text or not procedure_text.strip():
                return {"status": "error", "message": "Procedure text is required"}

            # Use provided timeout or default
            timeout_ms = (
                timeout_s
                if timeout_s is not None
                else int(os.getenv("MCP_TIMEOUT_S", "60"))  # Longer timeout for DDL
            ) * 1000

            # Get connection from pool
            conn = self.pool.get_connection()
            conn.autocommit = True
            conn.callTimeout = timeout_ms

            cursor = conn.cursor()

            self.logger.info(f"Creating/replacing procedure: {procedure_name}")

            # First, try to drop the procedure if it exists
            drop_sql = f"DROP PROCEDURE {procedure_name}"
            try:
                cursor.execute(drop_sql)
                self.logger.info(f"Dropped existing procedure: {procedure_name}")
            except cx_Oracle.DatabaseError:
                # Procedure might not exist, which is fine
                self.logger.info(
                    f"Procedure {procedure_name} does not exist, will create new"
                )

            # Create the procedure
            cursor.execute(procedure_text)

            # Verify the procedure was created
            verify_sql = """
                SELECT OBJECT_NAME, OBJECT_TYPE, STATUS, CREATED 
                FROM USER_OBJECTS 
                WHERE OBJECT_NAME = :1 AND OBJECT_TYPE = 'PROCEDURE'
            """
            cursor.execute(verify_sql, [procedure_name])
            result_row = cursor.fetchone()

            if result_row:
                columns = [col[0] for col in cursor.description]
                procedure_info = dict(zip(columns, result_row))
            else:
                procedure_info = {}

            cursor.close()

            return {
                "status": "success",
                "procedure_name": procedure_name,
                "message": f"Procedure {procedure_name} created successfully",
                "procedure_info": procedure_info,
                "timeout_s": timeout_ms // 1000 if timeout_ms else None,
            }

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error creating procedure: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error creating procedure: {e}")
            return {"status": "error", "message": f"Error creating procedure: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                try:
                    conn.autocommit = False
                except:
                    pass
                self.pool.release_connection(conn)

    async def call_proc(
        self,
        procedure_name: str,
        params: Optional[Dict[str, Any]] = None,
        timeout_s: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Call a stored procedure with parameters.

        Args:
            procedure_name: Name of the procedure to call
            params: Dictionary of parameters (IN, OUT, IN OUT)
            timeout_s: Query timeout in seconds (overrides default)

        Returns:
            dict: Call results with output parameters and return values
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not procedure_name or not procedure_name.strip():
                return {"status": "error", "message": "Procedure name is required"}

            # Use provided timeout or default
            timeout_ms = (
                timeout_s
                if timeout_s is not None
                else int(os.getenv("MCP_TIMEOUT_S", "30"))
            ) * 1000

            # Get connection from pool
            conn = self.pool.get_connection()
            conn.callTimeout = timeout_ms

            cursor = conn.cursor()

            self.logger.info(f"Calling procedure: {procedure_name}")

            # Call the procedure and capture OUT/IN OUT parameters.
            # Build an anonymous PL/SQL block with bind variables to reliably capture OUT values.
            output_values = {}
            if params:
                param_order = list(params.keys())
                binds = {}
                # Infer numeric output type if any IN param is numeric
                any_numeric_in = any(
                    isinstance(v, (int, float)) for v in params.values() if v is not None
                )
                for pname in param_order:
                    pval = params[pname]
                    if pval is None:
                        # Create an OUT bind; prefer NUMBER if we expect numeric outputs
                        try:
                            if any_numeric_in:
                                binds[pname] = cursor.var(cx_Oracle.NUMBER)
                            else:
                                binds[pname] = cursor.var(cx_Oracle.STRING, 4000)
                        except Exception:
                            # Fallback to generic string
                            binds[pname] = cursor.var(cx_Oracle.STRING, 4000)
                    else:
                        binds[pname] = pval
 
                # Build PL/SQL anonymous block: BEGIN proc(:p1, :p2, ...); END;
                call_sql = (
                    f"BEGIN {procedure_name}(" + ", ".join([f":{n}" for n in param_order]) + "); END;"
                )
                try:
                    cursor.execute(call_sql, binds)
                except Exception:
                    # As a final fallback, try cursor.callproc with positional binds
                    try:
                        pos_binds = [binds[n] for n in param_order]
                        call_result = cursor.callproc(procedure_name, pos_binds)
                        # Extract OUTs from call_result where original param was None
                        for idx, pname in enumerate(param_order):
                            if params[pname] is None:
                                output_values[pname] = call_result[idx]
                    except Exception as e:
                        self.logger.error(f"Procedure call failed: {e}")
                        raise
 
                # After execute, extract values from binds for OUT params
                for pname in param_order:
                    if params[pname] is None:
                        val = binds[pname].getvalue() if hasattr(binds[pname], "getvalue") else binds[pname]
                        output_values[pname] = val
            else:
                # No parameters: simple anonymous block call
                try:
                    cursor.execute(f"BEGIN {procedure_name}; END;")
                except Exception as e:
                    self.logger.error(f"Procedure call failed: {e}")
                    raise

            cursor.close()

            result = {
                "status": "success",
                "procedure_name": procedure_name,
                "message": f"Procedure {procedure_name} called successfully",
                "output_values": output_values if output_values else None,
                "timeout_s": timeout_ms // 1000 if timeout_ms else None,
            }

            # Process output values for JSON serialization
            if output_values:
                processed_output = {}
                for key, value in output_values.items():
                    processed_output[key] = self._process_value(value)
                result["output_values"] = processed_output

            return result

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error calling procedure: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error calling procedure: {e}")
            return {"status": "error", "message": f"Error calling procedure: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                self.pool.release_connection(conn)

    async def run_test(
        self,
        test_name: str,
        test_procedure: str,
        test_params: Optional[Dict[str, Any]] = None,
        timeout_s: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Simple procedure testing functionality.

        Args:
            test_name: Name for the test (used for procedure creation)
            test_procedure: Complete procedure definition for testing
            test_params: Parameters to pass to the test procedure
            timeout_s: Query timeout in seconds (overrides default)

        Returns:
            dict: Test results with execution details and any output
        """
        conn = None
        cursor = None

        try:
            # Validate input
            if not test_name or not test_name.strip():
                return {"status": "error", "message": "Test name is required"}

            if not test_procedure or not test_procedure.strip():
                return {"status": "error", "message": "Test procedure text is required"}

            # Use provided timeout or default
            timeout_ms = (
                timeout_s
                if timeout_s is not None
                else int(os.getenv("MCP_TIMEOUT_S", "60"))  # Longer timeout for testing
            ) * 1000

            # Get connection from pool
            conn = self.pool.get_connection()
            conn.autocommit = True
            conn.callTimeout = timeout_ms

            cursor = conn.cursor()

            self.logger.info(f"Running test: {test_name}")

            # Create a unique procedure name for the test
            proc_name = f"TEST_{test_name.upper().replace(' ', '_')}_{id(self)}"

            try:
                # Drop existing test procedure if it exists
                drop_sql = f"DROP PROCEDURE {proc_name}"
                cursor.execute(drop_sql)
            except cx_Oracle.DatabaseError:
                # Procedure might not exist, which is fine
                pass

            # Create the test procedure
            cursor.execute(test_procedure)

            # Call the test procedure (execute as a full PL/SQL anonymous block)
            call_sql = f"BEGIN {proc_name}"
            if test_params:
                param_list = [f":{param}" for param in test_params.keys()]
                call_sql += "(" + ", ".join(param_list) + ")"
                # Close the anonymous block
                call_sql += "; END;"
                cursor.execute(call_sql, test_params)
            else:
                call_sql += "; END;"
                cursor.execute(call_sql)

            # Clean up - drop the test procedure
            try:
                cursor.execute(drop_sql)
            except cx_Oracle.DatabaseError:
                # Ignore cleanup errors
                pass

            cursor.close()

            return {
                "status": "success",
                "test_name": test_name,
                "procedure_name": proc_name,
                "message": f"Test '{test_name}' completed successfully",
                "test_params": test_params,
                "timeout_s": timeout_ms // 1000 if timeout_ms else None,
            }

        except cx_Oracle.DatabaseError as e:
            (error_obj,) = e.args
            self.logger.error(f"Database error running test: {error_obj.message}")
            return {
                "status": "error",
                "message": f"Database error: {error_obj.message}",
                "code": error_obj.code,
            }
        except Exception as e:
            self.logger.error(f"Error running test: {e}")
            return {"status": "error", "message": f"Error running test: {str(e)}"}
        finally:
            if cursor:
                try:
                    cursor.close()
                except:
                    pass
            if conn:
                try:
                    conn.autocommit = False
                except:
                    pass
                self.pool.release_connection(conn)

    def _process_value(self, value: Any) -> Any:
        """
        Process a database value for JSON serialization.

        Args:
            value: Raw database value

        Returns:
            Any: Processed value suitable for JSON
        """
        if value is None:
            return None

        # Handle LOB types
        if isinstance(value, cx_Oracle.LOB):
            try:
                lob_value = value.read()
                if isinstance(lob_value, bytes):
                    lob_value = lob_value.decode("utf-8", errors="replace")

                # Truncate if too long
                if len(lob_value) > self.max_text_size:
                    lob_value = lob_value[: self.max_text_size] + "... [truncated]"

                return lob_value
            except Exception as e:
                self.logger.warning(f"Error reading LOB value: {e}")
                return "[LOB read error]"

        # Handle datetime objects (cx_Oracle may return Python datetime objects)
        if isinstance(value, datetime.datetime):
            try:
                return value.isoformat()
            except Exception:
                return str(value)

        # Handle large strings
        if isinstance(value, str) and len(value) > self.max_text_size:
            return value[: self.max_text_size] + "... [truncated]"

        return value
