# -*- coding: utf-8 -*-

from __future__ import division, print_function
import os
from libtbx.utils import Sorry
from libtbx import group_args
try:
  from phenix.program_template import ProgramTemplate
except ImportError:
  from libtbx.program_template import ProgramTemplate

# =============================================================================
class Program(ProgramTemplate):
  description = '''
Replace values in B-factor field with estimated B values.
Optionally remove low-confidence residues and split into domains.

Inputs: Model file (PDB, mmCIF)
'''

  datatypes = ['phil', 'model', 'sequence']

  master_phil_str = """

  job_title = None
    .type = str
    .input_size = 400
    .help = Job title in PHENIX GUI, not used on command line
    .style = noauto bold

  input_files {
    chain_id = None
      .type = str
      .short_caption = chain_id
      .help = If specified, find domains in this chain only. NOTE: only one \
                chain can be used for finding domains at a time.
      .input_size = 400

    selection = None
      .type = str
      .short_caption = Selection
      .help = If specified, use only selected part of model
       .input_size = 400

    pae_file = None
      .type = path
      .help = Optional input json file with matrix of inter-residue estimated\
               errors (pae file)
      .short_caption = PAE file

    distance_model_file = None
      .type = path
      .help = Distance_model_file. A PDB or mmCIF file containing the model \
               corresponding to the PAE  \
               matrix. Only needed if weight_by_ca_ca_distances is True and \
               you want to specify a file other than your model file. (Default\
               is to use the model file)
      .short_caption = Distance model file

    model = None
      .type = path
      .help = Input predicted model (e.g., AlphaFold model).  Assumed to \
                have LDDT values in B-value field (or RMSD values).
      .style = file_type:pdb input_file
      .short_caption = Predicted model
      .expert_level = 3

  }
  output_files {
    processed_model_prefix = None
      .type = str
      .help = Output file with processed models will begin with this prefix.\
               If not specified, the input model file name will be used with\
               the suffix _processed.
      .short_caption = Output file prefix (optional)

    remainder_seq_file_prefix = None
      .type = str
      .help = Output file with sequences of deleted parts of model \
          will begin with this prefix
      .short_caption = Output remainder seq file prefix

     maximum_output_b = 999.
       .type = float
       .help = Limit output B values (so that they fit in old-style PDB \
              format). Note that this limit is applied just before writing \
              output model files, it does not affect anything else. Also \
              if output model is trimmed normally, high-B value residues are\
              already removed, so in most cases this keyword has no effect.
       .short_caption = Maximum output B

     remove_hydrogen = True
       .type = bool
       .help = Remove hydrogen atoms from model on input
       .short_caption = Remove hydrogen

     single_letter_chain_ids = False
       .type = bool
       .help = Write output files with all chain IDS as single characters.\
                Default is to use original chain ID and to add digits (1-9)\
                for domains.
       .short_caption = Use only single-letter chain ID

  }

  include scope mmtbx.process_predicted_model.master_phil_str
  control {

    write_files = True
      .type = bool
      .help = Write output files
      .short_caption = Write output files
   }


  gui
    .help = "GUI-specific parameter required for output directory"
  {
    output_dir = None
    .type = path
    .style = output_dir
  }


  """

  def run(self):

    # print version and date
    # self.print_version_info()
    self.data_manager.set_overwrite(True)

    #
    self.get_data_inputs()  # get any file-based information
    # self.print_params()
    self.starting_model = self.model
    self.model_list = []
    self.processed_model = None
    self.processed_model_text = None
    self.processed_model_file_name = None
    self.processed_model_file_name_list = []

    from mmtbx.process_predicted_model import process_predicted_model
    info = process_predicted_model(model = self.model,
     distance_model = self.distance_model,
     params = self.params,
     pae_matrix = self.pae_matrix,
     log = self.logger,
       )

    if not info:
      print("Unable to process predicted model", file = self.logger)
      return

    self.model_list = info.model_list
    self.processed_model = info.model
    self.dock_and_rebuild = group_args(
      group_args_type = 'dummy dock_and_rebuild for summary',
      processed_model = self.processed_model)

    if not self.params.control.write_files:
      return  # done

    starting_residues = self.model.get_hierarchy().overall_counts().n_residues
    print("\nStarting residues: %s" %(starting_residues), file = self.logger)

    if self.params.output_files.processed_model_prefix:
      prefix = self.params.output_files.processed_model_prefix
    else:
      prefix, ext  = os.path.splitext(self.params.input_files.model)
      prefix = "%s_processed" %(prefix)
      prefix = os.path.basename(prefix)
    self.processed_model_file_name = "%s.pdb" %(prefix)
    if not info.model or info.model.overall_counts().n_residues < 1:
      print("No residues obtained after processing...", file = self.logger)
      return None
    if (self.params.output_files.maximum_output_b is not None) and (
       info.model.get_b_iso().min_max_mean().max >
       self.params.output_files.maximum_output_b):
      print("Limiting output B values to %.0f" %(
        self.params.output_files.maximum_output_b), file = self.logger)
    mm = limit_output_b(info.model,
         maximum_output_b = self.params.output_files.maximum_output_b)

    if self.params.output_files.single_letter_chain_ids:
      # convert all chain_ids to single character (A)
      mm_to_split = convert_chain_ids_to_single_character(mm, chain_id = "A")
    else:
      mm_to_split = mm

    # original (multi-char chain IDs)
    self.data_manager.write_model_file(mm, self.processed_model_file_name)

    final_residues = mm.get_hierarchy().overall_counts().n_residues
    print("Final residues: %s\n" %(final_residues), file = self.logger)

    # Split up processed model and write each chain as well
    if len(mm_to_split.chain_ids()) > 1 or \
         self.params.output_files.single_letter_chain_ids:
      model_list = mm_to_split.as_model_manager_each_chain()
    else:
      model_list = [mm_to_split]
    count = 0
    for m in model_list:
      count += 1
      chain_id = m.first_chain_id().strip()
      if not chain_id:
        if len(model_list) > 1:
          raise Sorry(
           "Input model cannot have a blank chain ID and non-blank chain IDS")
        chain_id = "A"
      fn = "%s_%s_%s.pdb" %(prefix,chain_id, count)
      print("Copying predicted model chain %s (#%s)to %s" %(
           chain_id,count,fn), file = self.logger)
      if not m or not m.overall_counts().n_residues:
        print("Skipping #%s (no residues)" %(count), file = self.logger)
        continue
      mm = limit_output_b(m,
           maximum_output_b = self.params.output_files.maximum_output_b)
      self.data_manager.write_model_file(mm,fn)
      self.processed_model_file_name_list.append(fn)


    # Write out seq file for remainder (unused part) of model
    if info.remainder_sequence_str:
      if self.params.output_files.remainder_seq_file_prefix:
        prefix = self.params.output_files.remainder_seq_file_prefix
      else:
        prefix, ext  = os.path.splitext(self.params.input_files.model)
        prefix = "%s_remainder" %(prefix)
        prefix = os.path.basename(prefix)
      self.remainder_sequence_file_name = os.path.join(
        os.getcwd(), "%s.seq" %(prefix))
      sequence_str = info.remainder_sequence_str
      self.data_manager.write_sequence_file(sequence_str,
        filename = self.remainder_sequence_file_name)

    # Summarize each step

    self.summarize_predicted_model()

    print ('\nFinished with process_predicted_model', file=self.logger)
  # ---------------------------------------------------------------------------
  def get_results(self):

    return group_args(
      group_args_type = 'results of process_predicted_model',
      starting_model = self.starting_model,
      processed_model = self.processed_model,
      processed_model_file_name = self.processed_model_file_name,
      processed_model_file_name_list = self.processed_model_file_name_list,
      model_list = self.model_list,
      processed_model_text = self.processed_model_text,
      params = self.params)

# =============================================================================
#    Custom operations
# =============================================================================
#

  def summarize_predicted_model(self):
    #  Process predicted model
    dr = self.dock_and_rebuild
    processed_model = dr.processed_model
    if processed_model:
      chain_id_list = processed_model.chain_ids()
      text = ""
      text += "Processed model with %s domains and %s residues " %(
           len(chain_id_list),
           processed_model.get_hierarchy().overall_counts().n_residues,)
      text += "\n\nResidues by domain (as chains):"

      self.processed_model_text = ""
      for chain_id in chain_id_list:
        m = processed_model.apply_selection_string("chain '%s'" %(chain_id))
        n_found = m.get_hierarchy().overall_counts().n_residues
        text += "\nCHAIN: %s   Residues: %s " %(
           chain_id,n_found)
        print(text, file = self.logger)
        self.processed_model_text += text
    else:
        self.processed_model_text = "No processed model information available"


  def set_defaults(self):
    # set params for files identified automatically and vice versa
    params=self.params
    self.titles=[]

    # Read in default model as pdb_in
    file_name = getattr(params.input_files,'model',None)
    if file_name:
      if not os.path.isfile(file_name):
        raise Sorry("The file %s is missing?" %(file_name))
      else:
        pass # ok so far
    else: # guess it
      try:
        file_name=self.data_manager.get_default_model_name()
        self.params.input_files.model=file_name
      except Exception as e:
        pass # did not work

  def get_data_inputs(self):  # get any file-based information
    self.set_defaults()
    file_name=self.params.input_files.model
    if not file_name:
      raise Sorry("Unable to guess model file name...please specify")
    if not os.path.isfile(file_name):
      raise Sorry("Missing the model file '%s'" %(file_name))
    try:
      self.model=self.data_manager.get_model(filename=file_name)
    except Exception as e:
      raise Sorry("Failed to read model file '%s'" %(file_name))

    print("Read model from %s" %(file_name), file = self.logger)

    if not self.model:
      raise Sorry("Missing model")

    self.model.add_crystal_symmetry_if_necessary()

    # Remove hydrogens and apply user selection
    selections = []
    if self.params.output_files.remove_hydrogen:
      selections.append("(not (element H))")
    if self.params.input_files.selection:
      selections.append("(%s)" %(self.params.input_files.selection))
    if selections:
      self.model = self.model.apply_selection_string(" and ".join(selections))

    if self.params.process_predicted_model.weight_by_ca_ca_distance:
      if not self.params.input_files.distance_model_file:
        self.params.input_files.distance_model_file = \
           self.params.input_files.model
      file_name=self.params.input_files.distance_model_file
      self.distance_model = self.data_manager.get_model(filename=file_name)
      print("Read distance model from %s" %(file_name), file = self.logger)
      self.distance_model.add_crystal_symmetry_if_necessary()
    else:
      self.distance_model = None

    if self.params.input_files.pae_file and \
       os.path.isfile(self.params.input_files.pae_file):
      from mmtbx.domains_from_pae import parse_pae_file
      self.pae_matrix = parse_pae_file(self.params.input_files.pae_file)
    else:
      self.pae_matrix = None

    if len(self.model.chain_ids()) != 1:
      raise Sorry("Input model should have exactly one chain id. (Found: %s)" %(
        " ".join(self.model.chain_ids())))

  def validate(self):  # make sure we have files
    return True

  def print_params(self):
    import iotbx.phil
    master_phil = iotbx.phil.parse(master_phil_str)
    print ("\nInput parameters for process_predicted_model:\n", file = self.logger)
    master_phil.format(python_object = self.params).show(out = self.logger)

  def print_version_info(self):

    # Print version info
    import time
    print ("\n"+60*"*"+"\n"+10*" "+"PHENIX process_predicted_model" +\
      "  "+str(time.asctime())+"\n"+60*"*"+"\n",file=self.logger)
    print ("Working directory: ",os.getcwd(),"\n",file=self.logger)
    print ("PHENIX VERSION: ",os.environ.get('PHENIX_VERSION','svn'),"\n",
     file=self.logger)

def convert_chain_ids_to_single_character(m, chain_id = "A"):

  from mmtbx.secondary_structure.find_ss_from_ca import set_chain_id
  # rename all chains to "A" but keep separate chains for different segments
  mm = m.deep_copy()
  set_chain_id(mm.get_hierarchy(), chain_id)
  mm.reset_after_changing_hierarchy()
  return mm

def get_available_letter(c, all_letters, used_ids):

  if len(c) == 1 and not c in used_ids:
    return c
  if c and (not (c[0] in used_ids)):
    return c[0]

  for a in all_letters:
    if not (a in used_ids):
      return a
  return None



def limit_output_b(m, maximum_output_b = None):
  """ create deep copy of model m in which all isotropic
      values > maximum_output_b
      are set to maximum_output_b. If maximum_output_b is None or there are no
      b-values > maximum_output_b, return original model (not deep copy)"""

  if (maximum_output_b is not None) and (
       m.get_b_iso().min_max_mean().max > maximum_output_b):
    b_values = m.get_b_iso()
    b_values.set_selected((b_values > maximum_output_b), maximum_output_b)
    mm = m.deep_copy() # REQUIRED so we do not modify b-values in m itself
    mm.set_b_iso(b_values)
    return mm
  else:
    return m
# =============================================================================
# for reference documentation keywords
master_phil_str = Program.master_phil_str
