a
    ^gh                     @   s6  d Z ddlZddlZddlZddlZddlZddlZdd Zdd Z	dd Z
d	d
 Zdd Zdd Zejd"ddZd#ddZdd Zejd ZeeD ]0ZeeeZeereZeeejree qejdfddZeededd ddejddfddZeededd dejdfdd Zeed!edd dS )$a  

This is a set of function wrappers that override the default numpy versions.

Interoperability functions for pytorch and Faiss: Importing this will allow
pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and
other functions. Torch GPU tensors can only be used with Faiss GPU indexes.
If this is imported with a package that supports Faiss GPU, the necessary
stream synchronization with the current pytorch stream will be automatically
performed.

Numpy ndarrays can continue to be used in the Faiss python interface after
importing this file. All arguments must be uniformly either numpy ndarrays
or Torch tensors; no mixing is allowed.

    Nc                 C   s6   |   sJ | jtjksJ t|   |   S )A gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) )	is_contiguousdtypetorchuint8faissZcast_integer_to_uint8_ptruntyped_storagedata_ptrstorage_offsetx r   g/var/www/html/cobodadashboardai.evdpl.com/venv/lib/python3.9/site-packages/faiss/contrib/torch_utils.pyswig_ptr_from_UInt8Tensor#   s
    r   c                 C   s:   |   sJ | jtjksJ t|   |  d  S r      )	r   r   r   float16r   cast_integer_to_void_ptrr   r	   r
   r   r   r   r   swig_ptr_from_HalfTensor+   s
    r   c                 C   s:   |   sJ | jtjksJ t|   |  d  S )r      )	r   r   r   float32r   Zcast_integer_to_float_ptrr   r	   r
   r   r   r   r   swig_ptr_from_FloatTensor4   s
    r   c                 C   s:   |   sJ | jtjksJ t|   |  d  S r   )	r   r   r   bfloat16r   r   r   r	   r
   r   r   r   r   swig_ptr_from_BFloat16Tensor;   s
    r   c                 C   sD   |   sJ | jtjks&J d| j t|   |  d  S )r   dtype=%sr   )	r   r   r   int32r   Zcast_integer_to_int_ptrr   r	   r
   r   r   r   r   swig_ptr_from_IntTensorC   s
    r   c                 C   sD   |   sJ | jtjks&J d| j t|   |  d  S )r   r      )	r   r   r   int64r   Zcast_integer_to_idx_t_ptrr   r	   r
   r   r   r   r   swig_ptr_from_IndicesTensorK   s
    r   c              
   c   st   |du rt j }t|j}t j }| t j }| t j | zdV  W | || n| || 0 dS )z Creates a scoping object to make Faiss GPU use the same stream
        as pytorch, based on torch.cuda.current_stream().
        Or, a specific pytorch stream can be passed in as a second
        argument, in which case we will use that stream.
    N)	r   cudaZcurrent_streamr   Zcast_integer_to_cudastream_tZcuda_streamZcurrent_deviceZgetDefaultStreamZsetDefaultStream)resZpytorch_streamZcuda_stream_sZ	prior_devZprior_streamr   r   r   using_streamV   s    

r"   Fc                 C   sv   zt | |}W n ty,   |r&Y d S  Y n0 |jd| kr@d S |sV|jd| ksVJ t| |d | t| || d S )NZtorch_replacement_Zreplacement_Z_numpy)getattrAttributeError__name__setattr)	the_classnamereplacementignore_missingZignore_no_baseZorig_methodr   r   r   torch_replace_methodp   s    r+   c                    sR  dd }dd }d2dd}dd }d	d
  d3 fdd	}d4 fdd	}d d d fdd
}dd }d5dd}	d6dd}
dd }dd }d7dd}d8d d!}t | d"| t | d#| t | d$| t | d%| t | d&| t | d'| t | d(|	 t | d)|
 t | d*| t | d+|d,d- t | d.|d,d- t | d/|d,d- t | d0| t | d1| d S )9Nc                 S   s   t |tju r| |S t |tju s*J |j\}}|| jksBJ t|}|j	rt
| dsbJ dt|   | || W d    q1 s0    Y  n| || d S N	getDevice#GPU tensor on CPU index not allowed)typenpndarrayZ	add_numpyr   Tensorshapedr   is_cudahasattrr"   getResourcesZadd_cselfr   nr4   x_ptrr   r   r   torch_replacement_add   s    

,z1handle_torch_Index.<locals>.torch_replacement_addc                 S   s   t |tju r| ||S t |tju s,J |j\}}|| jksDJ t|}t |tju s^J |j|fksrJ dt	|}|j
rt| dsJ dt|   | ||| W d    q1 s0    Y  n| ||| d S )Nz!not same number of vectors as idsr-   r.   )r/   r0   r1   Zadd_with_ids_numpyr   r2   r3   r4   r   r   r5   r6   r"   r7   Zadd_with_ids_c)r9   r   Zidsr:   r4   r;   Zids_ptrr   r   r   torch_replacement_add_with_ids   s    
.z:handle_torch_Index.<locals>.torch_replacement_add_with_idsc                 S   s  t |tju r| |||S t |tju s.J |j\}}|| jksFJ t|}|d u rntj	|||j
tjd}n$t |tju sJ |j||fksJ t|}|jrt| dsJ dt|    | |||| W d    n1 s0    Y  n| |||| |S )Ndevicer   r-   r.   )r/   r0   r1   Zassign_numpyr   r2   r3   r4   r   emptyr?   r   r   r5   r6   r"   r7   Zassign_c)r9   r   klabelsr:   r4   r;   ZL_ptrr   r   r   torch_replacement_assign   s"    
0z4handle_torch_Index.<locals>.torch_replacement_assignc                 S   s   t |tju r| |S t |tju s*J |j\}}|| jksBJ t|}|j	rt
| dsbJ dt|   | || W d    q1 s0    Y  n| || d S r,   )r/   r0   r1   Ztrain_numpyr   r2   r3   r4   r   r5   r6   r"   r7   Ztrain_cr8   r   r   r   torch_replacement_train   s    

,z3handle_torch_Index.<locals>.torch_replacement_trainc           	      S   s   | j \}}t| }|d u r2tj||| jtjd}n$t|tju sDJ |j ||fksVJ t|}|d u r~tj||| jtjd}n$t|tju sJ |j ||fksJ t	|}|||||fS )Nr>   )
r3   r   r   r@   r?   r   r/   r2   r   r   )	r   rA   DIr:   r4   r;   D_ptrI_ptrr   r   r   search_methods_common   s    
z1handle_torch_Index.<locals>.search_methods_commonc           
         s   t |tju r | j||||dS t |tju s2J |j\}}|| jksJJ  ||||\}}}	}}|jrt	| dszJ dt
|  " | |||||	 W d    q1 s0    Y  n| |||||	 ||fS )NrE   rF   r-   r.   )r/   r0   r1   Zsearch_numpyr   r2   r3   r4   r5   r6   r"   r7   Zsearch_c)
r9   r   rA   rE   rF   r:   r4   r;   rG   rH   rI   r   r   torch_replacement_search   s    
2z4handle_torch_Index.<locals>.torch_replacement_searchc              	      s0  t |tju r"| j|||||dS t |tju s4J |j\}}|| jksLJ  ||||\}}	}
}}|d u rtj||||j	tj
d}n&t |tju sJ |j|||fksJ t|}|jrt| dsJ dt|  $ | ||||	|
| W d    n1 s0    Y  n| ||||	|
| |||fS )N)rE   rF   Rr>   r-   r.   )r/   r0   r1   Zsearch_and_reconstruct_numpyr   r2   r3   r4   r@   r?   r   r   r5   r6   r"   r7   Zsearch_and_reconstruct_c)r9   r   rA   rE   rF   rM   r:   r4   r;   rG   rH   ZR_ptrrK   r   r   (torch_replacement_search_and_reconstruct  s"    
6zDhandle_torch_Index.<locals>.torch_replacement_search_and_reconstructrJ   c                   s:  t |tju r$| j||||||dS t |tju s6J |j\}}|| jksNJ  ||||\}	}
}}}|j|| jfkszJ |	 }t
|}|d ur|	 }|j|jksJ t|}nd }|jrt| dsJ dt|  ( | ||	||||
|d W d    n1 s0    Y  n| ||	||||
|d ||fS )NrJ   r-   r.   F)r/   r0   r1   Zsearch_preassigned_numpyr   r2   r3   r4   Znprobe
contiguousr   r   r5   r6   r"   r7   Zsearch_preassigned_c)r9   r   rA   ZIqZDqrE   rF   r:   r4   r;   rG   rH   ZIq_ptrZDq_ptrrK   r   r   $torch_replacement_search_preassigned,  s(    

:z@handle_torch_Index.<locals>.torch_replacement_search_preassignedc                 S   s    t |tjusJ d| |S )Nz(remove_ids not yet implemented for torch)r/   r   r2   Zremove_ids_numpy)r9   r   r   r   r   torch_replacement_remove_idsN  s    z8handle_torch_Index.<locals>.torch_replacement_remove_idsc                 S   s   |d ur"t |tju r"| ||S td}t| drFtd|  }|d u rdtj| j	|tj
d}n$t |tju svJ |j| j	fksJ t|}|jrt| dsJ dt|   | || W d    q1 s0    Y  n| || |S )Ncpur-   r    r>   r.   )r/   r0   r1   Zreconstruct_numpyr   r?   r6   r-   r@   r4   r   r2   r3   r   r5   r"   r7   Zreconstruct_c)r9   keyr   r?   r;   r   r   r   torch_replacement_reconstructS  s     

,z9handle_torch_Index.<locals>.torch_replacement_reconstructr   c                 S   s
  |dkr| j }|d ur2t|tju r2| |||S td}t| drVtd|  }|d u rvtj	|| j
|tjd}n&t|tju sJ |j|| j
fksJ t|}|jrt| dsJ dt|   | ||| W d    n1 s0    Y  n| ||| |S )NrU   rR   r-   r    r>   r.   )Zntotalr/   r0   r1   Zreconstruct_n_numpyr   r?   r6   r-   r@   r4   r   r2   r3   r   r5   r"   r7   Zreconstruct_n_c)r9   Zn0nir   r?   r;   r   r   r   torch_replacement_reconstruct_nu  s$    

.z;handle_torch_Index.<locals>.torch_replacement_reconstruct_nc                 S   s   t |tju r| ||S t |tju s,J |j\}t|}t |tju sNJ |j|| jfksbJ t	|}|j
rt| dsJ dt|   | ||| W d    q1 s0    Y  n| ||| d S r,   )r/   r0   r1   Zupdate_vectors_numpyr   r2   r3   r   r4   r   r5   r6   r"   r7   Zupdate_vectors_c)r9   keysr   r:   Zkeys_ptrr;   r   r   r    torch_replacement_update_vectors  s    .z<handle_torch_Index.<locals>.torch_replacement_update_vectorsc                 S   s   t |tju r| ||S t |tju s,J |j\}}|| jksDJ t|}|j	rZJ dt
| drlJ dt|}| |||| tt|j|d  d}t|d }tt|j| }	tt|j| }
||	|
fS )Nz1Range search using GPU tensor not yet implementedr-   z-Range search on GPU index not yet implemented   r   rU   )r/   r0   r1   Zrange_search_numpyr   r2   r3   r4   r   r5   r6   r   ZRangeSearchResultZrange_search_cZ
from_numpyZrev_swig_ptrlimscopyZastypeintZ	distancesrB   )r9   r   Zthreshr:   r4   r;   r!   r[   ndrE   rF   r   r   r   torch_replacement_range_search  s    

"z:handle_torch_Index.<locals>.torch_replacement_range_searchc                 S   s   t |tju r| ||S t |tju s,J |j\}}|| jksDJ t|}|d u rltj	|| 
 tjd}n|j|| 
 fksJ t|}|jrt| dsJ dt|   | ||| W d    q1 s0    Y  n| ||| |S N)r   r-   r.   )r/   r0   r1   Zsa_encode_numpyr   r2   r3   r4   r   r@   sa_code_sizer   r   r5   r6   r"   r7   Zsa_encode_c)r9   r   codesr:   r4   r;   	codes_ptrr   r   r   torch_replacement_sa_encode  s     
.z7handle_torch_Index.<locals>.torch_replacement_sa_encodec                 S   s   t |tju r| ||S t |tju s,J |j\}}||  ksFJ t|}|d u rltj	|| j
tjd}n&t |tju s~J |j|| j
fksJ t|}|jrt| dsJ dt|   | ||| W d    q1 s0    Y  n| ||| |S r`   )r/   r0   r1   Zsa_decode_numpyr   r2   r3   ra   r   r@   r4   r   r   r5   r6   r"   r7   Zsa_decode_c)r9   rb   r   r:   csrc   r;   r   r   r   torch_replacement_sa_decode  s"    
.z7handle_torch_Index.<locals>.torch_replacement_sa_decodeaddZadd_with_idsZassigntrainsearchZ
remove_idsZreconstructZreconstruct_nZrange_searchZupdate_vectorsT)r*   Zsearch_and_reconstructZsearch_preassignedZ	sa_encodeZ	sa_decode)N)NN)NNN)N)r   rU   N)N)N)r+   )r'   r<   r=   rC   rD   rL   rN   rP   rQ   rT   rW   rY   r_   rd   rf   r   rK   r   handle_torch_Index   sF    
"
"
%

rj   r   c                 C   sV  t |tju r"tj| ||||dS | \}}| s:J |jtj	ksJJ |j
rXJ d|  \}}||kspJ |  s|J | jtj	ksJ | j
rJ dtj|||jtj	d}	tj|||jtjd}
t|
}t|	}t|}t| }|tjkrt|||||||| nB|tjkr2t|||||||| nt||||||||||
 |	|
fS )N)metric
metric_argzuse knn_gpu for GPU tensorsr>   )r/   r0   r1   r   Z	knn_numpysizer   r   r   r   r5   r@   r?   r   r   r   	METRIC_L2Z	knn_L2sqrZMETRIC_INNER_PRODUCTZknn_inner_productZknn_extra_metrics)xqxbrA   rk   rl   nbr4   nqd2rE   rF   rH   rG   xb_ptrxq_ptrr   r   r   torch_replacement_knn&  s@    rv   ZknnTrU   c	              
   C   s  t |tju r&t| |||||||S | \}	}
| r@d}n"|  rZ| }d}ntd|j	t
jkr~tj}t|}n@|j	t
jkrtj}t|}n$|j	t
jkrtj}t|}ntd| \}}||
ksJ | rd}n"|  r| }d}ntd|j	t
jkr$tj}t|}nD|j	t
jkrBtj}t|}n&|j	t
jkr`tj}t|}ntd|d u rt
j|||jt
jd}n&|j||fksJ |j	t
jksJ |d u rt
j|||jt
jd}n|j||fksJ |j	t
jkrtj}t|}n:|j	|j	  kr"t
jkr6n ntj}t|}ntdt|}t }||_||_|
|_ ||_!||_"||_#|	|_$||_%||_&||_'||_(||_)||_*||_+||_||_,t-|  t.| | W d    n1 s0    Y  ||fS )NTFz$matrix should be row or column-majorz'xq must be float32, float16 or bfloat16r>   zI must be i64 or i32)/r/   r0   r1   r   Zknn_gpu_numpyrm   r   t	TypeErrorr   r   r   DistanceDataType_F32r   r   DistanceDataType_F16r   r   ZDistanceDataType_BF16r   r@   r?   r3   r   ZIndicesDataType_I64r   r   ZIndicesDataType_I32r   GpuDistanceParamsrk   rA   dimsvectorsvectorsRowMajor
vectorType
numVectorsqueriesqueriesRowMajor	queryType
numQueriesoutDistancesZ
outIndicesZoutIndicesTypeuse_cuvsr"   bfKnn)r!   ro   rp   rA   rE   rF   rk   r?   r   rq   r4   xb_row_majorxb_typert   rr   rs   xq_row_majorxq_typeru   ZI_typerH   rG   argsr   r   r   torch_replacement_knn_gpuT  s    








"

,r   Zknn_gpuc                 C   s  t |tju r t| ||||S | \}}| r:d}n"|  rT| }d}ntd|j	t
jkrxtj}	t|}
n$|j	t
jkrtj}	t|}
ntd| \}}||ksJ | rd}n"|  r| }d}ntd|j	t
jkrtj}t|}n&|j	t
jkr tj}t|}ntd|d u rJt
j|||jt
jd}n&|j||fks^J |j	t
jkspJ t|}t }||_d|_||_|
|_||_|	|_||_||_||_||_||_||_ ||_t!|  t"| | W d    n1 s0    Y  |S )	NTFz'xb matrix should be row or column-majorzxb must be float32 or float16z'xq matrix should be row or column-majorzxq must be float32 or float16r>   rU   )#r/   r0   r1   r   Zpairwise_distance_gpu_numpyrm   r   rw   rx   r   r   r   ry   r   r   rz   r   r@   r?   r3   r{   rk   rA   r|   r}   r~   r   r   r   r   r   r   r   r"   r   )r!   ro   rp   rE   rk   r?   rq   r4   r   r   rt   rr   rs   r   r   ru   rG   r   r   r   r   'torch_replacement_pairwise_distance_gpu  sj    





,r   Zpairwise_distance_gpu)N)FF) __doc__r   r   
contextlibinspectsysnumpyr0   r   r   r   r   r   r   contextmanagerr"   r+   rj   modulesZfaiss_moduledirsymbolr#   objisclassr'   
issubclassIndexrn   rv   r   r   r   r   r   r   <module>   sB   	 
   



*^G