# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This contrib module contains Pytorch code for quantization.
"""

import torch
import faiss
import math
from faiss.contrib.torch import clustering
# the kmeans can produce both torch and numpy centroids


class Quantizer:

    def __init__(self, d, code_size):
        """
        d: dimension of vectors
        code_size: nb of bytes of the code (per vector)
        """
        self.d = d
        self.code_size = code_size

    def train(self, x):
        """
        takes a n-by-d array and peforms training
        """
        pass

    def encode(self, x):
        """
        takes a n-by-d float array, encodes to an n-by-code_size uint8 array
        """
        pass

    def decode(self, codes):
        """
        takes a n-by-code_size uint8 array, returns a n-by-d array
        """
        pass


class VectorQuantizer(Quantizer):

    def __init__(self, d, k):

        code_size = int(math.ceil(torch.log2(k) / 8))
        Quantizer.__init__(d, code_size)
        self.k = k

    def train(self, x):
        pass


class ProductQuantizer(Quantizer):
    def __init__(self, d, M, nbits):
        """ M: number of subvectors, d%M == 0
        nbits: number of bits that each vector is encoded into
        """
        assert d % M == 0
        assert nbits == 8  # todo: implement other nbits values
        code_size = int(math.ceil(M * nbits / 8))
        Quantizer.__init__(self, d, code_size)
        self.M = M
        self.nbits = nbits
        self.code_size = code_size

    def train(self, x):
        nc = 2 ** self.nbits
        sd = self.d // self.M
        dev = x.device
        dtype = x.dtype
        self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
        for m in range(self.M):
            xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
            data = clustering.DatasetAssign(xsub.contiguous())
            self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)

    def encode(self, x):
        codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
        for m in range(self.M):
            xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
            _, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
            codes[:, m] = I.ravel()
        return codes

    def decode(self, codes):
        idxs = [codes[:, m].long() for m in range(self.M)]
        vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
        stacked_vectors = torch.stack(vectors, dim=1)
        cbd = self.codebook.shape[-1]
        x_rec = stacked_vectors.reshape(-1, cbd * self.M)
        return x_rec
