"""
@author: buchhornm
@author: daemsd

This module logs memory usage of pyspark executors.
"""
import os
import json
from threading import Timer
import datetime
import psutil
import pytz


class MemLogger(object):

    def __init__(self,
                 interval,
                 kafka_brokers='epod-master1.vgt.vito.be:6668,epod-master2.vgt.vito.be:6668,epod-master3.vgt.vito.be:6668',
                 kafka_topic='spark-memlogs',
                 **kafka_client_configs):

        self._timer = None
        self.interval = interval
        self.is_running = False
        self.pid = self.get_pid()

        self.kafka_brokers = kafka_brokers
        self.kafka_client_configs = kafka_client_configs

        self.kafka_topic = kafka_topic
        self.host = self.get_host()
        self.application_id = self.get_application_id()
        self.user_id = self.get_user()

    # Topic writing in kafka need a list of bytes
    # in python 3.x, the function is bytes(source,'utf-8')
    # while bytes(source) is required in python 2.x
    #
    try:
        THE_BYTES = bytes('xxx')
    except:
        @staticmethod
        def to_topic(source):
            """ return the string as a byte encoded list of characters for python 3.x"""
            return bytes(source, 'utf-8')

    else:
        @staticmethod
        def to_topic(source):
            """ return the string as a byte encoded list of characters for python 2.x"""
            return bytes(source)

    @staticmethod
    def get_pid():
        """ Get the pid from the environment """
        if "JVM_PID" in os.environ:
            # use the JVM pid
            pid = int(os.environ["JVM_PID"])
        else:
            # otherwise get the pid of the current python job
            pid = os.getpid()
        return pid

    @staticmethod
    def get_application_id():
        """ Get Yarn application ID from environment """
        if "HADOOP_TOKEN_FILE_LOCATION" in os.environ:
            split_info = os.environ["HADOOP_TOKEN_FILE_LOCATION"].split("/")
            for split in split_info:
                if split.startswith("application_"):
                    return split
        return "UNKNOWN"

    @staticmethod
    def get_user():
        """ Get username from environment """
        import getpass
        if "USER" in os.environ:
            username = os.environ["USER"]
        else:
            username = getpass.getuser()
        return username

    @staticmethod
    def get_host():
        """ Get hostname from environment """
        import socket
        host_addr = socket.gethostname()
        if (host_addr is not None) and ("." in host_addr):
            return host_addr
        the_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            # doesn't even have to be reachable
            the_socket.connect(('10.255.255.255', 1))
            host_addr = the_socket.getsockname()[0]
        except Exception:
            host_addr = '127.0.0.1'
        finally:
            the_socket.close()
        return host_addr

    def _run(self):
        self._schedule_timer()
        self.log_mem(self.pid)

    def _schedule_timer(self):
        self._timer = Timer(self.interval, self._run)
        self._timer.start()

    def start(self):
        if not self.is_running:
            from kafka import KafkaProducer
            self.producer = KafkaProducer(bootstrap_servers=self.kafka_brokers,
                                          **self.kafka_client_configs)
            self.is_running = True
            self._schedule_timer()


    def stop(self):
        if self.is_running:
            self._timer.cancel()
            self.is_running = False
            if self.producer is not None:
                self.producer.close()
            else:
                self.producer = None

    def log_mem(self, pid):
        """ Log memory usage to Kafka topic. Uses psutil module to get the memory. """
        try:
            # get the parent pid of the JMV pid
            parent = psutil.Process(pid).parent()

            # now get the memory of parent and all children
            rss_mem = parent.memory_info().rss
            for child in parent.children(recursive=True):
                rss_mem = rss_mem + child.memory_info().rss

            try:
                job_info = {'job_id': self.get_application_id()}
                job_info.update({'user': self.get_user()})
                job_info.update({'host': self.get_host()})
                job_info.update({'mem_rss_mb': str(rss_mem >> 20)})
                job_info.update({'time': datetime.datetime.now(tz=pytz.utc).isoformat()})

                self.producer.send(self.kafka_topic,
                                   self.to_topic(json.dumps(job_info)))
                self.producer.flush()

            except (Exception) as ex:
                print("Unable to write to Kafka topic: " + str(ex))

        except psutil.AccessDenied:
            pass
