"""Chain for applying self-critique using the SmartGPT workflow."""

from typing import Any, Dict, List, Optional, Tuple, Type

from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.schema import LLMResult, PromptValue
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import (
    AIMessagePromptTemplate,
    BaseMessagePromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)
from pydantic import ConfigDict, model_validator


class SmartLLMChain(Chain):
    """Chain for applying self-critique using the SmartGPT workflow.

    See details at https://youtu.be/wVzuvf9D9BU

    A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
    performs these 3 steps:
    1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
       each result is an "idea"
    2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
       & picks the best one
    3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
       & outputs only the (improved version of) the best output

    In total, SmartLLMChain pass will use n_ideas+2 LLM calls

    Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
    when the underlying models have the capability for reflection, which smaller models
    often don't.

    Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
    """

    class SmartLLMChainHistory:
        question: str = ""
        ideas: List[str] = []
        critique: str = ""

        @property
        def n_ideas(self) -> int:
            return len(self.ideas)

        def ideation_prompt_inputs(self) -> Dict[str, Any]:
            return {"question": self.question}

        def critique_prompt_inputs(self) -> Dict[str, Any]:
            return {
                "question": self.question,
                **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
            }

        def resolve_prompt_inputs(self) -> Dict[str, Any]:
            return {
                "question": self.question,
                **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
                "critique": self.critique,
            }

    prompt: BasePromptTemplate
    """Prompt object to use."""
    output_key: str = "resolution"
    ideation_llm: Optional[BaseLanguageModel] = None
    """LLM to use in ideation step. If None given, 'llm' will be used."""
    critique_llm: Optional[BaseLanguageModel] = None
    """LLM to use in critique step. If None given, 'llm' will be used."""
    resolver_llm: Optional[BaseLanguageModel] = None
    """LLM to use in resolve step. If None given, 'llm' will be used."""
    llm: Optional[BaseLanguageModel] = None
    """LLM to use for each steps, if no specific llm for that step is given. """
    n_ideas: int = 3
    """Number of ideas to generate in idea step"""
    return_intermediate_steps: bool = False
    """Whether to return ideas and critique, in addition to resolution."""
    history: SmartLLMChainHistory = SmartLLMChainHistory()

    model_config = ConfigDict(
        extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def validate_inputs(cls, values: Dict[str, Any]) -> Any:
        """Ensure we have an LLM for each step."""
        llm = values.get("llm")
        ideation_llm = values.get("ideation_llm")
        critique_llm = values.get("critique_llm")
        resolver_llm = values.get("resolver_llm")

        if not llm and not ideation_llm:
            raise ValueError(
                "Either ideation_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if not llm and not critique_llm:
            raise ValueError(
                "Either critique_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if not llm and not resolver_llm:
            raise ValueError(
                "Either resolve_llm or llm needs to be given. Pass llm, "
                "if you want to use the same llm for all steps, or pass "
                "ideation_llm, critique_llm and resolver_llm if you want "
                "to use different llms for each step."
            )
        if llm and ideation_llm and critique_llm and resolver_llm:
            raise ValueError(
                "LLMs are given for each step (ideation_llm, critique_llm,"
                " resolver_llm), but backup LLM (llm) is also given, which"
                " would not be used."
            )
        return values

    @property
    def input_keys(self) -> List[str]:
        """Defines the input keys."""
        return self.prompt.input_variables

    @property
    def output_keys(self) -> List[str]:
        """Defines the output keys."""
        if self.return_intermediate_steps:
            return ["ideas", "critique", self.output_key]
        return [self.output_key]

    def prep_prompts(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Tuple[PromptValue, Optional[List[str]]]:
        """Prepare prompts from inputs."""
        stop = None
        if "stop" in inputs:
            stop = inputs["stop"]
        selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
        prompt = self.prompt.format_prompt(**selected_inputs)
        _colored_text = get_colored_text(prompt.to_string(), "green")
        _text = "Prompt after formatting:\n" + _colored_text
        if run_manager:
            run_manager.on_text(_text, end="\n", verbose=self.verbose)
        if "stop" in inputs and inputs["stop"] != stop:
            raise ValueError(
                "If `stop` is present in any inputs, should be present in all."
            )
        return prompt, stop

    def _call(
        self,
        input_list: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
        self.history.question = prompt.to_string()
        ideas = self._ideate(stop, run_manager)
        self.history.ideas = ideas
        critique = self._critique(stop, run_manager)
        self.history.critique = critique
        resolution = self._resolve(stop, run_manager)
        if self.return_intermediate_steps:
            return {"ideas": ideas, "critique": critique, self.output_key: resolution}
        return {self.output_key: resolution}

    def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
        """Between steps, only the LLM result text is passed, not the LLMResult object.
        This function extracts the text from an LLMResult."""
        if len(result.generations) != 1:
            raise ValueError(
                f"In SmartLLM the LLM result in step {step} is not "
                "exactly 1 element. This should never happen"
            )
        if len(result.generations[0]) != 1:
            raise ValueError(
                f"In SmartLLM the LLM in step {step} returned more than "
                "1 output. SmartLLM only works with LLMs returning "
                "exactly 1 output."
            )
        return result.generations[0][0].text

    def get_prompt_strings(
        self, stage: str
    ) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
        role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
        role_strings.append(
            (
                HumanMessagePromptTemplate,
                "Question: {question}\nAnswer: Let's work this out in a step by "
                "step way to be sure we have the right answer:",
            )
        )
        if stage == "ideation":
            return role_strings
        role_strings.extend(
            [
                *[
                    (
                        AIMessagePromptTemplate,
                        "Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
                    )
                    for i in range(self.n_ideas)
                ],
                (
                    HumanMessagePromptTemplate,
                    "You are a researcher tasked with investigating the "
                    f"{self.n_ideas} response options provided. List the flaws and "
                    "faulty logic of each answer option. Let's work this out in a step"
                    " by step way to be sure we have all the errors:",
                ),
            ]
        )
        if stage == "critique":
            return role_strings
        role_strings.extend(
            [
                (AIMessagePromptTemplate, "Critique: {critique}"),
                (
                    HumanMessagePromptTemplate,
                    "You are a resolver tasked with 1) finding which of "
                    f"the {self.n_ideas} answer options the researcher thought was  "
                    "best, 2) improving that answer and 3) printing the answer in "
                    "full. Don't output anything for step 1 or 2, only the full "
                    "answer in 3. Let's work this out in a step by step way to "
                    "be sure we have the right answer:",
                ),
            ]
        )
        if stage == "resolve":
            return role_strings
        raise ValueError(
            "stage should be either 'ideation', 'critique' or 'resolve',"
            f" but it is '{stage}'. This should never happen."
        )

    def ideation_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))

    def critique_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))

    def resolve_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))

    def _ideate(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> List[str]:
        """Generate n_ideas ideas as response to user prompt."""
        llm = self.ideation_llm if self.ideation_llm else self.llm
        prompt = self.ideation_prompt().format_prompt(
            **self.history.ideation_prompt_inputs()
        )
        callbacks = run_manager.get_child() if run_manager else None
        if llm:
            ideas = [
                self._get_text_from_llm_result(
                    llm.generate_prompt([prompt], stop, callbacks),
                    step="ideate",
                )
                for _ in range(self.n_ideas)
            ]
            for i, idea in enumerate(ideas):
                _colored_text = get_colored_text(idea, "blue")
                _text = f"Idea {i+1}:\n" + _colored_text
                if run_manager:
                    run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return ideas
        else:
            raise ValueError("llm is none, which should never happen")

    def _critique(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> str:
        """Critique each of the ideas from ideation stage & select best one."""
        llm = self.critique_llm if self.critique_llm else self.llm
        prompt = self.critique_prompt().format_prompt(
            **self.history.critique_prompt_inputs()
        )
        callbacks = run_manager.handlers if run_manager else None
        if llm:
            critique = self._get_text_from_llm_result(
                llm.generate_prompt([prompt], stop, callbacks), step="critique"
            )
            _colored_text = get_colored_text(critique, "yellow")
            _text = "Critique:\n" + _colored_text
            if run_manager:
                run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return critique
        else:
            raise ValueError("llm is none, which should never happen")

    def _resolve(
        self,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> str:
        """Improve upon the best idea as chosen in critique step & return it."""
        llm = self.resolver_llm if self.resolver_llm else self.llm
        prompt = self.resolve_prompt().format_prompt(
            **self.history.resolve_prompt_inputs()
        )
        callbacks = run_manager.handlers if run_manager else None
        if llm:
            resolution = self._get_text_from_llm_result(
                llm.generate_prompt([prompt], stop, callbacks), step="resolve"
            )
            _colored_text = get_colored_text(resolution, "green")
            _text = "Resolution:\n" + _colored_text
            if run_manager:
                run_manager.on_text(_text, end="\n", verbose=self.verbose)
            return resolution
        else:
            raise ValueError("llm is none, which should never happen")
