"""Utility functions for clusterings and embeddings
"""

from __init__scr import *
# from sklearn.decomposition import PCA
import time
import numpy as np
import pandas as pd
import igraph as ig
from scipy import sparse
from annoy import AnnoyIndex
import leidenalg
#from umap import UMAP
import umap.umap_ as UMAP
import logging

from utils import create_logger

# major change in annoy functions 5/7/2019 
def build_knn_map(X, metric='euclidean', n_trees=10, verbose=True):
    """X is expected to have low feature dimensions (n_obs, n_features) with (n_features <= 50)

    return:
         t: annoy knn object, can be used in the following ways 
                t.get_nns_by_vector
                t.get_nns_by_item
    """
    ti = time.time()

    n_obs, n_f = X.shape
    t = AnnoyIndex(n_f, metric=metric)  # Length of item vector that will be indexed
    for i, X_row in enumerate(X):
        t.add_item(i, X_row)
    t.build(n_trees) # 10 trees
    if verbose:
        print("Time used to build kNN map {}".format(time.time()-ti))
    return t 

def get_knn_by_items(t, k, 
    form='list', 
    search_k=-1, 
    include_distances=False,
    verbose=True, 
    ):
    """Get kNN for each item in the knn map t
    """
    ti = time.time()
    # set up
    n_obs = t.get_n_items()
    n_f = t.f
    if k > n_obs:
        print("Actual k: {}->{} due to low n_obs".format(k, n_obs))
        k = n_obs

    knn = [0]*(n_obs)
    knn_dist = [0]*(n_obs)
    # this block of code can be optimized
    if include_distances:
        for i in range(n_obs):
            res = t.get_nns_by_item(i, k, search_k=search_k, include_distances=include_distances)
            knn[i] = res[0]
            knn_dist[i] = res[1]
    else:
        for i in range(n_obs):
            res = t.get_nns_by_item(i, k, search_k=search_k, include_distances=include_distances) 
            knn[i] = res

    knn = np.array(knn)
    knn_dist = np.array(knn_dist)

    if verbose:
        print("Time used to get kNN {}".format(time.time()-ti))

    if form == 'adj':
        # row col 1/dist 
        row_inds = np.repeat(np.arange(n_obs), k)
        col_inds = np.ravel(knn)
        if include_distances:
            data = np.ravel(knn_dist) 
        else:
            data = [1]*len(row_inds)
        knn_dist_mat = sparse.coo_matrix((data, (row_inds, col_inds)), shape=(n_obs, n_obs))
        return knn_dist_mat
    elif form == 'list':  #
        if include_distances:
            return knn, knn_dist
        else:
            return knn
    else:
        raise ValueError("Choose from 'adj' and 'list'")

def get_knn_by_vectors(t, X, k, 
    form='list', 
    search_k=-1, 
    include_distances=False,
    verbose=True, 
    ):
    """Get kNN for each row vector of X 
    """
    ti = time.time()
    # set up
    n_obs = t.get_n_items()
    n_f = t.f
    n_obs_test, n_f_test = X.shape
    assert n_f_test == n_f

    if k > n_obs:
        print("Actual k: {}->{} due to low n_obs".format(k, n_obs))
        k = n_obs

    knn = [0]*(n_obs_test)
    knn_dist = [0]*(n_obs_test)
    if include_distances:
        for i, vector in enumerate(X):
            res = t.get_nns_by_vector(vector, k, search_k=search_k, include_distances=include_distances) 
            knn[i] = res[0]
            knn_dist[i] = res[1]
    else:
        for i, vector in enumerate(X):
            res = t.get_nns_by_vector(vector, k, search_k=search_k, include_distances=include_distances) 
            knn[i] = res

    knn = np.array(knn)
    knn_dist = np.array(knn_dist)

    if verbose:
        print("Time used to get kNN {}".format(time.time()-ti))

    if form == 'adj':
        # row col 1/dist 
        row_inds = np.repeat(np.arange(n_obs_test), k)
        col_inds = np.ravel(knn)
        if include_distances:
            data = np.ravel(knn_dist) 
        else:
            data = [1]*len(row_inds)
        knn_dist_mat = sparse.coo_matrix((data, (row_inds, col_inds)), shape=(n_obs_test, n_obs))
        return knn_dist_mat
    elif form == 'list':  #
        if include_distances:
            return knn, knn_dist
        else:
            return knn
    else:
        raise ValueError("Choose from 'adj' and 'list'")

def gen_knn_annoy(X, k, form='list', 
    metric='euclidean', n_trees=10, search_k=-1, verbose=True, 
    include_distances=False,
    ):
    """X is expected to have low feature dimensions (n_obs, n_features) with (n_features <= 50)
    """
    ti = time.time()

    n_obs, n_f = X.shape
    t = build_knn_map(X, metric=metric, n_trees=n_trees, verbose=verbose)

    return get_knn_by_items(t, k,                             
                            form=form, 
                            search_k=search_k, 
                            include_distances=include_distances,
                            verbose=verbose, 
                            )

def gen_knn_annoy_train_test(X_train, X_test, k, 
    form='list', 
    metric='euclidean', n_trees=10, search_k=-1, verbose=True, 
    include_distances=False,
    ):
    """X is expected to have low feature dimensions (n_obs, n_features) with (n_features <= 50)
    For each row in X_test, find k nearest neighbors in X_train
    """
    ti = time.time()
    
    n_obs, n_f = X_train.shape
    n_obs_test, n_f_test = X_test.shape
    assert n_f == n_f_test 
    
    t = build_knn_map(X_train, metric=metric, n_trees=n_trees, verbose=verbose)
    return get_knn_by_vectors(t, X_test, k, 
                                form=form, 
                                search_k=search_k, 
                                include_distances=include_distances,
                                verbose=verbose, 
                                )
    
def compute_jaccard_weights_from_knn(X):
    """compute jaccard index on a knn graph
    Arguments: 
        X (unweighted) kNN ajacency matrix (each row Xi* gives the kNNs of cell i) 
        X has to be 0-1 valued 
        k (number of nearest neighbors) 
        
    output: numpy matrix Y
    """
    X = sparse.csr_matrix(X)
    ni, nj = X.shape
    assert ni == nj
    
    k = X[0, :].sum() # number of neighbors
    
    Y = X.dot(X.T)
    # Y = X.multiply(tmp/(2*k - tmp.todense()))    
    Y.data = Y.data/(2*k - Y.data)
    
    return Y 

def adjacency_to_igraph(adj_mtx, weighted=False):
    """
    Converts an adjacency matrix to an igraph object
    
    Args:
        adj_mtx (sparse matrix): Adjacency matrix
        directed (bool): If graph should be directed
    
    Returns:
        G (igraph object): igraph object of adjacency matrix
    
    Uses code from:
        https://github.com/igraph/python-igraph/issues/168
        https://stackoverflow.com/questions/29655111

    Author:
        Wayne Doyle 
        (Fangming Xie modified) 
    """
    nrow, ncol = adj_mtx.shape
    if nrow != ncol:
        raise ValueError('Adjacency matrix should be a square matrix')
    vcount = nrow
    sources, targets = adj_mtx.nonzero()
    edgelist = list(zip(sources.tolist(), targets.tolist()))
    G = ig.Graph(n=vcount, edges=edgelist, directed=True)
    if weighted:
        G.es['weight'] = adj_mtx.data
    return G

def leiden_lite(g, cell_list, resolution=1, weighted=False, verbose=True, num_starts=None, seed=1):
    """ Code from Ethan Armand and Wayne Doyle, ./mukamel_lab/mop
    slightly modified by Fangming Xie 05/13/2019
    """
    
    ti = time.time()
    
    if num_starts is not None:
        np.random.seed(seed)
        partitions = []
        quality = []
        seeds = np.random.randint(10*num_starts, size=num_starts)
        for seed in seeds:
            if weighted:
                temp_partition = leidenalg.find_partition(g,
                                                      leidenalg.RBConfigurationVertexPartition, 
                                                      weights=g.es['weight'],
                                                      resolution_parameter=resolution,
                                                      seed=seed,
                                                      )
            else:
                temp_partition = leidenalg.find_partition(g,
                                                      leidenalg.RBConfigurationVertexPartition,
                                                      resolution_parameter=resolution,
                                                      seed=seed,
                                                      )
            quality.append(temp_partition.quality())
            partitions.append(temp_partition)
        partition1 = partitions[np.argmax(quality)]
    else:
        if weighted:
            partition1 = leidenalg.find_partition(g,
                                                  leidenalg.RBConfigurationVertexPartition,
                                                  weights=g.es['weight'],
                                                  resolution_parameter=resolution,
                                                  seed=seed,
                                                  )
        else:
            partition1 = leidenalg.find_partition(g,
                                                  leidenalg.RBConfigurationVertexPartition,
                                                  resolution_parameter=resolution,
                                                  seed=seed,
                                                  )

    # get cluster labels from partition1
    labels = [0]*(len(cell_list)) 
    for i, cluster in enumerate(partition1):
        for element in cluster:
            labels[element] = i+1

    df_res = pd.DataFrame(index=cell_list)
    df_res['cluster'] = labels 
    df_res = df_res.rename_axis('sample', inplace=False)
    
    if verbose:
        print("Time spent on leiden clustering: {}".format(time.time()-ti))
        
    return df_res

def clustering_routine(X, cell_list, k, 
    seed=1, verbose=True,
    resolution=1, metric='euclidean', option='plain', n_trees=10, search_k=-1, num_starts=None):
    """
    X is a (n_obs, n_feature) matrix, n_feature <=50 is recommended
    option: {'plain', 'jaccard', ...}
    """
    assert len(cell_list) == len(X)
    
    if option == 'plain':
        g_knn = gen_knn_annoy(X, k, form='adj', metric=metric, 
                              n_trees=n_trees, search_k=search_k, verbose=verbose)
        G = adjacency_to_igraph(g_knn, weighted=False)
        df_res = leiden_lite(G, cell_list, resolution=resolution, seed=seed, 
                            weighted=False, verbose=verbose, num_starts=num_starts)
        
    elif option == 'jaccard':
        g_knn = gen_knn_annoy(X, k, form='adj', metric=metric, 
                              n_trees=n_trees, search_k=search_k, verbose=verbose)
        gw_knn = compute_jaccard_weights_from_knn(g_knn)
        G = adjacency_to_igraph(gw_knn, weighted=True)
        df_res = leiden_lite(G, cell_list, resolution=resolution, seed=seed, 
                            weighted=True, verbose=verbose, num_starts=num_starts)
    else:
        raise ValueError('Choose from "plain" and "jaccard"')
    
    return df_res

def clustering_routine_multiple_resolutions(X, cell_list, k, 
    seed=1, verbose=True,
    resolutions=[1], metric='euclidean', option='plain', n_trees=10, search_k=-1, num_starts=None):
    """
    X is a (n_obs, n_feature) matrix, n_feature <=50 is recommended
    option: {'plain', 'jaccard', ...}
    """
    assert len(cell_list) == len(X)
    
    res = []
    if option == 'plain':
        g_knn = gen_knn_annoy(X, k, form='adj', metric=metric, 
                              n_trees=n_trees, search_k=search_k, verbose=verbose)
        G = adjacency_to_igraph(g_knn, weighted=False)
        for resolution in resolutions:
            df_res = leiden_lite(G, cell_list, resolution=resolution, seed=seed, 
                                weighted=False, verbose=verbose, num_starts=num_starts)
            df_res = df_res.rename(columns={'cluster': 'cluster_r{}'.format(resolution)})
            res.append(df_res)
        
    elif option == 'jaccard':
        g_knn = gen_knn_annoy(X, k, form='adj', metric=metric, 
                              n_trees=n_trees, search_k=search_k, verbose=verbose)
        gw_knn = compute_jaccard_weights_from_knn(g_knn)
        G = adjacency_to_igraph(gw_knn, weighted=True)
        for resolution in resolutions:
            df_res = leiden_lite(G, cell_list, resolution=resolution, seed=seed, 
                                weighted=True, verbose=verbose, num_starts=num_starts)
            df_res = df_res.rename(columns={'cluster': 'cluster_r{}'.format(resolution)})
            res.append(df_res)
        
    else:
        raise ValueError('Choose from "plain" and "jaccard"')
    res = pd.concat(res, axis=1)
    
    return res

def run_umap_lite(X, cell_list, n_neighbors=15, min_dist=0.1, n_dim=2, 
             random_state=1, output_file=None, **kwargs):
    """run umap on X (n_obs, n_features) 
    """
    # from sklearn.decomposition import PCA

    ti = time.time()

    logging.info("Running UMAP: {} n_neighbors, {} min_dist , {} dim.\n\
                  Input shape: (# observations, # features) = {}"
                        .format(n_neighbors, min_dist, n_dim, X.shape))
    
    umap = UMAP.UMAP(n_components=n_dim, random_state=random_state, 
                n_neighbors=n_neighbors, min_dist=min_dist, **kwargs)
    ts = umap.fit_transform(X)

    columns = ['umap_{}'.format(i+1) for i in np.arange(n_dim)]
    df_umap = pd.DataFrame(ts, columns=columns)
    df_umap['sample'] = cell_list 
    df_umap = df_umap.set_index('sample')
    
    if output_file:
        df_umap.to_csv(output_file, sep="\t", na_rep='NA', header=True, index=True)
        logging.info("Saved coordinates to file. {}".format(output_file))

    tf = time.time()
    logging.info("Done. running time: {} seconds.".format(tf - ti))
    
    return df_umap
