#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun  2 16:21:21 2023

@author: maslyaev
"""

import numpy as np
import json
import pandas as pd

from itertools import product
from typing import Union

from epde.structure.main_structures import SoEq, Equation
from epde.interface.equation_translator import translate_equation, parse_equation_str

def equations_match(eq_checked: Equation, eq_ref: Equation):
    def parse_equation(equation: Equation, eps = 1e-9):
        term_weights = []
        for idx, term in enumerate(equation.structure):
            if idx < equation.target_idx:
                term_weights.append(equation.weights_final[idx])
            elif idx == equation.target_idx:
                term_weights.append(1)
            else:
                term_weights.append(equation.weights_final[idx-1])
        print(term_weights)
        return {frozenset({(factor.label, factor.param(name = 'power')) 
                           for factor in term.structure}) 
                for idx, term in enumerate(equation.structure) if np.abs(term_weights[idx]) > eps}
    
    checked_parsed = parse_equation(eq_checked)
    ref_parsed = parse_equation(eq_ref)
    
    matching = sum([term in checked_parsed for term in ref_parsed])
    missing = len(ref_parsed) - matching
    extra = sum([term not in ref_parsed for term in checked_parsed])
    return (matching, missing, extra)

def systems_match(checked_system: SoEq, reference_system: SoEq):
    return [equations_match(checked_system.vals[var], reference_system.vals[var]) 
            for var in checked_system.vars_to_describe]
        
class Logger():
    def __init__(self, name, referential_equation = None, pool = None):
        self.reset(name = name)
        if isinstance(referential_equation, str) or isinstance(referential_equation, dict):
            if pool is None:
                raise ValueError('Can not translate equations: the pool of tokens is missing')
            referential_equation = translate_equation(referential_equation, pool)
        self._referential_equation = referential_equation
        
    def reset(self, name = None):
        try:
            self.log_out()
        except AttributeError:
            pass
        
        self._log = {}
        self._meta = {'aggregation_key' : []}
        if name is not None:
            self.name = name
        
    def dump(self):
        self._log['meta'] = self._meta
        with open(self.name, 'w') as outfile:
            json.dump(self._log, outfile)
        self.reset()
        
    def add_log(self, key, entry, aggregation_key = None, **kwargs):
        match = systems_match(entry, self._referential_equation) if self._referential_equation is not None else (0, 0, 0)
        try:
            mae = [np.mean(np.abs(eq.evaluate(False, True)[0])) for eq in entry]
        except KeyError:
            mae = 0
        
        log_entry = {'equation_form': entry.text_form,
                     'term_match': match,
                     'mae_train': mae,
                     'aggregation_key': aggregation_key
                     }

        if aggregation_key not in self._meta['aggregation_key']:
            self._meta['aggregation_key'].append(aggregation_key)
        log_entry = {**log_entry, **kwargs}
        self._log[key] = log_entry
        
class LogLoader(object):
    '''
    Object for the basic analytics of the equation discovery process
    '''
    def __init__(self, filename: Union[str, list]):
        self.reset()
        if isinstance(filename, str):
            file = open(filename, 'r')    
            self._log.append(json.load(file))
            file.close()
        else:
            for specific_filename in filename:
                file = open(specific_filename, 'r')    
                self._log.append(json.load(file))
                file.close()

    def reset(self):
        self._log = []
        self._variables = None
    
    @staticmethod
    def eq_analytics(equation_string: str):
        eq_terms = parse_equation_str(equation_string)
        term_stats = {frozenset(eq_terms[-1]) : 1.}
        for term in eq_terms[:-2]:
            term_stats[frozenset(term[1:])] = str(term[0])
        return term_stats

    def get_aggregation_keys(self):
        keys = []
        for log in self._log:
            keys.append(log['meta']['aggregation_key'])
        return keys

    def obtain_parsed_log(self, variables: list = ['u',], aggregation_key: tuple = None):
        if self._variables is None:
            self._variables = variables
        else:
            assert self._variables == variables
            
        def parse_system_str(system_string: str):
            def strap_cases(eq_string: str):
                return eq_string.replace(' / ', '').replace(' | ', '').replace(' \\ ', '')
            
            if '/' in system_string:
                return [strap_cases(eq_string) for eq_string in system_string.split(sep = '\n')[:-1]]
            else:
                return system_string.split(sep = '\n')[:-1]
        
        term_presence_log = {key : {} for key in range(len(variables))} # replaced self._term_presence_log
        for log_entry in self._log:
            for exp_key, exp_log in log_entry.items():
                if exp_key == 'meta':
                    continue
                
                if not 'aggregation_key' in exp_log.keys() or aggregation_key != exp_log['aggregation_key']:
                    continue
                
                
                system_list = parse_system_str(exp_log['equation_form'])
                assert len(variables) == len(system_list)
                for eq_idx in range(len(system_list)):
                    term_stats = self.eq_analytics(system_list[eq_idx])
                    for key, value in term_stats.items():
                        if key in term_presence_log[eq_idx].keys():
                            term_presence_log[eq_idx][key].append(float(value))
                        else:
                            term_presence_log[eq_idx][key] = [float(value),]
        return term_presence_log
        
    @staticmethod
    def get_stats(terms: Union[list, tuple], parsed_log: dict, metrics: list = [np.mean, np.var, np.size], 
                  metric_names: list = ['mean', 'var', 'disc_num'], variable_eq: str = 'u', variables = ['u',]):
        assert all([isinstance(term, frozenset) for term in terms])
        stats = []
        for term in terms:
            if term in parsed_log[variables.index(variable_eq)].keys():
                term_coeff_vals = np.array(parsed_log[variables.index(variable_eq)][term])
                stats.append([metric(np.abs(term_coeff_vals[term_coeff_vals != 0])) 
                              for metric in metrics])
            else:
                stats.append([np.nan for metric in metrics])

        stats = np.array(stats).reshape(-1)

        label_sep = '_'; term_sep = '*'
        term_names = [term_sep.join(term) for term in terms]
        labels = [label_sep.join(map(str, x)) for x in product(*[term_names, metric_names])]
        print(metric_names)
        print(labels)
        
        return labels, stats
    
    def to_pandas(self, terms: Union[list, tuple], metrics: list = [np.mean, np.var, np.size], metric_names: list = ['mean', 'var', 'disc_num'], 
                  variable: str = 'u', variables: list = ['u',]):
        metric_frames = []
        for log_entry in self._log:
            aggregation_keys = log_entry['meta']['aggregation_key']
            data = []
            row_labels = ['_'.join(map(str, key)) for key in aggregation_keys]
            for aggr_key in aggregation_keys:
                parsed_log = self.obtain_parsed_log(variables=variables, aggregation_key=aggr_key)
                keys, stats = self.get_stats(terms = terms, parsed_log = parsed_log, metrics = metrics, metric_names = metric_names, 
                                             variable_eq = variable, variables = variables)
                data.append({keys[idx] : stats[idx] for idx in range(len(keys))})
                
            
            metric_frames.append(pd.DataFrame(data, index = row_labels))
        return metric_frames