a
    bgT                     @   sD  d Z ddlZddlZddlZddlZddlZddlZddlZddlm	Z	 ddl
mZmZmZmZmZmZmZmZmZmZmZmZ ddlmZmZ ddlmZ ddlmZmZ ddlm Z m!Z!m"Z"m#Z#m$Z$m%Z% erddl&Z'ddl(m)  m*  m+Z, e-e.Z/d	Z0d
Z1dZ2ddddZ3G dd dZ4e4 Z5e4 Z6G dd de7e	Z8G dd de!Z9G dd de!Z:G dd dZ;ee7 e<eee7 eej= f dddZ>de7ddd Z?ed ee7ddf d!d"d#Z@eeAe4f ZBe7ZCG d$d% d%ZDeDZEe7ZFG d&d' d'e9e:eeEeFf ZGee7eef ZHeAZIG d(d) d)e9e:eeHeIf ZJeGZKeJZLeDZMdS )*z*A common module for NVIDIA Riva Runnables.    N)Enum)TYPE_CHECKINGAnyAsyncGeneratorAsyncIteratorDict	GeneratorIteratorListOptionalTupleUnioncast)
AnyMessageBaseMessage)PromptValue)RunnableConfigRunnableSerializable)
AnyHttpUrl	BaseModelFieldparse_obj_asroot_validator	validator      ?i  )
.!?   ¡   ¿zriva.clientreturnc               
   C   sB   zddl } W n. ty: } ztd|W Y d}~n
d}~0 0 | jS )z5Import the riva client and raise an error on failure.r   NziCould not import the NVIDIA Riva client library. Please install it with `pip install nvidia-riva-client`.)riva.clientImportErrorclient)rivaerr r(   w/var/www/html/cobodadashboardai.evdpl.com/venv/lib/python3.9/site-packages/langchain_community/utilities/nvidia_riva.py_import_riva_client1   s    r*   c                   @   s   e Zd ZdZdS )	SentinelTzAn empty Sentinel type.N)__name__
__module____qualname____doc__r(   r(   r(   r)   r+   >   s   r+   c                   @   sN   e Zd ZdZdZdZdZdZdZdZ	e
ed dd	d
ZeddddZdS )RivaAudioEncodinga  An enum of the possible choices for Riva audio encoding.

    The list of types exposed by the Riva GRPC Protobuf files can be found
    with the following commands:
    ```python
    import riva.client
    print(riva.client.AudioEncoding.keys())  # noqa: T201
    ```
    ALAWENCODING_UNSPECIFIEDFLAC
LINEAR_PCMMULAWOGGOPUS)format_coder"   c              
   C   sR   z| j | j| jd| W S  tyL } ztd| |W Y d}~n
d}~0 0 dS )zReturn the audio encoding specified by the format code in the wave file.

        ref: https://mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
        )         z>The following wave file format code is not supported by Riva: N)r4   r1   r5   KeyErrorNotImplementedError)clsr7   r'   r(   r(   r)   from_wave_format_codeX   s    z'RivaAudioEncoding.from_wave_format_codezriva.client.AudioEncodingr!   c                 C   s   t  }t|j| S )z-Returns the Riva API object for the encoding.)r*   getattrZAudioEncodingselfriva_clientr(   r(   r)   riva_pb2f   s    zRivaAudioEncoding.riva_pb2N)r,   r-   r.   r/   r1   r2   r3   r4   r5   r6   classmethodintr>   propertyrC   r(   r(   r(   r)   r0   F   s   
r0   c                   @   s   e Zd ZU dZeeddddgdZeeef e	d< eddd	Z
ee e	d
< eddddZeddddeeedddZdS )RivaAuthMixinzBConfiguration for the authentication to a Riva service connection.zhttp://localhost:50051z1The full URL where the Riva service can be found.z"https://user@pass:riva.example.com)descriptionZexamplesurlNz@A full path to the file where Riva's public ssl key can be read.rH   ssl_certzriva.client.Authr!   c                 C   sB   t  }tt| j}|jdk}t| jdd }|j| j||dS )z!Return a riva client auth object.https/   )rK   use_ssluri)	r*   r   r   rI   schemestrsplitZAuthrK   )rA   rB   rI   rO   Zurl_no_schemer(   r(   r)   authz   s    
zRivaAuthMixin.authT)preZallow_reusevalr"   c                 C   s$   t |trtttt|S tt|S )z:Do some initial conversations for the URL before checking.)
isinstancerR   r   r   r   )r=   rW   r(   r(   r)   _validate_url   s    
zRivaAuthMixin._validate_url)r,   r-   r.   r/   r   r   rI   r   rR   __annotations__rK   r   rF   rT   r   rD   r   rY   r(   r(   r(   r)   rG   m   s   

rG   c                   @   sP   e Zd ZU dZeejddZeed< edddZ	e
ed< edd	dZeed
< dS )RivaCommonConfigMixinz%A collection of common Riva settings.z!The encoding on the audio stream.)defaultrH   encodingi@  z*The sample rate frequency of audio stream.sample_rate_hertzzen-USzaThe [BCP-47 language code](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for the target language.language_codeN)r,   r-   r.   r/   r   r0   r4   r]   rZ   r^   rE   r_   rR   r(   r(   r(   r)   r[      s   
r[   c                   @   sz   e Zd ZU dZejed< ejed< ddddZdddd	Z	ddd
dZ
edddZddddZddddZdS )_Eventz3A combined event that is threadsafe and async safe._event_aeventNr!   c                 C   s   t  | _t | _dS )zInitialize the event.N)	threadingEventra   asynciorb   rA   r(   r(   r)   __init__   s    
z_Event.__init__c                 C   s   | j   | j  dS zSet the event.N)ra   setrb   rf   r(   r(   r)   ri      s    
z
_Event.setc                 C   s   | j   | j  dS rh   )ra   clearrb   rf   r(   r(   r)   rj      s    
z_Event.clearc                 C   s
   | j  S )zIndicate if the event is set.)ra   is_setrf   r(   r(   r)   rk      s    z_Event.is_setc                 C   s   | j   dS )zWait for the event to be set.N)ra   waitrf   r(   r(   r)   rl      s    z_Event.waitc                    s   | j  I dH  dS )z#Async wait for the event to be set.N)rb   rl   rf   r(   r(   r)   
async_wait   s    z_Event.async_wait)r,   r-   r.   r/   rc   rd   rZ   re   rg   ri   rj   boolrk   rl   rm   r(   r(   r(   r)   r`      s   


r`   )output_directorysample_rater"   c                 C   sr   | rnt jddd| d}|j}W d   n1 s20    Y  t|d}|d |d || ||fS d	S )
zECreate a new wave file and return the wave write object and filename.bxz.wavF)modesuffixdeletedirNwbr8   rN   )NN)tempfileNamedTemporaryFilenamewaveopenZsetnchannelsZsetsampwidthZsetframerate)ro   rp   fwav_file_namewav_filer(   r(   r)   _mk_wave_file   s    $


r   TTSInputTyperV   c                 C   s.   t | tr|  S t | tr&t| jS t| S )zAttempt to coerce the input value to a string.

    This is particularly useful for converting LangChain message to strings.
    )rX   r   Z	to_stringr   rR   contentrW   r(   r(   r)   _coerce_string   s
    


r   )inputsr"   c                 c   s   d}| D ]}t |}tD ]0}||v r||d\}}|| | V  d}qq||7 }t|tkrtdt|tD ]}|||d  V  qnd}q|r|V  dS )z9Filter the input chunks are return strings ready for TTS. r8   r      N)r   _SENTENCE_TERMINATORSrS   len_MAX_TEXT_LENGTHrange)r   bufferchunk
terminatorZlast_sentenceidxr(   r(   r)   _process_chunks   s    r   c                   @   sJ  e Zd ZU dZejed< ejed< ejed< e	ed< e	ed< e	ed< e
ej ed< d)ed
dddZeed
d
f dddZee dddZeedddZeedddZeedddZeedddZd*ee
e d
dddZd+ee
e d
dddZd,e
e d
d d!d"Zd-e
e d
d d#d$Zed% d
d&d'd(Zd
S ).AudioStreamz%A message containing streaming audio.	_put_lock_queueoutputhangupuser_talking
user_quiet_workerr   N)maxsizer"   c                 C   sD   t  | _tj|d| _t | _t | _t | _	t | _
d| _dS )zInitialize the queue.)r   N)rc   Lockr   queueQueuer   r   r`   r   r   r   r   )rA   r   r(   r(   r)   rg     s    

zAudioStream.__init__r!   c                 c   sL   z| j dt}W n tjy*   Y q Y n0 |tkr6qH|V  | j   q dS )zReturn an error.TN)r   get_QUEUE_GET_TIMEOUTr   EmptyHANGUP	task_donerA   Znext_valr(   r(   r)   __iter__  s    
zAudioStream.__iter__c                 C  s\   z"t  d| jjdtI dH }W n tjy:   Y q Y n0 |tkrFqX|V  | j	  q dS )z4Iterate through all items in the queue until HANGUP.NT)
re   get_event_looprun_in_executorr   r   r   r   r   r   r   r   r(   r(   r)   	__aiter__&  s    
zAudioStream.__aiter__c                 C   s
   | j  S )z(Indicate if the audio stream has hungup.)r   rk   rf   r(   r(   r)   hungup9  s    zAudioStream.hungupc                 C   s
   | j  S )z-Indicate in the input stream buffer is empty.)r   emptyrf   r(   r(   r)   r   >  s    zAudioStream.emptyc                 C   s4   | j o
| j}| jduo*| j  o*| j }|o2|S )z;Indicate if the audio stream has hungup and been processed.N)r   r   r   is_aliver   )rA   Z
input_doneZoutput_doner(   r(   r)   completeC  s    

zAudioStream.completec                 C   s   | j r| j  S dS )z&Indicate if the ASR stream is running.F)r   r   rf   r(   r(   r)   runningN  s    
zAudioStream.running)itemtimeoutr"   c                 C   sZ   | j @ | jrtd|tu r(| j  | jj||d W d   n1 sL0    Y  dS )zPut a new item into the queue.z?The audio stream has already been hungup. Cannot put more data.r   N)r   r   RuntimeErrorr   r   ri   r   put)rA   r   r   r(   r(   r)   r   U  s    
zAudioStream.putc                    s*   t  }t |d| j||I dH  dS )z$Async put a new item into the queue.N)re   r   wait_forr   r   )rA   r   r   loopr(   r(   r)   aput`  s    zAudioStream.aput)r   r"   c                 C   s   |  t| dS )zSend the hangup signal.N)r   r   rA   r   r(   r(   r)   closee  s    zAudioStream.closec                    s   |  t|I dH  dS )zAsync send the hangup signal.N)r   r   r   r(   r(   r)   aclosei  s    zAudioStream.aclosezrasr.StreamingRecognizeResponse)	responsesr"   c                    s^   j rtdtjddd dd fdd}tj|d	_d
j_j     dS )zIDrain the responses from the provided iterator and put them into a queue.z,An ASR instance has already been registered.rN   r   r   Nr!   c                     s       D ]x} | jsq| jD ]d}|js*q|jrdj  j  tt	|jd j
}j| qj sj  j  qqdS )zConsume the ASR Generator.r   N)rl   resultsZalternativesis_finalr   rj   r   ri   r   rR   
transcriptr   r   rk   )responseresultr   Zhas_startedr   rA   r(   r)   workert  s    




z$AudioStream.register.<locals>.worker)targetT)	r   r   rc   BarrierThreadr   daemonstartrl   )rA   r   r   r(   r   r)   registerm  s    
zAudioStream.register)r   )N)N)N)N) r,   r-   r.   r/   rc   r   rZ   r   r   r`   r   r   rE   rg   r   bytesr   r   StreamInputTyper   rF   rn   r   r   r   r   r   r   r   r   r	   r   r(   r(   r(   r)   r      s0   





r   c                   @   s   e Zd ZU dZdZeed< dZeed< edddZ	e
ed	< ed
ddZeed< ed
ddZeed< ed
deeeef eeef dddZeddddZddddZdeee eedddZdS )RivaASRzNA runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva.Znvidia_riva_asrry   zA Runnable for converting audio bytes to a string.This is useful for feeding an audio stream into a chain andpreprocessing that audio to create an LLM prompt.rH   r8   z7The number of audio channels in the input audio stream.rJ   audio_channel_countTz\Controls whether or not Riva should attempt to filter profanity out of the transcribed text.profanity_filterz]Controls whether Riva should attempt to correct senetence puncuation in the transcribed text.enable_automatic_punctuationrU   valuesr"   c                 C   s
   t  }|S z4Validate the Python environment and input arguments.r*   r=   r   _r(   r(   r)   _validate_environment  s    zRivaASR._validate_environmentz&riva.client.StreamingRecognitionConfigr!   c                 C   s4   t  }|jd|j| j| j| jd| j| j| jddS )z)Create and return the riva config object.Tr8   )r]   r^   r   Zmax_alternativesr   r   r_   )Zinterim_resultsconfig)	r*   ZStreamingRecognitionConfigZRecognitionConfigr]   r^   r   r   r   r_   r@   r(   r(   r)   r     s    zRivaASR.configzriva.client.ASRServicec              
   C   sH   t  }z|| jW S  tyB } ztd|W Y d}~n
d}~0 0 dS );Connect to the riva service and return the a client object.z5Error raised while connecting to the Riva ASR server.N)r*   Z
ASRServicerT   	Exception
ValueErrorrA   rB   r'   r(   r(   r)   _get_service  s    zRivaASR._get_serviceNinputr   kwargsr"   c                 K   s   |j s(|  }|j|| jd}|| g }|js|jj |jjd}W d   n1 s^0    Y  |r,|j	 sz||j
 g7 }W n tjy   Y qlY n0 |j  qltdt| d| S q,dS )z3Transcribe the audio bytes into a string with Riva.)Zaudio_chunksZstreaming_configg?NzRiva ASR returning: %s r   )r   r   Zstreaming_response_generatorr   r   r   r   	not_emptyrl   r   
get_nowaitr   r   r   _LOGGERdebugreprjoinstrip)rA   r   r   r   servicer   Zfull_responsereadyr(   r(   r)   invoke  s*    

,

zRivaASR.invoke)N)r,   r-   r.   r/   ry   rR   rZ   rH   r   r   rE   r   rn   r   r   rD   r   r   r   rF   r   r   ASRInputTyper   r   ASROutputTyper   r(   r(   r(   r)   r     s8   

$ r   c                   @   s  e Zd ZU dZdZeed< dZeed< edddZ	eed	< ed
ddZ
ee ed< eddeeeef eeef dddZedeeedddZddddZd eee eedddZd!ee ee ee ee dddZd"ee ee ee eed
f dddZd
S )#RivaTTSz?A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva.Znvidia_riva_ttsry   z_A tool for converting text to speech.This is useful for converting LLM output into audio bytes.rH   zEnglish-US.Female-1zThe voice model in Riva to use for speech. Pre-trained models are documented in [the Riva documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html).rJ   
voice_nameNzThe directory where all audio files should be saved. A null value indicates that wave files should not be saved. This is useful for debugging purposes.ro   Tr   r   c                 C   s
   t  }|S r   r   r   r(   r(   r)   r      s    zRivaTTS._validate_environment)vr"   c                 C   s,   |r(t |}|jddd t| S |S )NT)parentsexist_ok)pathlibPathmkdirrR   absolute)r=   r   dirpathr(   r(   r)   _output_directory_validator'  s
    
z#RivaTTS._output_directory_validatorz"riva.client.SpeechSynthesisServicer!   c              
   C   sH   t  }z|| jW S  tyB } ztd|W Y d}~n
d}~0 0 dS )r   z5Error raised while connecting to the Riva TTS server.N)r*   ZSpeechSynthesisServicerT   r   r   r   r(   r(   r)   r   0  s    zRivaTTS._get_servicer   c                 K   s   d | t|gS )zDPerform TTS by taking a string and outputting the entire audio file.    )r   	transformiter)rA   r   r   r   r(   r(   r)   r   :  s    zRivaTTS.invokec                 k   s   |   }t| j| j\}}t|D ]X}td| |j|| j| j	| j
j| jd}|D ]$}	tt|	j}
|rr||
 |
V  qTq"|r|  td| dS )zHPerform TTS by taking a stream of characters and streaming output bytes.zRiva TTS chunk: %s)textr   r_   r]   Zsample_rate_hzzRiva TTS wrote file: %sN)r   r   ro   r^   r   r   r   Zsynthesize_onliner   r_   r]   rC   r   r   audioZwriteframesrawr   )rA   r   r   r   r   r}   r~   r   r   respr   r(   r(   r)   r   C  s*    	

zRivaTTS.transformc           	        s   t  t t  ddfdd}tt dfddddfdd dd fd	d
}| }| }zt  dI dH }W n t j	j
y   Y qY n0   |tu rq|V  q|I dH  |I dH  dS )zGIntercept async transforms and route them to the synchronous transform.Nr!   c                     s,    2 z3 dH W }  |  q6  t dS )z#Produce input into the input queue.N)
put_nowait_TRANSFORM_ENDr   )r   input_queuer(   r)   	_produceru  s    z%RivaTTS.atransform.<locals>._producerc                  3   s@   z j dd} W n tjy(   Y q Y n0 | tkr4q<| V  q dS )zIterate over the input_queue.r   r   N)r   r   r   r   r   )r   r(   r)   _input_iterator{  s    
z+RivaTTS.atransform.<locals>._input_iteratorc                     s*      D ]} |  qt dS )z!Consume the input with transform.N)r   r   r   r   )r   	out_queuerA   r(   r)   	_consumer  s    z%RivaTTS.atransform.<locals>._consumerc                      s    d I dH  dS )z"Coroutine that wraps the consumer.N)r   r(   )r   r   r(   r)   _consumer_coro  s    z*RivaTTS.atransform.<locals>._consumer_coror   )re   get_running_loopr   r   r	   r   create_taskr   r   
exceptionsTimeoutErrorr   r   )	rA   r   r   r   r   r   ZproducerZconsumerrW   r(   )r   r   r   r   r   r   rA   r)   
atransformj  s&    

zRivaTTS.atransform)N)N)N)r,   r-   r.   r/   ry   rR   rZ   rH   r   r   ro   r   r   rD   r   r   r   r   r   r   r   r   TTSOutputTyper   r	   r   r   r   r  r(   r(   r(   r)   r      sP   

		$  * 
r   )Nr/   re   loggingr   r   rw   rc   rz   enumr   typingr   r   r   r   r   r   r	   r
   r   r   r   r   Zlangchain_core.messagesr   r   Zlangchain_core.prompt_valuesr   Zlangchain_core.runnablesr   r   Zpydanticr   r   r   r   r   r   r#   r&   Zriva.client.proto.riva_asr_pb2r%   protoZriva_asr_pb2Zrasr	getLoggerr,   r   r   r   r   r*   r+   r   r   rR   r0   rG   r[   r`   floatZ
Wave_writer   r   r   r   r   ZStreamOutputTyper   r   r   r   r   r  r   ZNVIDIARivaASRZNVIDIARivaTTSZNVIDIARivaStreamr(   r(   r(   r)   <module>   sn   8 	
'!# 

h

 $