import copy
from abc import ABC, abstractmethod
from collections import deque
from itertools import product
from threading import Semaphore, Thread
from typing import Any, Dict, Iterator, List

from hyperon_das_atomdb import WILDCARD

import hyperon_das.link_filters as link_filters
from hyperon_das.query_engines.query_engine_protocol import QueryEngine
from hyperon_das.utils import Assignment, QueryAnswer


class QueryAnswerIterator(ABC):
    def __init__(self, source: Any):
        self.source = source
        self.current_value = None
        self.iterator = None

    def __iter__(self):
        return self

    def __next__(self):
        if not self.source or self.iterator is None:
            raise StopIteration
        try:
            self.current_value = next(self.iterator)
        except StopIteration as exception:
            self.current_value = None
            raise exception
        return self.current_value

    def __str__(self):
        return str(self.source)

    @abstractmethod
    def is_empty(self) -> bool:
        """
        Determines if the iterator has no more elements to iterate over.

        Returns:
            bool: True if the iterator is empty and has no more elements to yield, False otherwise.
        """
        ...

    def get(self) -> Any:
        if not self.source or self.current_value is None:
            raise StopIteration
        return self.current_value


class ListIterator(QueryAnswerIterator):
    def __init__(self, source: List[Any]):
        super().__init__(source)
        if source:
            self.iterator = iter(self.source)
            self.current_value = source[0]

    def is_empty(self) -> bool:
        return not self.source


class ProductIterator(QueryAnswerIterator):
    def __init__(self, source: List[QueryAnswerIterator]):
        super().__init__(source)
        if not self.is_empty():
            self.current_value = tuple([iterator.get() for iterator in source])
            self.iterator = product(*self.source)

    def is_empty(self) -> bool:
        return any(iterator.is_empty() for iterator in self.source)


class AndEvaluator(ProductIterator):
    def __init__(self, source: List[QueryAnswerIterator]):
        super().__init__(source)

    def __next__(self):
        while True:
            candidate = super().__next__()
            assignments = [query_answer.assignment for query_answer in candidate]
            composite_assignment = Assignment.compose(assignments)
            if composite_assignment:
                composite_subgraph = [query_answer.subgraph for query_answer in candidate]
                return QueryAnswer(composite_subgraph, composite_assignment)


class LazyQueryEvaluator(ProductIterator):
    def __init__(
        self, link_type: str, source: List[QueryAnswerIterator], query_engine: QueryEngine
    ):
        super().__init__(source)
        self.link_type = link_type
        self.query_engine = query_engine
        self.buffered_answer = None

    def _replace_target_handles(self, link: Dict[str, Any]) -> Dict[str, Any]:
        targets = []
        for target_handle in link["targets"]:
            atom = self.query_engine.get_atom(target_handle)
            if atom.get("targets", None) is not None:
                atom = self._replace_target_handles(atom)
            targets.append(atom)
        answer = copy.deepcopy(link)
        answer["targets"] = targets
        return answer

    def __next__(self):
        if self.buffered_answer:
            try:
                return self.buffered_answer.__next__()
            except StopIteration:
                self.buffered_answer = None
        while self.buffered_answer is None:
            target_info = super().__next__()
            target_handle = []
            wildcard_flag = False
            for query_answer_target in target_info:
                target = query_answer_target.subgraph
                if query_answer_target.assignment:
                    wildcard_flag = True
                if target.get("atom_type", None) == "variable":
                    target_handle.append(WILDCARD)
                    wildcard_flag = True
                else:
                    target_handle.append(target["handle"])
            das_query_answer = self.query_engine.get_links(
                link_filters.Targets(target_handle, self.link_type)
            )
            lazy_query_answer = []
            for answer in das_query_answer:
                assignment = None
                if wildcard_flag:
                    assignment = Assignment()
                    assignment_failed = False
                    for query_answer_target, handle in zip(target_info, answer["targets"]):
                        target = query_answer_target.subgraph
                        if target.get("atom_type", None) == "variable":
                            if not assignment.assign(target["name"], handle):
                                assignment_failed = True
                        else:
                            if not assignment.merge(query_answer_target.assignment):
                                assignment_failed = True
                        if assignment_failed:
                            break
                    if assignment_failed:
                        continue
                    assignment.freeze()
                lazy_query_answer.append(
                    QueryAnswer(self._replace_target_handles(answer), assignment)
                )
            if lazy_query_answer:
                self.buffered_answer = ListIterator(lazy_query_answer)
                next_value = self.buffered_answer.__next__()
        return next_value


class BaseLinksIterator(QueryAnswerIterator, ABC):
    def __init__(self, source: ListIterator, **kwargs) -> None:
        super().__init__(source)
        if not self.source.is_empty():
            if not hasattr(self, 'backend'):
                self.backend = kwargs.get('backend')
            self.chunk_size = kwargs.get('chunk_size', 1000)
            self.cursor = kwargs.get('cursor', 0)
            self.buffer_queue = deque()
            self.iterator = self.source
            self.current_value = self.get_current_value()
            self.fetch_data_thread = Thread(target=self._fetch_data)
            if self.cursor not in (0, None):
                self.semaphore = Semaphore(1)
                self.fetch_data_thread.start()

    def __next__(self) -> Any:
        if self.iterator:
            try:
                return self.get_next_value()
            except StopIteration as e:
                self.current_value = None
                self.iterator = None
                if self.fetch_data_thread.is_alive():
                    self.fetch_data_thread.join()
                if self.cursor in (0, None) and len(self.buffer_queue) == 0:
                    self.current_value = None
                    raise e
                self._refresh_iterator()
                self.fetch_data_thread = Thread(target=self._fetch_data)
                if self.cursor != 0:
                    self.fetch_data_thread.start()
                return self.__next__()
        raise StopIteration

    def _fetch_data(self) -> None:
        kwargs = self.get_fetch_data_kwargs()
        while True:
            if self.semaphore.acquire(blocking=False):
                try:
                    cursor, answer = self.get_fetch_data(**kwargs)
                    self.cursor = cursor
                    self.buffer_queue.extend(answer)
                finally:
                    self.semaphore.release()
                break

    def _refresh_iterator(self) -> None:
        if self.semaphore.acquire(blocking=False):
            try:
                self.source = ListIterator(list(self.buffer_queue))
                self.iterator = self.source
                self.current_value = self.get_current_value()
                self.buffer_queue.clear()
            finally:
                self.semaphore.release()

    def is_empty(self) -> bool:
        return not self.iterator

    @abstractmethod
    def get_next_value(self) -> Any:
        raise NotImplementedError("Subclasses must implement get_next_value method")

    @abstractmethod
    def get_current_value(self) -> Any:
        raise NotImplementedError("Subclasses must implement get_current_value method")

    @abstractmethod
    def get_fetch_data_kwargs(self) -> Dict[str, Any]:
        raise NotImplementedError("Subclasses must implement get_fetch_data_kwargs method")

    @abstractmethod
    def get_fetch_data(self, **kwargs) -> tuple:
        raise NotImplementedError("Subclasses must implement get_fetch_data method")


class LocalIncomingLinks(BaseLinksIterator):
    def __init__(self, source: ListIterator, **kwargs) -> None:
        self.atom_handle = kwargs.get('atom_handle')
        self.targets_document = kwargs.get('targets_document', False)
        super().__init__(source, **kwargs)

    def get_next_value(self) -> Any:
        if not self.is_empty() and self.backend:
            link_handle = next(self.iterator)
            link_document = self.backend.get_atom(
                link_handle, targets_document=self.targets_document
            )
            self.current_value = link_document
        return self.current_value

    def get_current_value(self) -> Any:
        if self.backend:
            try:
                return self.backend.get_atom(
                    self.source.get(), targets_document=self.targets_document
                )
            except StopIteration:
                return None

    def get_fetch_data_kwargs(self) -> Dict[str, Any]:
        return {'handles_only': True, 'cursor': self.cursor, 'chunk_size': self.chunk_size}

    def get_fetch_data(self, **kwargs) -> tuple:
        if self.backend:
            return self.backend.get_incoming_links(self.atom_handle, **kwargs)


class RemoteIncomingLinks(BaseLinksIterator):
    def __init__(self, source: ListIterator, **kwargs) -> None:
        self.atom_handle = kwargs.get('atom_handle')
        self.targets_document = kwargs.get('targets_document', False)
        self.returned_handles = set()
        super().__init__(source, **kwargs)

    def get_next_value(self) -> Any:
        if not self.is_empty():
            while True:
                link_document = next(self.iterator)
                if isinstance(link_document, tuple) or isinstance(link_document, list):
                    handle = link_document[0]['handle']
                elif isinstance(link_document, dict):
                    handle = link_document['handle']
                elif isinstance(link_document, str):
                    handle = link_document
                else:
                    raise ValueError(f"Invalid link document: {link_document}")
                if handle not in self.returned_handles:
                    self.returned_handles.add(handle)
                    self.current_value = link_document
                    break
        return self.current_value

    def get_current_value(self) -> Any:
        try:
            return self.source.get()
        except StopIteration:
            return None

    def get_fetch_data_kwargs(self) -> Dict[str, Any]:
        return {
            'cursor': self.cursor,
            'chunk_size': self.chunk_size,
            'targets_document': self.targets_document,
        }

    def get_fetch_data(self, **kwargs) -> tuple:
        if self.backend:
            return self.backend.get_incoming_links(self.atom_handle, **kwargs)


class CustomQuery(BaseLinksIterator):
    def __init__(self, source: ListIterator, **kwargs) -> None:
        self.index_id = kwargs.pop('index_id', None)
        self.backend = kwargs.pop('backend', None)
        self.is_remote = kwargs.pop('is_remote', False)
        self.kwargs = kwargs
        super().__init__(source, **kwargs)

    def get_next_value(self) -> Any:
        if not self.is_empty():
            self.current_value = next(self.iterator)
        return self.current_value

    def get_current_value(self) -> Any:
        try:
            return self.source.get()
        except StopIteration:
            return None

    def get_fetch_data_kwargs(self) -> Dict[str, Any]:
        kwargs = self.kwargs
        kwargs.update({'cursor': self.cursor, 'chunk_size': self.chunk_size})
        return kwargs

    def get_fetch_data(self, **kwargs) -> tuple:
        if self.backend:
            if self.is_remote:
                return self.backend.custom_query(
                    self.index_id, query=kwargs.get('query', []), **kwargs
                )
            else:
                return self.backend.get_atoms_by_index(
                    self.index_id, query=kwargs.get('query', []), **kwargs
                )


class TraverseLinksIterator(QueryAnswerIterator):
    def __init__(
        self, source: LocalIncomingLinks | RemoteIncomingLinks | Iterator, **kwargs
    ) -> None:
        super().__init__(source)
        self.cursor = kwargs.get('cursor')
        self.targets_only = kwargs.get('targets_only', False)
        self.buffer = None
        self.link_type = kwargs.get('link_type')
        self.cursor_position = kwargs.get('cursor_position')
        self.target_type = kwargs.get('target_type')
        self.custom_filter = kwargs.get('filter')
        if not self.source.is_empty():
            self.iterator = self.source
            self.current_value = self._find_first_valid_element()
            self.buffer = self.current_value

    def __next__(self):
        while True:
            if self.buffer:
                buffered_value, self.buffer = self.buffer, None
                return buffered_value
            link = super().__next__()
            if isinstance(link, tuple):
                link, targets = link
            elif isinstance(link, dict):
                targets = link.pop('targets_document', [])
            else:
                raise ValueError(f"Invalid link document: {link}")
            if (
                not self.link_type
                and self.cursor_position is None
                and not self.target_type
                and not self.custom_filter
            ) or self._filter(link, targets):
                self.current_value = targets if self.targets_only else link
                break
        return self.current_value

    def _find_first_valid_element(self):
        if self.source:
            for link in self.source:
                if isinstance(link, tuple):
                    link, targets = link
                elif isinstance(link, dict):
                    targets = link.get('targets_document', [])
                else:
                    raise ValueError(f"Invalid link document: {link}")
                if self._filter(link, targets):
                    return targets if self.targets_only else link

    def _filter(self, link: Dict[str, Any], targets: list[dict[str, Any]]) -> bool:
        if self.link_type and self.link_type != link['named_type']:
            return False

        try:
            if (
                self.cursor_position is not None
                and self.cursor != link['targets'][self.cursor_position]
            ):
                return False
        except IndexError:
            return False
        except Exception as e:
            raise e

        if self.target_type:
            if not any(target['named_type'] == self.target_type for target in targets):
                return False

        if self.custom_filter:
            deep_link = link.copy()
            deep_link['targets'] = targets
            if self._apply_custom_filter(deep_link) is False:
                return False

        return True

    def _apply_custom_filter(self, atom: Dict[str, Any], F=None) -> bool:
        custom_filter = F if F else self.custom_filter

        assert callable(
            custom_filter
        ), "The custom_filter must be a function with this signature 'def func(atom: dict) -> bool: ...'"

        try:
            if not custom_filter(atom):
                return False
        except Exception as e:
            raise Exception(f"Error while applying the custom filter: {e}")

    def is_empty(self) -> bool:
        return not self.current_value


class TraverseNeighborsIterator(QueryAnswerIterator):
    def __init__(self, source: TraverseLinksIterator, **kwargs) -> None:
        super().__init__(source)
        self.buffered_answer = None
        self.cursor = self.source.cursor
        self.target_type = self.source.target_type
        self.visited_neighbors = []
        self.custom_filter = kwargs.get('filter')
        if not self.source.is_empty():
            self.iterator = source
            self.current_value = self._find_first_valid_element()

    def __next__(self):
        if self.buffered_answer:
            try:
                return self.buffered_answer.__next__()
            except StopIteration:
                self.buffered_answer = None

        while True:
            targets = super().__next__()
            _new_neighbors, match_found = self._process_targets(targets)
            if match_found:
                self.buffered_answer = ListIterator(_new_neighbors)
                self.current_value = self.buffered_answer.__next__()
                return self.current_value

    def _find_first_valid_element(self):
        for targets in self.iterator:
            _new_neighbors, match_found = self._process_targets(targets)
            if match_found:
                self.buffered_answer = ListIterator(_new_neighbors)
                return self.buffered_answer.get()

    def _process_targets(self, targets: list) -> tuple:
        answer = []
        match_found = False
        for target in targets:
            if self._filter(target):
                match_found = True
                self.visited_neighbors.append(target['handle'])
                answer.append(target)
        return (answer, match_found)

    def _filter(self, target: Dict[str, Any]) -> bool:
        handle = target['handle']
        if not (
            self.cursor != handle
            and handle not in self.visited_neighbors
            and (self.target_type == target['named_type'] or not self.target_type)
        ):
            return False

        if self.custom_filter:
            if self.source._apply_custom_filter(target, F=self.custom_filter) is False:
                return False

        return True

    def is_empty(self) -> bool:
        return not self.current_value