import os
import pickle
import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm

REQUIRED_PACKAGES = ["numpy", "matplotlib", "pickle", "tqdm", "sklearn", "os", "datetime"]

def explore_kmeans_clustering(
    ngram_texts,
    path_dset,
    dset_name,
    min_clusters=2,
    max_clusters=10,
    step=1,
    max_features=5000,
    min_df=3,
    max_df=0.8
):
    """
    Perform K-means clustering with TF-IDF vectors over a range of k values (clusters).
    Computes both Inertia (Elbow method) and Silhouette scores for each k, but selects
    the final model based on the best (lowest) Inertia.

    Args:
        ngram_texts (list of list of str): Tokenized text data (e.g., unigrams + bigrams).
        path_dset (str): Output directory path where files are saved.
        dset_name (str): Prefix or identifier for saved files.
        min_clusters (int): Minimum number of clusters to evaluate.
        max_clusters (int): Maximum number of clusters to evaluate.
        step (int): Step size for the range of clusters.
        max_features (int): Maximum vocabulary size for the TF-IDF vectorizer.
        min_df (int or float): When building the vocabulary, ignore terms that occur in 
                               fewer than min_df documents (int) or in less than min_df 
                               fraction (float). Defaults to 3.
        max_df (float): Ignore terms that occur in more than max_df fraction of documents.
                        Can be an int for an absolute cutoff. Defaults to 0.8.

    Returns:
        dict: Paths and info about saved artifacts (model, plots, assignments, etc.).
    """

    # Generate timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M")
    
    # Create output directory if it doesn't exist
    output_dir = os.path.join(path_dset)
    os.makedirs(output_dir, exist_ok=True)

    # Convert each tokenized doc into a single string for TF-IDF
    docs_as_strings = [" ".join(doc) for doc in ngram_texts]

    # Build TF-IDF matrix with min_df and max_df
    vectorizer = TfidfVectorizer(
        max_features=max_features,
        min_df=min_df,
        max_df=max_df
    )
    tfidf_matrix = vectorizer.fit_transform(docs_as_strings)

    # Range of cluster numbers
    cluster_range = range(min_clusters, max_clusters + 1, step)

    # Store both inertia and silhouette for each k
    inertia_values = []
    silhouette_values = []

    # Compute KMeans once per k, then record both metrics
    for k in tqdm(cluster_range, desc="K-means Clustering"):
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(tfidf_matrix)

        # Inertia (always defined)
        inertia = kmeans.inertia_
        inertia_values.append(inertia)

        # Silhouette (undefined for k=1, so set to np.nan)
        if k > 1:
            labels = kmeans.labels_
            sil = silhouette_score(tfidf_matrix, labels)
        else:
            sil = np.nan

        silhouette_values.append(sil)

    # Determine best k according to Inertia (Elbow method: smallest inertia)
    best_k_inertia_index = np.argmin(inertia_values)
    best_k_inertia = list(cluster_range)[best_k_inertia_index]

    # Also find best k for Silhouette (max silhouette), purely for reference/plot
    valid_silhouettes = [(idx, val) for idx, val in enumerate(silhouette_values) if not np.isnan(val)]
    if valid_silhouettes:
        best_silhouette_index, best_sil_value = max(valid_silhouettes, key=lambda x: x[1])
        best_k_silhouette = list(cluster_range)[best_silhouette_index]
    else:
        best_k_silhouette = best_k_inertia  # fallback if no valid silhouettes

    # Final K determined by Inertia
    final_k = best_k_inertia

    # Train final K-means with the chosen final_k
    final_kmeans = KMeans(n_clusters=final_k, random_state=42)
    final_kmeans.fit(tfidf_matrix)

    # Assign each document to a cluster
    cluster_assignments = final_kmeans.labels_

    # --- Save artifacts ---
    # 1) K-means model
    model_path = os.path.join(output_dir, f"{dset_name}_kmeans_model_{timestamp}.pkl")
    with open(model_path, "wb") as f:
        pickle.dump(final_kmeans, f)

    # 2) Vectorizer
    vectorizer_path = os.path.join(output_dir, f"{dset_name}_tfidf_vectorizer_{timestamp}.pkl")
    with open(vectorizer_path, "wb") as f:
        pickle.dump(vectorizer, f)

    # 3) Cluster assignments
    cluster_assignments_path = os.path.join(output_dir, f"{dset_name}_cluster_assignments_{timestamp}.pkl")
    with open(cluster_assignments_path, "wb") as f:
        pickle.dump(cluster_assignments, f)

    # 4) Score values
    scores_path = os.path.join(output_dir, f"{dset_name}_cluster_scores_{timestamp}.pkl")
    with open(scores_path, "wb") as f:
        pickle.dump({
            "inertia": inertia_values,
            "silhouette": silhouette_values
        }, f)

    # --- Create two separate plots: Elbow Plot & Silhouette Plot ---
    # Plot 1: Inertia (Elbow)
    elbow_plot_path = os.path.join(output_dir, f"{dset_name}_elbow_plot_{timestamp}.png")
    plt.figure(figsize=(8, 5))
    plt.plot(list(cluster_range), inertia_values, marker="o")
    plt.xlabel("Number of Clusters (k)")
    plt.ylabel("Inertia (Sum of Squares)")
    plt.title("Elbow Method: Inertia vs. Number of Clusters")
    # Mark best_k_inertia
    plt.axvline(best_k_inertia, color="red", linestyle="--", label=f"Best k (Inertia) = {best_k_inertia}")
    plt.grid()
    plt.legend()
    plt.savefig(elbow_plot_path)
    plt.close()

    # Plot 2: Silhouette
    silhouette_plot_path = os.path.join(output_dir, f"{dset_name}_silhouette_plot_{timestamp}.png")
    plt.figure(figsize=(8, 5))
    plt.plot(list(cluster_range), silhouette_values, marker="o", color="green")
    plt.xlabel("Number of Clusters (k)")
    plt.ylabel("Silhouette Score")
    plt.title("Silhouette Score vs. Number of Clusters")
    # Mark best_k_silhouette
    plt.axvline(best_k_silhouette, color="red", linestyle="--",
                label=f"Best k (Silhouette) = {best_k_silhouette}")
    plt.grid()
    plt.legend()
    plt.savefig(silhouette_plot_path)
    plt.close()

    # Return paths/info
    return {
        "model_path": model_path,
        "vectorizer_path": vectorizer_path,
        "cluster_assignments_path": cluster_assignments_path,
        "scores_path": scores_path,
        "elbow_plot_path": elbow_plot_path,
        "silhouette_plot_path": silhouette_plot_path,
        "best_k_inertia": best_k_inertia,
        "best_k_silhouette": best_k_silhouette,
        "final_k": final_k,
        "inertia_values": inertia_values,
        "silhouette_values": silhouette_values,
        "cluster_range": list(cluster_range),
    }