#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Version 1.0 2017.03.12

# ***************************************************************************
# *   Copyright (C) 2017, Paul Lutus                                        *
# *                                                                         *
# *   This program is free software; you can redistribute it and/or modify  *
# *   it under the terms of the GNU General Public License as published by  *
# *   the Free Software Foundation; either version 2 of the License, or     *
# *   (at your option) any later version.                                   *
# *                                                                         *
# *   This program is distributed in the hope that it will be useful,       *
# *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
# *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
# *   GNU General Public License for more details.                          *
# *                                                                         *
# *   You should have received a copy of the GNU General Public License     *
# *   along with this program; if not, write to the                         *
# *   Free Software Foundation, Inc.,                                       *
# *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
# ***************************************************************************

import os
import sys
import re
import subprocess
import time
import glob
import tempfile
import threading
import socket
import platform
import signal
import ast

class BlenderNetworkRender:
  """
  This class creates parallel threads to render Blender images
  on multiple processors/machines.
  """

  def __init__(self):

    self.debug = False

    self.verbose = True # more messages printed on console

    # for a longer but more accurate speed test,
    # increase the size of this number either
    # here or on the command line

    self.speed_test_repeat_count = 5e8

    # seconds between console updates during render

    self.time_delay = 5

    # windows_system flag, identifies a windows system, changes many things

    self.windows_system = re.search('(?i)windows', platform.system()) != None

    # used in resource searches

    if self.windows_system:
      self.search_str = r'cd "C:\Program Files" && dir /b %s.exe /s'
    else:
      self.search_str = 'which %s'


    self.call_suffix = ('> /dev/null 2>&1', '')[self.windows_system]

    self.null_file_handle = open(os.devnull, "w")

    self.running = True

    # trap Ctrl+C and other exit strategies

    signal.signal(signal.SIGINT, self.exit_handler)
    signal.signal(signal.SIGTERM, self.exit_handler)

    self.hostname = socket.gethostname()

    self.temp_dir = tempfile.gettempdir()

    self.program_name = re.sub(r'(.*[/\\])?(.*)\..*', r'\2', sys.argv[0])

    self.home_path = os.path.expanduser("~")

    self.init_file_path = '%s.ini' % os.path.join(self.home_path, '.' + self.program_name)

    self.python_temp_path = os.path.join(self.temp_dir, 'blender_temp.py')
    self.blender_temp_path = os.path.join(self.temp_dir, 'blender_temp.blend')
    self.graphic_temp_path = os.path.join(self.temp_dir, 'blender_temp.png')

    self.speed_test = False

    self.config = {
      'machines' : [],
      'allocations' : [],
      'coms' : {'blender':'', 'rm':'', 'cp':'', 'ssh':'', 'scp':'', 'convert':''}
    }

    # Python time test script
    # a simple test like this is carried out on only one core
    # so divide the test count by the number of cores

    self.speed_test_script = """
import multiprocessing, math

count = %d // multiprocessing.cpu_count()

for n in range(count):
  x = math.sqrt(2)
    """

    # Blender Python render script

    self.render_script = """
import bpy, sys
n = sys.argv.index('--')
bcsr = bpy.context.scene.render
bcsr.use_border = True
bcsr.border_min_x = float(sys.argv[n+1])
bcsr.border_max_x = float(sys.argv[n+2])
bcsr.border_min_y = 0
bcsr.border_max_y = 1
bcsr.filepath = sys.argv[n+3]
bpy.ops.render.render(write_still=True)
    """

    self.start_time = {}
    self.duration_time = {}

  # end of __init__()

  @staticmethod
  def option_function(array, index):
    """
    allow SSH relaxed security, for testing only
    """
    for n, item in enumerate(('-o', 'StrictHostKeyChecking=no')):
      array.insert(index+n, item)

  def verbose_print(self, s):
    if self.verbose:
      print(s)

  @staticmethod
  def write_file(path, data):
    with open(path, 'w') as f:
      f.write(data)

  @staticmethod
  def read_file(path):
    with open(path) as f:
      return f.read()

  # read and write program initialization file

  def write_config_file(self, prompt=False):
    """
    write plain-text configuration file
    as a string-formatted dict
    """
    if prompt:
      self.verbose_print('Writing system configuration file.')
    self.write_file(self.init_file_path, '%s\n' % str(self.config))

  def read_config_file(self, prompt=False):
    """
    read plain-text configuration file
    and convert into a dict
    """
    if os.path.exists(self.init_file_path):
      if prompt:
        self.verbose_print('Reading system configuration file.')
      data = self.read_file(self.init_file_path)
      self.config = ast.literal_eval(data)

  @staticmethod
  def sys_write(s):
    """
    a convenient way to emit
    characters without linefeeds
    and make them visible immediately
    """
    sys.stdout.write(s)
    sys.stdout.flush()

  def test_ssh_login(self):
    """
    perform SSH login on a list of machines
    as a test of public-key authentication
    """
    for machine in self.config['machines']:
      com = [
        self.config['coms']['ssh'],
        machine,
        'echo',
        '%s: OK.' % machine,
      ]
      subprocess.call(com, True)


  def available(self, s, machine=None):
    """
    test for existence of a named program
    either locally or on a network machine
    """
    if machine is not None:
      com = '%s %s %s %s' % \
      (
        self.config['coms']['ssh'],
        machine,
        self.search_str % s,
        self.call_suffix
      )
    else:
      com = '%s %s' % (self.search_str % s, self.call_suffix)
    (result, _) = subprocess.getstatusoutput(com)
    return result == 0


  @staticmethod
  def sec_to_hms(t):
    """
    convert floating-point seconds
    into d hh:mm:ss.ss string
    """
    sign = (' ', '-')[t < 0]
    t = abs(t)
    sec = t % 60
    t //= 60
    minute = t % 60
    t //= 60
    hour = t % 24
    day = t // 24
    return '%s%dd %02d:%02d:%05.2f' % (sign, day, hour, minute, sec)

  def list_machine_times(self):
    """
    display previously computed
    network machine time allocations
    """
    for n, machine in enumerate(self.config['machines']):
      print(
        '  %-30s : %.4f' %
        (
          machine,
          self.config['allocations'][n])
        )

  def system_call(self, com, use_shell=False):
    """
    use subprocess.call() to execute a system program
    and return status (0 = no error)
    """
    if not self.debug:
      result = subprocess.call(
        com,
        stdout=self.null_file_handle,
        stderr=self.null_file_handle,
        shell=use_shell
      )
      if result != 0:
        print('Error executing system call %s.' % com[0])
        print('To learn more, run with debug flag "-d".')
    else:
      print('*** system call: [%s]' % com)
      result = subprocess.call(
        com,
        shell=use_shell
      )
      print('*** call outcome: %s error' % ('no', '')[result != 0])
    return result

  @staticmethod
  def normalize_list(lst):
    """
    normalize a numeric list so sum of all values = 1
    """
    total = sum(lst)
    return [x/total for x in lst]

  # manage network machine time allocations

  def reset_allocations(self):
    """
    set default machine allocations so each
    machine gets 1/n of the processing time
    """
    size = len(self.config['machines'])
    self.config['allocations'] = [1/size for m in self.config['machines']]
    self.verbose_print('Machine speed allocations set to defaults.')

  def test_allocation_consistency(self):
    """
    make sure the lists of machines and allocations
    are the same size otherwize reset the allocations
    """
    if len(self.config['allocations']) != len(self.config['machines']):
      self.reset_allocations()

  def remove_temps(self):
    """
    remove temporary files from remote and local machines
    """
    for machine in self.config['machines']:
      com = [
        self.config['coms']['ssh'],
        machine,
        self.config['coms']['rm'],
        self.python_temp_path,
        self.blender_temp_path,
        self.graphic_temp_path
      ]
      # option_function(com, 1)
      try:
        self.system_call(com)
      except:
        None
    com = [
      self.config['coms']['rm'],
      os.path.join(self.temp_dir, 'blender_strip0*.png')
    ]
    try:
      self.system_call(com, True)
    except:
      None


  def find_local_machines(self):
    """
    use Avahi/Zeroconf to locate Blender-equipped network servers
    this only works on Linux at the moment
    """
    self.sys_write('Searching local network for Blender-equipped machines ...')
    found = []
    if not self.available('avahi-browse'):
      print('\nTo use this feature, your system must have Avahi/Zeroconf installed.')
    else:
      data = subprocess.getoutput('avahi-browse _ssh._tcp -tr')
      for line in data.split('\n'):
        self.sys_write('.')
        if re.search('hostname = ', line):
          sysname = re.sub(r'(?i).*hostname.*?\[(.*?)\].*', r'\1', line)
          # if the system has blender installed
          if self.available('blender', sysname):
            found.append(sysname)
      # remove duplicates and sort
      found = sorted(list(set(found)))
      print('\nLocated machines:')
      for machine in found:
        print('  %s' % machine)
      self.config['machines'] = found
      self.reset_allocations()
      self.write_config_file(True)

  def locate_system_resources(self, force=False):
    """
    create a dict of generic command names
    and platform-dependent definitions
    """
    changed = False
    for name in self.config['coms']:
      if len(self.config['coms'][name]) == 0 or force:
        changed = True
        target = name
        output = None
        status = 0
        # some Windows programs have different names
        if self.windows_system:
          # on Windows, 'rm' = 'del' builtin
          if name == 'rm':
            output = 'del'
          # on Windows, 'cp = 'copy' builtin
          if name == 'cp':
            output = 'copy'
          # the ImageMagick 'convert' program has a different name on Windows
          if name == 'convert':
            target = 'magick'
        if output is None:
          (status, output) = subprocess.getstatusoutput(self.search_str % target)
        if status == 0:
          self.verbose_print('Located resource "%s" = "%s"' % (name, output))
          self.config['coms'][name] = output
        else:
          self.verbose_print('Error: cannot locate resource "%s"' % name)
    if changed:
      self.write_config_file(True)

  def copy_identities(self):
    """
    copy Secure Shell identities to local network machines
    for use in public key authentication, using different
    methods for Linux and Windows
    """
    path = os.path.join(self.home_path, '.ssh')
    if not os.path.exists(path):
      print('This system is not configured for Secure Shell.')
      return
    else:
      done = False
      while not done:
        lst = os.listdir(path)
        lst = [re.sub(r'\.pub', '', s) for s in lst if re.search(r'\.pub$', s)]
        if lst is None or len(lst) == 0:
          print('This system has no defined Secure Shell keys.')
          return
        print('Choose identity key(s):')
        n = 0
        for n, item in enumerate(lst):
          print('  %d: %s' % (n, item))
        print('  %d: All of them' % (n+1))
        result = input("Select:")
        try:
          x = int(result)
          assert x < len(lst) + 1 and x >= 0
          done = True
        except:
          print('Please enter a number in the range (0-%d).' % (len(lst)))
          continue
        if x == len(lst):
          idx = range(0, x)
        else:
          idx = range(x, x+1)
        for n in idx:
          kpath = os.path.join(path, lst[n] + '.pub')
          data = self.read_file(kpath).strip()
          print('Processing key %d : %s' % (n, lst[n]))
          auth = os.path.join(path, 'authorized_keys')
          for machine in self.config['machines']:
            print('  Writing to %s ...' % machine)
            if self.windows_system:
              com = [
                self.config['coms']['ssh'],
                machine,
                '"',
                'powershell write-host \"%s`r`n\" >> \"%s\"' % (data, auth),
                '"',
              ]
              self.system_call(com)
            else:
              com = [
                self.config['coms']['ssh'],
                machine,
                'echo',
                '"%s"' % data,
                '>>',
                auth
              ]
              self.system_call(com)


  # begin script processing routines


  def threaded_call(
      self, machine,
      blender_script_path,
      apos, bpos
    ):
    """
    launched in a thread, this function
    runs an instance of Blender on a network machine
    as one of a set of a concurrent threads
    """
    com1 = [
      '%s' % self.config['coms']['ssh'],
      #'-v', # testing only
      '%s' % machine
    ]
    com2 = [
      '%s' % self.config['coms']['blender'],
      '--background'
    ]
    if len(blender_script_path) > 0:
      com2 += ['%s' % blender_script_path]
    com2 += [
      '--python',
      '%s' % self.python_temp_path,
      # '--' prevents blender from reading
      # more arguments from the command line
      '--',
      # strip start and end numerical values
      '%.4f' % apos, '%.4f' % bpos,
      '%s' % self.graphic_temp_path
    ]
    # option_function(com1, 1)
    if self.windows_system:
      for n, item in enumerate(com2):
        com2[n] = '\"' + item + '\"'
      com2 = ['"'] + com2 + ['"']
    com = com1 + com2
    self.start_time[machine] = time.time()
    self.system_call(com)
    self.duration_time[machine] = time.time() - self.start_time[machine]

  def launch_threads(self, blender_script_path, blender_string):
    """
    create a network thread for each participating machine
    """
    apos = 0
    for n, machine in enumerate(self.config['machines']):
      bpos = apos + self.config['allocations'][n]
      self.verbose_print(
        'Rendering strip %d of %s on %s ...'
        % (n, blender_string, machine)
      )
      if re.search(self.hostname, machine) is None:
        com = [
          self.config['coms']['scp'],
          self.python_temp_path,
          self.blender_temp_path,
          '%s:%s/' % (machine, self.temp_dir)
        ]
        # option_function(com, 1)
        self.system_call(com)
      call_args = (
        machine,
        blender_script_path,
        apos, bpos
      )
      # start network concurrent thread for one image strip
      thread = threading.Thread(
        target=self.threaded_call,
        args=call_args
      )
      thread.start()
      apos = bpos

  def monitor_renders(self, process_start, blender_string):
    """
    monitor running threads for completion
    which is signaled by the machine's presence
    in self.duration_time[macnine]
    """
    busy = True
    count = 0
    while busy and self.running:
      busy = False
      active = []
      for machine in self.config['machines']:
        # an entry in "duration_time" signals completion
        if machine not in self.duration_time:
          busy = True
          active.append(machine)
      # if some mechines are still rendering
      if len(active) > 0 and count % self.time_delay == 0:
        time_str = self.sec_to_hms(time.time() - process_start)
        self.verbose_print('%s : %s : %s' % (time_str, blender_string, ', '.join(active)))
      time.sleep(1)
      count += 1

  def create_result_image(self, total, blender_name):
    """
    combine generated image strips and create output image
    """
    size = len(self.config['machines'])
    if self.verbose:
      # get average render time
      mean = total / size
      print('Network Rendering Times:')
      for machine in self.config['machines']:
        duration = self.duration_time[machine]
        delta = (duration - mean) / mean
        print(
          '  %-30s : %s (deviation from mean %+7.2f%%)'
          % (machine, self.sec_to_hms(duration), delta * 100)
        )

    # copy image strips to local machine
    image_paths = []
    for n, machine in enumerate(self.config['machines']):
      image_path = '%s/blender_strip%03d.png' % (self.temp_dir, n)
      image_paths.append(image_path)
      self.verbose_print('Copying strip %d from %s ...' % (n, machine))
      com = [
        self.config['coms']['scp'],
        '%s:%s' % (machine, self.graphic_temp_path),
        image_path
      ]
      # option_function(com, 1)
      self.system_call(com)

    # combine strips -- this method requires that the
    # Blender images be specified as RGBA PNG
    self.verbose_print(
      'Creating result image "%s.png" from strips ...'
      % (blender_name)
    )
    com = [self.config['coms']['convert']]
    com += image_paths
    com += [
      '-flatten',
      '%s.png' % blender_name
    ]
    self.system_call(com)


  def process_script(self, blender_script):
    """
    This is the main Blender definition file
    processing function -- it reads the definition,
    distributes it across the network, and
    launches a thread for each machine
    all of which run in parallel
    """
    if len(self.config['machines']) == 0:
      print('Error: no network machines defined.')
      print('Use option \"-m\" to enter machine names.')
      if not self.windows_system:
        print('Or option \"-f\" to search for them.')
      return

    if not self.running:
      return

    # save the Python and Blender scripts as local files

    if self.speed_test and len(blender_script) == 0:
      self.write_file(self.python_temp_path, self.speed_test_script % self.speed_test_repeat_count)
    else:
      self.write_file(self.python_temp_path, self.render_script)
    if not self.windows_system:
      os.system('chmod +x "%s"' % self.python_temp_path)

    if len(blender_script) > 0:
      com = [
        self.config['coms']['cp'],
        blender_script,
        self.blender_temp_path
      ]
      self.system_call(com, self.windows_system)
      blender_string = blender_script
    else:
      self.write_file(self.blender_temp_path, '') # write empty placeholder file
      blender_string = '(network speed test)'

    if os.path.exists(self.init_file_path) and not self.speed_test:
      self.verbose_print(
        'Machine speed allocations read from %s.'
        % self.init_file_path
      )
    else:
      self.reset_allocations()
      self.test_allocation_consistency()

    if self.verbose:
      print('Current machine speed allocations:')
      self.list_machine_times()

    blender_name = re.sub(r'(.*?)\.blend', r'\1', blender_script)

    process_start = time.time()

    self.start_time.clear()
    self.duration_time.clear()

    # set up network render processes

    if len(blender_script) == 0:
      blender_script_path = ''
    else:
      blender_script_path = self.blender_temp_path

    self.launch_threads(blender_script_path, blender_string)

    if not self.running:
      return

    # now monitor the renders and wait for all to finish

    self.monitor_renders(process_start, blender_string)

    self.verbose_print('Done rendering strips for %s.' % blender_string)
    total = sum(self.duration_time.values())

    if not self.running:
      return

    if self.speed_test:
      # network speed test actions
      allocation = [total / self.duration_time[machine] \
        for machine in self.config['machines']]
      allocation = self.normalize_list(allocation)
      self.config['allocations'] = allocation
      if self.verbose:
        print('Speed test machine running times:')
        for machine in self.config['machines']:
          print(
            '  %-30s : %s'
            % (machine, self.sec_to_hms(self.duration_time[machine]))
          )
        print('Resulting machine speed allocations:')
        self.list_machine_times()
      self.write_config_file(True)
      self.verbose_print(
        'Allocations saved to %s.'
        % self.init_file_path
      )

    else: # not speed test
      self.create_result_image(total, blender_name)

    process_end = time.time()
    self.verbose_print(
      'File: %s : elapsed time: %s'
      % (blender_string, self.sec_to_hms(process_end - process_start))
    )

  # end script processing function

  def exit_handler(self, signum, frame):
    """
    a program exit handler that (on Linux)
    closes orphaned network processes
    """
    #pylint: disable=unused-argument
    if self.running:
      self.running = False
      if not self.windows_system:
        print('\nUser interrupt.')
        print('Terminating remote processes:')
        for machine in self.config['machines']:
          data = subprocess.getoutput(
            'ssh %s ps x'
            % machine
          )
          for line in data.split('\n'):
            if re.search(self.python_temp_path, line):
              arr = re.split(r'\s+', line.strip())
              print(
                '  Terminating process %6s on %s'
                % (arr[0], machine)
              )
              os.system('ssh %s kill %s' % (machine, arr[0]))
            else:
              None
    print('Exiting.')
    quit()

  def show_help(self):
    """
    provide the user with some sketchy information
    about the program's options
    """
    print('Usage: -m [list of participating machine URIs] define participating machines')
    if not self.windows_system:
      print(
        '       -f find local-network Blender-equipped '
        + 'machines using Avahi/Zeroconf (Linux only)'
      )
    print('       -l locate system applications and resources')
    print(
      '       -s [Blender filename, integer or use default of %.0e] network speed test'
      % self.speed_test_repeat_count
    )
    print('       -r reset machine speed allocations to default values')
    print('       -c copy SSH identities to network machines')
    print('       -t test SSH login on participating machines')
    print('       -q quiet')
    print('       Blender script(s)')


  def main(self):
    """
    default class entry point --
    read configutation file,
    read and execute user commands
    """
    if not self.available('ssh'):
      print('Error: must have Secure Shell installed, exiting.')
      quit()

    self.read_config_file()

    self.locate_system_resources()

    sys.argv.pop(0) # discard program path and name

    if len(sys.argv) == 0:
      self.show_help()
      quit()

    speed_numeric = False

    machine_names = False

    filenames = []

    n = 0
    while n < len(sys.argv):
      arg = sys.argv[n]
      if speed_numeric:
        speed_numeric = False
        try:
          # this looks weird, but it's needed so one
          # can enter '1e9' instead of nine zeros
          self.speed_test_repeat_count = int(float(arg))
        except:
          n -= 1
      elif machine_names:
        machine_names = False
        names = []
        while n < len(sys.argv):
          if sys.argv[n][0] == '-':
            n -= 1
            break
          else:
            names.append(sys.argv[n])
            n += 1
        self.config['machines'] = names
        print('Defined machines: %s' % names)
        self.reset_allocations()
        self.write_config_file()
      elif arg == '-m':
        machine_names = True
      elif arg == '-f' and not self.windows_system:
        self.find_local_machines()
      elif arg == '-l':
        self.locate_system_resources(True)
      elif arg == '-s':
        speed_numeric = True
        self.speed_test = True
      elif arg == '-r':
        self.reset_allocations()
      elif arg == '-c':
        self.copy_identities()
      elif arg == '-t':
        self.test_ssh_login()
      elif arg == '-d':
        self.debug = True
      elif arg == '-q':
        self.verbose = False
      else:
        # windows doesn't expand wildcard arguments
        gl_str = glob.glob(arg)
        if len(gl_str) == 0:
          filenames.append(arg)
        else:
          for item in gl_str:
            filenames.append(item)
      n += 1

    if len(filenames) == 0 and self.speed_test:
      filenames = ['']

    if len(filenames) > 0:

      process_start = time.time()

      action = ('processing', '(network speed test)')[self.speed_test]

      for blender_script in filenames:
        self.verbose_print(
          '*** Begin %s %s ***'
          % (action, blender_script)
        )
        if len(blender_script) > 0 \
          and not os.path.exists(blender_script):
          print(
            'Blender script %s doesn\'t exist, skipping.'
            % blender_script
          )
        else:
          self.process_script(blender_script)
          self.speed_test = False
        self.verbose_print('*** End %s %s ***' % (action, blender_script))
        if not self.running:
          break

      self.remove_temps()

      self.null_file_handle.close()

      process_end = time.time()

      delta_time = self.sec_to_hms(process_end - process_start)

      print('Total elapsed time: %s' % delta_time)

  # end of BlenderNetworkRender class definition

# if not called as a module, then run main()
if __name__ == '__main__':
  BlenderNetworkRender().main()
