import os
import sys
import subprocess
import pickle
import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm

REQUIRED_PACKAGES = ["numpy", "matplotlib", "tqdm", "scikit-learn"]

def install_packages():
    """Install missing required packages."""
    for package in REQUIRED_PACKAGES:
        try:
            __import__(package)
        except ImportError:
            print(f"Package '{package}' is missing. Attempting to install...")
            try:
                subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            except Exception as e:
                print(f"Failed to install package '{package}': {e}")

# Ensure required packages are installed
install_packages()

def explore_kmeans_clustering(
    ngram_texts,
    path_dset,
    dset_name,
    min_clusters=2,
    max_clusters=10,
    step=1,
    max_features=5000,
    min_df=3,
    max_df=0.8
):
    """Perform K-means clustering with TF-IDF vectors over a range of k values."""
    timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M")
    output_dir = os.path.join(path_dset)
    os.makedirs(output_dir, exist_ok=True)

    docs_as_strings = [" ".join(doc) for doc in ngram_texts]
    vectorizer = TfidfVectorizer(max_features=max_features, min_df=min_df, max_df=max_df)
    tfidf_matrix = vectorizer.fit_transform(docs_as_strings)

    cluster_range = range(min_clusters, max_clusters + 1, step)
    inertia_values, silhouette_values = [], []

    for k in tqdm(cluster_range, desc="K-means Clustering"):
        kmeans = KMeans(n_clusters=k, random_state=42, n_init="auto")
        kmeans.fit(tfidf_matrix)
        inertia_values.append(kmeans.inertia_)
        silhouette_values.append(silhouette_score(tfidf_matrix, kmeans.labels_) if k > 1 else np.nan)

    best_k_inertia = cluster_range[np.argmin(inertia_values)]
    valid_silhouettes = [(idx, val) for idx, val in enumerate(silhouette_values) if not np.isnan(val)]
    best_k_silhouette = cluster_range[max(valid_silhouettes, key=lambda x: x[1])[0]] if valid_silhouettes else best_k_inertia
    
    final_k = best_k_inertia
    final_kmeans = KMeans(n_clusters=final_k, random_state=42, n_init="auto")
    final_kmeans.fit(tfidf_matrix)
    cluster_assignments = final_kmeans.labels_

    model_path = os.path.join(output_dir, f"{dset_name}_kmeans_model_{timestamp}.pkl")
    with open(model_path, "wb") as f:
        pickle.dump(final_kmeans, f)

    vectorizer_path = os.path.join(output_dir, f"{dset_name}_tfidf_vectorizer_{timestamp}.pkl")
    with open(vectorizer_path, "wb") as f:
        pickle.dump(vectorizer, f)

    cluster_assignments_path = os.path.join(output_dir, f"{dset_name}_cluster_assignments_{timestamp}.pkl")
    with open(cluster_assignments_path, "wb") as f:
        pickle.dump(cluster_assignments, f)

    scores_path = os.path.join(output_dir, f"{dset_name}_cluster_scores_{timestamp}.pkl")
    with open(scores_path, "wb") as f:
        pickle.dump({"inertia": inertia_values, "silhouette": silhouette_values}, f)

    elbow_plot_path = os.path.join(output_dir, f"{dset_name}_elbow_plot_{timestamp}.png")
    plt.figure(figsize=(8, 5))
    plt.plot(list(cluster_range), inertia_values, marker="o")
    plt.xlabel("Number of Clusters (k)")
    plt.ylabel("Inertia (Sum of Squares)")
    plt.title("Elbow Method: Inertia vs. Number of Clusters")
    plt.axvline(best_k_inertia, color="red", linestyle="--", label=f"Best k (Inertia) = {best_k_inertia}")
    plt.grid()
    plt.legend()
    plt.savefig(elbow_plot_path)
    plt.close()

    silhouette_plot_path = os.path.join(output_dir, f"{dset_name}_silhouette_plot_{timestamp}.png")
    plt.figure(figsize=(8, 5))
    plt.plot(list(cluster_range), silhouette_values, marker="o", color="green")
    plt.xlabel("Number of Clusters (k)")
    plt.ylabel("Silhouette Score")
    plt.title("Silhouette Score vs. Number of Clusters")
    plt.axvline(best_k_silhouette, color="red", linestyle="--", label=f"Best k (Silhouette) = {best_k_silhouette}")
    plt.grid()
    plt.legend()
    plt.savefig(silhouette_plot_path)
    plt.close()

    return {
        "model_path": model_path,
        "vectorizer_path": vectorizer_path,
        "cluster_assignments_path": cluster_assignments_path,
        "scores_path": scores_path,
        "elbow_plot_path": elbow_plot_path,
        "silhouette_plot_path": silhouette_plot_path,
        "best_k_inertia": best_k_inertia,
        "best_k_silhouette": best_k_silhouette,
        "final_k": final_k,
        "inertia_values": inertia_values,
        "silhouette_values": silhouette_values,
        "cluster_range": list(cluster_range),
    }
