# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
import pickle
import os
import logging
from multiprocessing.pool import ThreadPool
import threading
import _thread
from queue import Queue
import traceback
import datetime

import numpy as np
import faiss

from faiss.contrib.inspect_tools import get_invlist


class BigBatchSearcher:
    """
    Object that manages all the data related to the computation
    except the actual within-bucket matching and the organization of the
    computation (parallel or not)
    """

    def __init__(
            self,
            index, xq, k,
            verbose=0,
            use_float16=False):

        # verbosity
        self.verbose = verbose
        self.tictoc = []

        self.xq = xq
        self.index = index
        self.use_float16 = use_float16
        keep_max = faiss.is_similarity_metric(index.metric_type)
        self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
        self.t_accu = [0] * 6
        self.t_display = self.t0 = time.time()

    def start_t_accu(self):
        self.t_accu_t0 = time.time()

    def stop_t_accu(self, n):
        self.t_accu[n] += time.time() - self.t_accu_t0

    def tic(self, name):
        self.tictoc = (name, time.time())
        if self.verbose > 0:
            print(name, end="\r", flush=True)

    def toc(self):
        name, t0 = self.tictoc
        dt = time.time() - t0
        if self.verbose > 0:
            print(f"{name}: {dt:.3f} s")
        return dt

    def report(self, l):
        if self.verbose == 1 or (
            self.verbose == 2 and (
                l > 1000 and time.time() < self.t_display + 1.0
            )
        ):
            return
        t = time.time() - self.t0
        print(
            f"[{t:.1f} s] list {l}/{self.index.nlist} "
            f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
            f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
            f"wait in {self.t_accu[4]:.3f} "
            f"wait out {self.t_accu[5]:.3f} "
            f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
            f"mem {faiss.get_mem_usage_kb()}",
             end="\r" if self.verbose <= 2 else "\n",
             flush=True,
        )
        self.t_display = time.time()

    def coarse_quantization(self):
        self.tic("coarse quantization")
        bs = 65536
        nq = len(self.xq)
        q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
        for i0 in range(0, nq, bs):
            i1 = min(nq, i0 + bs)
            q_dis_i, q_assign_i = self.index.quantizer.search(
                self.xq[i0:i1], self.index.nprobe)
            # q_dis[i0:i1] = q_dis_i
            q_assign[i0:i1] = q_assign_i
        self.toc()
        self.q_assign = q_assign

    def reorder_assign(self):
        self.tic("bucket sort")
        q_assign = self.q_assign
        q_assign += 1   # move -1 -> 0
        self.bucket_lims = faiss.matrix_bucket_sort_inplace(
            self.q_assign, nbucket=self.index.nlist + 1, nt=16)
        self.query_ids = self.q_assign.ravel()
        if self.verbose > 0:
            print('  number of -1s:', self.bucket_lims[1])
        self.bucket_lims = self.bucket_lims[1:]  # shift back to ignore -1s
        del self.q_assign   # inplace so let's forget about the old version...
        self.toc()

    def prepare_bucket(self, l):
        """ prepare the queries and database items for bucket l"""
        t0 = time.time()
        index = self.index
        # prepare queries
        i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
        q_subset = self.query_ids[i0:i1]
        xq_l = self.xq[q_subset]
        if self.by_residual:
            xq_l = xq_l - index.quantizer.reconstruct(l)
        t1 = time.time()
        # prepare database side
        list_ids, xb_l = get_invlist(index.invlists, l)

        if self.decode_func is None:
            xb_l = xb_l.ravel()
        else:
            xb_l = self.decode_func(xb_l)

        if self.use_float16:
            xb_l = xb_l.astype('float16')
            xq_l = xq_l.astype('float16')

        t2 = time.time()
        self.t_accu[0] += t1 - t0
        self.t_accu[1] += t2 - t1
        return q_subset, xq_l, list_ids, xb_l

    def add_results_to_heap(self, q_subset, D, list_ids, I):
        """add the bucket results to the heap structure"""
        if D is None:
            return
        t0 = time.time()
        if I is None:
            I = list_ids
        else:
            I = list_ids[I]
        self.rh.add_result_subset(q_subset, D, I)
        self.t_accu[3] += time.time() - t0

    def sizes_in_checkpoint(self):
        return (self.xq.shape, self.index.nprobe, self.index.nlist)

    def write_checkpoint(self, fname, completed):
        # write to temp file then move to final file
        tmpname = fname + ".tmp"
        with open(tmpname, "wb") as f:
            pickle.dump(
                {
                    "sizes": self.sizes_in_checkpoint(),
                    "completed": completed,
                    "rh": (self.rh.D, self.rh.I),
                }, f, -1)
        os.replace(tmpname, fname)

    def read_checkpoint(self, fname):
        with open(fname, "rb") as f:
            ckp = pickle.load(f)
        assert ckp["sizes"] == self.sizes_in_checkpoint()
        self.rh.D[:] = ckp["rh"][0]
        self.rh.I[:] = ckp["rh"][1]
        return ckp["completed"]


class BlockComputer:
    """ computation within one bucket """

    def __init__(
            self,
            index,
            method="knn_function",
            pairwise_distances=faiss.pairwise_distances,
            knn=faiss.knn):

        self.index = index
        if index.__class__ == faiss.IndexIVFFlat:
            index_help = faiss.IndexFlat(index.d, index.metric_type)
            decode_func = lambda x: x.view("float32")
            by_residual = False
        elif index.__class__ == faiss.IndexIVFPQ:
            index_help = faiss.IndexPQ(
                index.d, index.pq.M, index.pq.nbits, index.metric_type)
            index_help.pq = index.pq
            decode_func = index_help.pq.decode
            index_help.is_trained = True
            by_residual = index.by_residual
        elif index.__class__ == faiss.IndexIVFScalarQuantizer:
            index_help = faiss.IndexScalarQuantizer(
                index.d, index.sq.qtype, index.metric_type)
            index_help.sq = index.sq
            decode_func = index_help.sq.decode
            index_help.is_trained = True
            by_residual = index.by_residual
        else:
            raise RuntimeError(f"index type {index.__class__} not supported")
        self.index_help = index_help
        self.decode_func = None if method == "index" else decode_func
        self.by_residual = by_residual
        self.method = method
        self.pairwise_distances = pairwise_distances
        self.knn = knn

    def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
        metric_type = self.index.metric_type
        if xq_l.size == 0 or xb_l.size == 0:
            D = I = None
        elif self.method == "index":
            faiss.copy_array_to_vector(xb_l, self.index_help.codes)
            self.index_help.ntotal = len(list_ids)
            D, I = self.index_help.search(xq_l, k)
        elif self.method == "pairwise_distances":
            # TODO implement blockwise to avoid mem blowup
            D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
            I = None
        elif self.method == "knn_function":
            D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)

        return D, I


def big_batch_search(
        index, xq, k,
        method="knn_function",
        pairwise_distances=faiss.pairwise_distances,
        knn=faiss.knn,
        verbose=0,
        threaded=0,
        use_float16=False,
        prefetch_threads=1,
        computation_threads=1,
        q_assign=None,
        checkpoint=None,
        checkpoint_freq=7200,
        start_list=0,
        end_list=None,
        crash_at=-1
        ):
    """
    Search queries xq in the IVF index, with a search function that collects
    batches of query vectors per inverted list. This can be faster than the
    regular search indexes.
    Supports IVFFlat, IVFPQ and IVFScalarQuantizer.

    Supports three computation methods:
    method = "index":
        build a flat index and populate it separately for each index
    method = "pairwise_distances":
        decompress codes and compute all pairwise distances for the queries
        and index and add result to heap
    method = "knn_function":
        decompress codes and compute knn results for the queries

    threaded=0: sequential execution
    threaded=1: prefetch next bucket while computing the current one
    threaded=2: prefetch prefetch_threads buckets at a time.

    compute_threads>1: the knn function will get an additional thread_no that
        tells which worker should handle this.

    In threaded mode, the computation is tiled with the bucket perparation and
    the writeback of results (useful to maximize GPU utilization).

    use_float16: convert all matrices to float16 (faster for GPU gemm)

    q_assign: override coarse assignment, should be a matrix of size nq * nprobe

    checkpointing (only for threaded > 1):
    checkpoint: file where the checkpoints are stored
    checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded

    start_list, end_list: process only a subset of invlists
    """
    nprobe = index.nprobe

    assert method in ("index", "pairwise_distances", "knn_function")

    mem_queries = xq.nbytes
    mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
    mem_res = len(xq) * k * (
        np.dtype('int64').itemsize
        + np.dtype('float32').itemsize
    )
    mem_tot = mem_queries + mem_assign + mem_res
    if verbose > 0:
        logging.info(
            f"memory: queries {mem_queries} assign {mem_assign} "
            f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
        )

    bbs = BigBatchSearcher(
        index, xq, k,
        verbose=verbose,
        use_float16=use_float16
    )

    comp = BlockComputer(
        index,
        method=method,
        pairwise_distances=pairwise_distances,
        knn=knn
    )

    bbs.decode_func = comp.decode_func

    bbs.by_residual = comp.by_residual
    if q_assign is None:
        bbs.coarse_quantization()
    else:
        bbs.q_assign = q_assign
    bbs.reorder_assign()

    if end_list is None:
        end_list = index.nlist

    completed = set()
    if checkpoint is not None:
        assert (start_list, end_list) == (0, index.nlist)
        if os.path.exists(checkpoint):
            logging.info(f"recovering checkpoint: {checkpoint}")
            completed = bbs.read_checkpoint(checkpoint)
            logging.info(f"   already completed: {len(completed)}")
        else:
            logging.info("no checkpoint: starting from scratch")

    if threaded == 0:
        # simple sequential version

        for l in range(start_list, end_list):
            bbs.report(l)
            q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
            t0i = time.time()
            D, I = comp.block_search(xq_l, xb_l, list_ids, k)
            bbs.t_accu[2] += time.time() - t0i
            bbs.add_results_to_heap(q_subset, D, list_ids, I)

    elif threaded == 1:

        # parallel version with granularity 1

        def add_results_and_prefetch(to_add, l):
            """ perform the addition for the previous bucket and
            prefetch the next (if applicable) """
            if to_add is not None:
                bbs.add_results_to_heap(*to_add)
            if l < index.nlist:
                return bbs.prepare_bucket(l)

        prefetched_bucket = bbs.prepare_bucket(start_list)
        to_add = None
        pool = ThreadPool(1)

        for l in range(start_list, end_list):
            bbs.report(l)
            prefetched_bucket_a = pool.apply_async(
                add_results_and_prefetch, (to_add, l + 1))
            q_subset, xq_l, list_ids, xb_l = prefetched_bucket
            bbs.start_t_accu()
            D, I = comp.block_search(xq_l, xb_l, list_ids, k)
            bbs.stop_t_accu(2)
            to_add = q_subset, D, list_ids, I
            bbs.start_t_accu()
            prefetched_bucket = prefetched_bucket_a.get()
            bbs.stop_t_accu(4)

        bbs.add_results_to_heap(*to_add)
        pool.close()
    else:

        def task_manager_thread(
            task,
            pool_size,
            start_task,
            end_task,
            completed,
            output_queue,
            input_queue,
        ):
            try:
                with ThreadPool(pool_size) as pool:
                    res = [pool.apply_async(
                        task,
                        args=(i, output_queue, input_queue))
                        for i in range(start_task, end_task)
                        if i not in completed]
                    for r in res:
                        r.get()
                    pool.close()
                    pool.join()
                output_queue.put(None)
            except:
                traceback.print_exc()
                _thread.interrupt_main()
                raise

        def task_manager(*args):
            task_manager = threading.Thread(
                target=task_manager_thread,
                args=args,
            )
            task_manager.daemon = True
            task_manager.start()
            return task_manager

        def prepare_task(task_id, output_queue, input_queue=None):
            try:
                logging.info(f"Prepare start: {task_id}")
                q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
                output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
                logging.info(f"Prepare end: {task_id}")
            except:
                traceback.print_exc()
                _thread.interrupt_main()
                raise

        def compute_task(task_id, output_queue, input_queue):
            try:
                logging.info(f"Compute start: {task_id}")
                t_wait_out = 0
                while True:
                    t0 = time.time()
                    logging.info(f'Compute input: task {task_id}')
                    input_value = input_queue.get()
                    t_wait_in = time.time() - t0
                    if input_value is None:
                        # signal for other compute tasks
                        input_queue.put(None)
                        break
                    centroid, q_subset, xq_l, list_ids, xb_l = input_value
                    logging.info(f'Compute work: task {task_id}, centroid {centroid}')
                    t0 = time.time()
                    if computation_threads > 1:
                        D, I = comp.block_search(
                            xq_l, xb_l, list_ids, k, thread_id=task_id
                        )
                    else:
                        D, I = comp.block_search(xq_l, xb_l, list_ids, k)
                    t_compute = time.time() - t0
                    logging.info(f'Compute output: task {task_id}, centroid {centroid}')
                    t0 = time.time()
                    output_queue.put(
                        (centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
                    )
                    t_wait_out = time.time() - t0
                logging.info(f"Compute end: {task_id}")
            except:
                traceback.print_exc()
                _thread.interrupt_main()
                raise

        prepare_to_compute_queue = Queue(2)
        compute_to_main_queue = Queue(2)
        compute_task_manager = task_manager(
            compute_task,
            computation_threads,
            0,
            computation_threads,
            set(),
            compute_to_main_queue,
            prepare_to_compute_queue,
        )
        prepare_task_manager = task_manager(
            prepare_task,
            prefetch_threads,
            start_list,
            end_list,
            completed,
            prepare_to_compute_queue,
            None,
        )

        t_checkpoint = time.time()
        while True:
            logging.info("Waiting for result")
            value = compute_to_main_queue.get()
            if not value:
                break
            centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
            # to test checkpointing
            if centroid == crash_at:
                1 / 0
            bbs.t_accu[2] += t_compute
            bbs.t_accu[4] += t_wait_in
            bbs.t_accu[5] += t_wait_out
            logging.info(f"Adding to heap start: centroid {centroid}")
            bbs.add_results_to_heap(q_subset, D, list_ids, I)
            logging.info(f"Adding to heap end: centroid {centroid}")
            completed.add(centroid)
            bbs.report(centroid)
            if checkpoint is not None:
                if time.time() - t_checkpoint > checkpoint_freq:
                    logging.info("writing checkpoint")
                    bbs.write_checkpoint(checkpoint, completed)
                    t_checkpoint = time.time()

        prepare_task_manager.join()
        compute_task_manager.join()

    bbs.tic("finalize heap")
    bbs.rh.finalize()
    bbs.toc()

    return bbs.rh.D, bbs.rh.I
