import asyncio
from typing import Any, Dict, List, Optional, Union, cast

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.utils.pydantic import is_basemodel_instance
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self


class SyntheticDataGenerator(BaseModel):
    """Generate synthetic data using the given LLM and few-shot template.

    Utilizes the provided LLM to produce synthetic data based on the
    few-shot prompt template.

    Attributes:
        template (FewShotPromptTemplate): Template for few-shot prompting.
        llm (Optional[BaseLanguageModel]): Large Language Model to use for generation.
        llm_chain (Optional[Chain]): LLM chain with the LLM and few-shot template.
        example_input_key (str): Key to use for storing example inputs.

    Usage Example:
        >>> template = FewShotPromptTemplate(...)
        >>> llm = BaseLanguageModel(...)
        >>> generator = SyntheticDataGenerator(template=template, llm=llm)
        >>> results = generator.generate(subject="climate change", runs=5)
    """

    template: FewShotPromptTemplate
    llm: Optional[BaseLanguageModel] = None
    results: list = []
    llm_chain: Optional[Chain] = None
    example_input_key: str = "example"

    model_config = ConfigDict(
        validate_assignment=True,
    )

    @model_validator(mode="after")
    def set_llm_chain(self) -> Self:
        llm_chain = self.llm_chain
        llm = self.llm
        few_shot_template = self.template

        if not llm_chain:  # If llm_chain is None or not present
            if llm is None or few_shot_template is None:
                raise ValueError(
                    "Both llm and few_shot_template must be provided if llm_chain is "
                    "not given."
                )
            self.llm_chain = LLMChain(llm=llm, prompt=few_shot_template)

        return self

    @staticmethod
    def _format_dict_to_string(input_dict: Dict) -> str:
        formatted_str = ", ".join(
            [f"{key}: {value}" for key, value in input_dict.items()]
        )
        return formatted_str

    def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None:
        """Prevents duplicates by adding previously generated examples to the few shot
        list."""
        if self.template and self.template.examples:
            if is_basemodel_instance(example):
                formatted_example = self._format_dict_to_string(
                    cast(BaseModel, example).dict()
                )
            elif isinstance(example, dict):
                formatted_example = self._format_dict_to_string(example)
            else:
                formatted_example = str(example)
            self.template.examples.pop(0)
            self.template.examples.append({self.example_input_key: formatted_example})

    def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]:
        """Generate synthetic data using the given subject string.

        Args:
            subject (str): The subject the synthetic data will be about.
            runs (int): Number of times to generate the data.
            extra (str): Extra instructions for steerability in data generation.

        Returns:
            List[str]: List of generated synthetic data.

        Usage Example:
            >>> results = generator.generate(subject="climate change", runs=5,
            extra="Focus on environmental impacts.")
        """
        if self.llm_chain is None:
            raise ValueError(
                "llm_chain is none, either set either llm_chain or llm at generator "
                "construction"
            )
        for _ in range(runs):
            result = self.llm_chain.run(subject=subject, *args, **kwargs)
            self.results.append(result)
            self._update_examples(result)
        return self.results

    async def agenerate(
        self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any
    ) -> List[str]:
        """Generate synthetic data using the given subject asynchronously.

        Note: Since the LLM calls run concurrently,
        you may have fewer duplicates by adding specific instructions to
        the "extra" keyword argument.

        Args:
            subject (str): The subject the synthetic data will be about.
            runs (int): Number of times to generate the data asynchronously.
            extra (str): Extra instructions for steerability in data generation.

        Returns:
            List[str]: List of generated synthetic data for the given subject.

        Usage Example:
            >>> results = await generator.agenerate(subject="climate change", runs=5,
            extra="Focus on env impacts.")
        """

        async def run_chain(
            subject: str, extra: str = "", *args: Any, **kwargs: Any
        ) -> None:
            if self.llm_chain is not None:
                result = await self.llm_chain.arun(
                    subject=subject, extra=extra, *args, **kwargs
                )
                self.results.append(result)

        await asyncio.gather(
            *(run_chain(subject=subject, extra=extra) for _ in range(runs))
        )
        return self.results
