#  @author Silia Taider
#  @date 2023-08

################################################################################
# Copyright (C) 1995-2023, Rene Brun and Fons Rademakers.                      #
# All rights reserved.                                                         #
#                                                                              #
# For the licensing terms see $ROOTSYS/LICENSE.                                #
# For the list of contributors see $ROOTSYS/README/CREDITS.                    #
################################################################################
from __future__ import annotations

from typing import Dict, Optional, Callable, Any
import ast
import inspect
import warnings
from functools import singledispatch, partial


@singledispatch   
def LiveVisualize(drawable_callback_dict: Dict[Any, Optional[Callable]], 
                  global_callback: Optional[Callable] = None) -> None:
    """
    Enables real-time data representation for the given drawable objects.
    The objects are drawn and updated every time a partial result returns from distributed execution.

    Args:
        drawable_callback_dict (dict): A dictionary where keys are drawable objects 
        and values are optional corresponding callback functions. 

        global_callback (function): An optional global callback function that 
            is applied to all drawable objects.

    Raises:
        ValueError: If a passed drawable object is not valid.
    """
    # Check if the objects already have a value (the computation graph has already been triggered)
    if any(obj.proxied_node.value for obj in drawable_callback_dict):
        warnings.warn("LiveVisualize should be called before triggering the computation graph. Skipping live visualization.")
        return

    # Check if all drawables share the same headnode
    unique_headnodes = list({obj.proxied_node.get_head() for obj in drawable_callback_dict})
    if len(unique_headnodes) > 1:
        raise ValueError("Live visualization is not supported for operations belonging to different RDataFrame graphs.")
    
    global_callback_validated = process_callback(global_callback)
    drawable_id_callback_dict = {
        # Key: node_id of the drawable object's proxied_node
        # Value: List of validated callback functions for the drawable object
        obj.proxied_node.node_id: [process_callback(callback), global_callback_validated]
        for obj, callback in drawable_callback_dict.items()
        # Filter: Only include valid drawable objects
        if is_valid_drawable(obj)
    }

    unique_headnodes[0].drawables_dict = drawable_id_callback_dict


@LiveVisualize.register(list)
@LiveVisualize.register(tuple)
def _1(drawables, callback: Optional[Callable] = None) -> None:
    """
    Wrapper function to facilitate calling LiveVisualize with a list or a tuple of drawable objects.

    Args:
        drawables (list | tuple): Drawable objects to visualize.
		
        callback (function): An optional callback function to be applied to the drawable objects.

    Notes:
        This function constructs a dictionary of drawable objects and their associated callback functions,
        and then calls the main LiveVisualize function with the constructed dictionary.
    """
    if callback is None:
        drawable_callback_dict = {obj: None for obj in drawables}
    else:
        drawable_callback_dict = {obj: callback for obj in drawables}

    LiveVisualize(drawable_callback_dict)


def process_callback(callback: Callable) -> Callable:
    """
    Process and validate a callback function.

    Args:		
        callback: The callback function to be validated.

    Returns:
        validated_callback: The validated callback function, or None if not valid.
    """
    if callback is None:
        return None
		
    elif not callable(callback):
        warnings.warn("The provided callback is not callable. Skipping callback.")
        return None

    elif not has_correct_argument_count(callback):
        warnings.warn("The callback function should have exactly one parameter to fill. Skipping callback.")
        return None
    
    elif not is_callback_safe(callback):
        warnings.warn("The provided callback function contains blocked actions. Skipping callback.")
        return None    
        
    return callback


def has_correct_argument_count(callback: Callable) -> bool:
    """
    Checks if the provided callback function has exactly one unfilled argument.

    Args:
        callback (Callable): The callback function to check.

    Returns:
        bool: True if the callback function has exactly one unfilled argument,
              False otherwise.
    """
    # Get the values of the functions parameters
    unfilled_parameters = [param for param in inspect.signature(callback).parameters.values() if param.default == param.empty]
    if len(unfilled_parameters) != 1:
            return False
    
    return True


def is_callback_safe(callback: Callable) -> bool:
    """
    Checks if the provided callback function is safe for live visualization, 
        (does not contain blocked actions).

    Args:
        callback (function): The callback function to check.

    Returns:
        bool: True if the callback function is safe, 
            False otherwise.
    """
    # Parse the callback function's source code
    if isinstance(callback, partial):
            callback = callback.func
    callback_source_ast = ast.parse(inspect.getsource(callback))
    for node in ast.walk(callback_source_ast):
        if is_action_blocked(node):
            return False
        
    return True


def is_action_blocked(node: ast.AST) -> bool:
    """
    Checks if the given Abstract Syntax Tree (AST) node corresponds to a blocked action.

    Args:
        node (ast.AST): The AST node to check.

    Returns:
        bool: True if the AST node corresponds to a blocked action,
              False otherwise.
    """
    BLOCKED_ACTIONS = [ "Add", "AddBinContent", "BufferFill", "Build",  "ClearUnderflowAndOverflow", 
                       "Delete", "Divide", "DoFillN", "Fill", "FillN",  "FillRandom", "LabelsDeflate", 
                       "Merge",  "Multiply", "Rebin", "Rebuild", "RecursiveRemove", "Reset", "Scale", 
                       "SetBinContent", "SetBinError", "SetBins", "SetBinsLength",  "SetBuffer", 
                       "SetCellContent", "SetCellError", "SetContent", "SetDirectory", "SetEntries", 
                       "SetError", "SetMaximum", "SetMinimum",  "Smooth", "TransformHisto", 
                       "UpdateBinContent"]

    if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr in BLOCKED_ACTIONS:
                return True
            
    return False

	
def is_valid_drawable(obj: Any) -> bool:
    """
    Checks if the object is a valid drawable object for live visualization.

    Args:
        obj: The object to be checked.

    Returns:
        bool: True if the object is a valid drawable object for live visualization 
            according to the ALLOWED_OPERATIONS list , False otherwise.
    """
    ALLOWED_OPERATIONS = ["Histo1D", "Histo2D", "Histo3D",
                          "Graph", "GraphAsymmErrors",
                          "Profile1D", "Profile2D"]
    
    if hasattr(obj, "proxied_node") and hasattr(obj.proxied_node, "operation") and hasattr(obj.proxied_node.operation, "name"):
        if obj.proxied_node.operation.name in ALLOWED_OPERATIONS:
            return True
    
    raise ValueError(f"Allowed operations are: {ALLOWED_OPERATIONS}. Skipping live visualization.")