import nltk
from nltk.corpus import wordnet as wn

# Ensure the necessary resources are downloaded (run this once)
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

# Mapping from NLTK POS tags to WordNet POS tags
def get_wordnet_pos(treebank_tag):
    """
    Convert a Treebank or similar POS tag into a list of valid WordNet POS tags.
    Include both 'a' and 's' for adjectives (since WordNet has 'a' and 's').
    """
    # You can expand these if your tagset is more fine-grained
    if treebank_tag.startswith('J'):  # Adjective
        # Return both 'a' (normal adj) and 's' (satellite adj) to check
        return ['a', 's']
    elif treebank_tag.startswith('V'):  # Verb
        return ['v']
    elif treebank_tag.startswith('N'):  # Noun
        return ['n']
    elif treebank_tag.startswith('R'):  # Adverb
        return ['r']
    # If we don't recognize the tag, return an empty list
    return []


def check_wordnet_pos(token, tag):
    """
    Check if a word (token) exists in WordNet under the POS implied by 'tag'.
    For adjectives, we attempt both 'a' and 's'.
    
    Returns True if at least one valid synset is found.
    """
    possible_wn_tags = get_wordnet_pos(tag)
    if not possible_wn_tags:
        return False

    for wn_tag in possible_wn_tags:
        synsets = wn.synsets(token, pos=wn_tag)
        if synsets:
            return True

    return False


def matches_pos_tag(pos_tag, allowed_tags):
    """
    Check if a POS tag is in the allowed set.
    If your allowed_tags contain patterns like 'NN' or 'VB', you might:
      - do exact matches, or
      - do partial matches (e.g. pos_tag.startswith('N'))
    Adjust to taste.
    """
    return pos_tag in allowed_tags

def filter_pos_tagged_words(list_of_lists,
                            set_PosTag,
                            set_StopWord=None,
                            set_NegationToken=None,
                            return_string=False,
                            char_NegationJoin=' ',
                            UNK='x'):
    """
    Filters a list of lists of (token, tag) tuples based on:
      1) Skip tokens if in set_StopWord (return as UNK).
      2) Keep only tokens whose tag is in set_PosTag.
      3) Merge any token in set_NegationToken with the following token, joined by char_NegationJoin.
      4) For each remaining token, check WordNet existence (including 's' for adjectives).
      5) If token is not valid, replace with UNK.
      6) If return_string=True, return a list of strings, otherwise return a list of lists of (token, tag).

    :param list_of_lists:        A list of lists, where each sub-list is a
                                 list of (token, tag) tuples.
    :param set_PosTag:           A set of allowed POS tags (e.g., {"NN", "VB", "JJ"}).
    :param set_StopWord:         A set of tokens to be skipped entirely (e.g., {"the", "and"}).
    :param set_NegationToken:    A set of tokens considered negations (e.g., {"not", "no"}).
    :param return_string:        If True, returns a list of strings; else a list of lists of (token, tag).
    :param char_NegationJoin:    The string used to join negation tokens with the following token.
    :param UNK:            The token to replace invalid or stopword tokens (default is 'UNK').
    :return:                     Filtered results as a list of strings or a list of lists of (token, tag).
    """
    if set_StopWord is None:
        set_StopWord = set()
    if set_NegationToken is None:
        set_NegationToken = set()

    output = []

    for pos_tagged_list in list_of_lists:
        # --- Step 1 & 2 & 3: Skip stopwords, keep only set_PosTag, merge negations ---
        list_temp1 = []
        i = 0
        while i < len(pos_tagged_list):
            token, tag = pos_tagged_list[i]

            # 1) Skip stopwords (return as UNK)
            if token in set_StopWord:
                list_temp1.append((UNK, tag))
                i += 1
                continue

            # 2) Keep only allowed POS tags
            if not matches_pos_tag(tag, set_PosTag):
                i += 1
                continue

            # 3) Negation merging: if token is in set_NegationToken and there's a next token
            if token in set_NegationToken and i + 1 < len(pos_tagged_list):
                # Merge with next token
                next_token, next_tag = pos_tagged_list[i + 1]

                # Skip the next token if it is a stopword
                if next_token in set_StopWord:
                    list_temp1.append((UNK, next_tag))
                    i += 2
                    continue

                # Also skip if next_tag not in set_PosTag
                if not matches_pos_tag(next_tag, set_PosTag):
                    i += 2
                    continue

                # Combine tokens: e.g. "not" + "happy" => "not_happy"
                combined_token = f"{token}{char_NegationJoin}{next_token}"
                list_temp1.append((combined_token, next_tag))
                i += 2
            else:
                # If not negation or can't merge, just add it
                list_temp1.append((token, tag))
                i += 1

        # --- Step 4: Check WordNet existence or replace with UNK ---
        list_temp2 = []
        for token, tag in list_temp1:
            # If token has an underscore, decide how to extract check_token
            if '_' in token:
                check_token = token.split('_', 1)[1]
            else:
                check_token = token

            # Replace token with UNK if it doesn't exist in WordNet
            if check_wordnet_pos(check_token, tag):
                list_temp2.append((token, tag))
            else:
                list_temp2.append((UNK, tag))

        # --- Step 5: Return either strings or lists of tuples ---
        if return_string:
            joined_tokens = [tok for (tok, tg) in list_temp2]
            output.append(" ".join(joined_tokens))
        else:
            output.append(list_temp2)

    return output

