# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Simplistic RPC implementation.
Exposes all functions of a Server object.

This code is for demonstration purposes only, and does not include certain
security protections. It is not meant to be run on an untrusted network or
in a production environment.
"""

import importlib
import os
import pickle
import sys
import _thread
import traceback
import socket
import logging

LOG = logging.getLogger(__name__)

# default
PORT = 12032

safe_modules = {
    'numpy',
    'numpy.core.multiarray',
}


class RestrictedUnpickler(pickle.Unpickler):

    def find_class(self, module, name):
        # Only allow safe modules.
        if module in safe_modules:
            return getattr(importlib.import_module(module), name)
        # Forbid everything else.
        raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
                                     (module, name))


class FileSock:
    " wraps a socket so that it is usable by pickle/cPickle "

    def __init__(self,sock):
        self.sock = sock
        self.nr=0

    def write(self, buf):
        # print("sending %d bytes"%len(buf))
        #self.sock.sendall(buf)
        # print("...done")
        bs = 512 * 1024
        ns = 0
        while ns < len(buf):
            sent = self.sock.send(buf[ns:ns + bs])
            ns += sent

    def read(self,bs=512*1024):
        #if self.nr==10000: pdb.set_trace()
        self.nr+=1
        # print("read bs=%d"%bs)
        b = []
        nb = 0
        while len(b)<bs:
            # print('   loop')
            rb = self.sock.recv(bs - nb)
            if not rb: break
            b.append(rb)
            nb += len(rb)
        return b''.join(b)

    def readline(self):
        # print("readline!")
        """may be optimized..."""
        s=bytes()
        while True:
            c=self.read(1)
            s+=c
        if len(c)==0 or chr(c[0])=='\n':
            return s

class ClientExit(Exception):
    pass

class ServerException(Exception):
    pass


class Server:
    """
    server protocol. Methods from classes that subclass Server can be called
    transparently from a client
    """

    def __init__(self, s, logf=sys.stderr, log_prefix=''):
        self.logf = logf
        self.log_prefix = log_prefix

        # connection

        self.conn = s
        self.fs = FileSock(s)


    def log(self, s):
        self.logf.write("Sever log %s: %s\n" % (self.log_prefix, s))

    def one_function(self):
        """
        Executes a single function with associated I/O.
        Protocol:
        - the arguments and results are serialized with the pickle protocol
        - client sends : (fname,args)
            fname = method name to call
            args = tuple of arguments
        - server sends result: (rid,st,ret)
            rid = request id
            st = None, or exception if there was during execution
            ret = return value or None if st!=None
        """

        try:
            (fname, args) = RestrictedUnpickler(self.fs).load()
        except EOFError:
            raise ClientExit("read args")
        self.log("executing method %s"%(fname))
        st = None
        ret = None
        try:
            f=getattr(self,fname)
        except AttributeError:
            st = AttributeError("unknown method "+fname)
            self.log("unknown method")

        try:
            ret = f(*args)
        except Exception as e:
            # due to a bug (in mod_python?), ServerException cannot be
            # unpickled, so send the string and make the exception on the client side

            #st=ServerException(
            #  "".join(traceback.format_tb(sys.exc_info()[2]))+
            #  str(e))
            st="".join(traceback.format_tb(sys.exc_info()[2]))+str(e)
            self.log("exception in method")
            traceback.print_exc(50,self.logf)
            self.logf.flush()

        LOG.info("return")
        try:
            pickle.dump((st ,ret), self.fs, protocol=4)
        except EOFError:
            raise ClientExit("function return")

    def exec_loop(self):
        """ main execution loop. Loops and handles exit states"""

        self.log("in exec_loop")
        try:
            while True:
                self.one_function()
        except ClientExit as e:
            self.log("ClientExit %s"%e)
        except socket.error as e:
            self.log("socket error %s"%e)
            traceback.print_exc(50,self.logf)
        except EOFError:
            self.log("EOF during communication")
            traceback.print_exc(50,self.logf)
        except BaseException:
            # unexpected
            traceback.print_exc(50,sys.stderr)
            sys.exit(1)

        LOG.info("exit sever")

    def exec_loop_cleanup(self):
        pass

    ###################################################################
    # spying stuff

    def get_ps_stats(self):
        ret=''
        f=os.popen("echo ============ `hostname` uptime:; uptime;"+
                   "echo ============ self:; "+
                   "ps -p %d -o pid,vsize,rss,%%cpu,nlwp,psr; "%os.getpid()+
                   "echo ============ run queue:;"+
                   "ps ar -o user,pid,%cpu,%mem,ni,nlwp,psr,vsz,rss,cputime,command")
        for l in f:
            ret+=l
        return ret

class Client:
    """
    Methods of the server object can be called transparently. Exceptions are
    re-raised.
    """
    def __init__(self, HOST, port=PORT, v6=False):
        socktype = socket.AF_INET6 if v6 else socket.AF_INET

        sock = socket.socket(socktype, socket.SOCK_STREAM)
        LOG.info("connecting to %s:%d, socket type: %s", HOST, port, socktype)
        sock.connect((HOST, port))
        self.sock = sock
        self.fs = FileSock(sock)

    def generic_fun(self, fname, args):
        # int "gen fun",fname
        pickle.dump((fname, args), self.fs, protocol=4)
        return self.get_result()

    def get_result(self):
        (st, ret) = RestrictedUnpickler(self.fs).load()
        if st!=None:
            raise ServerException(st)
        else:
            return ret

    def __getattr__(self,name):
        return lambda *x: self.generic_fun(name,x)


def run_server(new_handler, port=PORT, report_to_file=None, v6=False):

    HOST = ''                 # Symbolic name meaning the local host
    socktype = socket.AF_INET6 if v6 else socket.AF_INET
    s = socket.socket(socktype, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    LOG.info("bind %s:%d", HOST, port)
    s.bind((HOST, port))
    s.listen(5)

    LOG.info("accepting connections")
    if report_to_file is not None:
        LOG.info('storing host+port in %s', report_to_file)
        open(report_to_file, 'w').write('%s:%d ' % (socket.gethostname(), port))

    while True:
        try:
            conn, addr = s.accept()
        except socket.error as e:
            if e[1]=='Interrupted system call': continue
            raise

        LOG.info('Connected to %s', addr)

        ibs = new_handler(conn)

        tid = _thread.start_new_thread(ibs.exec_loop,())

        LOG.debug("Thread ID: %d", tid)
