#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import sys
import math
import copy
import re
import platform
import subprocess
import tempfile

from vector3d import *
from pdb import *

class System:
  stride = "/usr/local/lib/vmd/stride_LINUX"
  if platform.system() == "Windows":
    stride = "C:/Program Files/University of Illinois/VMD/stride_WIN32.exe"
  definedAtoms = {}
  maxvdw = 1.0
  chainnum = 3
  helixcolor = "0xFF0000"
  helixrad = 0
  strandcolor = "0x0000FF"
  strandrad = 0
  coilcolor = "0x00FF00"
  coilrad = 0
  suffix = ""
  spooledtext = ""

  # class methods
  @classmethod
  def globalSetting( cls ):
    ret = "  <GLOBAL>\n"
    # coil
    ret += "    <COIL "
    if cls.coilrad > 0:
      ret += "radius=\"" + str( cls.coilrad ) + "\" "
    if len( cls.coilcolor ) > 0:
      ret += "color=\"" + str( cls.coilcolor ) + "\" "
    ret += "/>\n"
    ret += "  </GLOBAL>\n"
    if len( cls.suffix ) == 0 and len( cls.spooledtext ) == 0:
      ret += cls.aliasSetting()
    else:
      ret += cls.spooledtext
    return( ret )

  # aliases
  @classmethod
  def aliasSetting( cls ):
    ret = "  <ALIAS>\n"
    # helices
    ret += "    <HELIX" + cls.suffix + " elem=\"RIBBON\" "
    if cls.helixrad > 0:
      ret += "radius=\"" + str( cls.helixrad ) + "\" "
    if len( cls.helixcolor ) > 0:
      ret += "color=\"" + str( cls.helixcolor ) + "\" "
    ret += "/>\n"
    # strands
    ret += "    <STRAND" + cls.suffix + " elem=\"RIBBON\" "
    if cls.strandrad > 0:
      ret += "radius=\"" + str( cls.strandrad ) + "\" "
    if len( cls.strandcolor ) > 0:
      ret += "color=\"" + str( cls.strandcolor ) + "\" "
    ret += "/>\n"
    # coils
    ret += "    <COIL" + cls.suffix + " elem=\"COIL\" "
    if cls.coilrad > 0:
      ret += "radius=\"" + str( cls.coilrad ) + "\" "
    if len( cls.coilcolor ) > 0:
      ret += "color=\"" + str( cls.coilcolor ) + "\" "
    ret += "/>\n"
    ret += "  </ALIAS>\n"
    return( ret )

  @classmethod
  def setNewSuffix( cls, s ):
    cls.suffix = s
    cls.spooledtext += cls.aliasSetting()

  @classmethod
  def clearSpooledText( cls ):
    cls.spooledtext = ""
    cls.suffix = ""

  @classmethod
  def clearAtoms( cls ):
    cls.definedAtoms = {}
    cls.maxvdw = 1.0

  @classmethod
  def registerAtom( cls, args ):
    atom = WMAtom( args.strip() )
    if not atom.empty():
      cls.definedAtoms[atom.name] = atom
      cls.maxvdw = max( cls.maxvdw, atom.vdw )
      return( True )
    return( False )

  @classmethod
  def registerDefaultsFromFile( cls, filename ):
    cls.clearAtoms()
    for line in open( os.path.join( os.path.dirname( __file__ ), filename ), 'r' ):
      line = line.strip()
      if line[0] == "#":
        continue
      cls.registerAtom( line )

  @classmethod
  def showAtoms( cls ):
    if len( cls.definedAtoms ) == 0:
      print >> sys.stderr, ">> no atoms registered"
    else:
      for name, data in cls.definedAtoms.items():
        print >> sys.stderr, ">> " + str( data )

  # instance methods
  def __init__( self ):
    self.allclear()

  def clear( self ):
    self.data = []

  def allclear( self ):
    self.PDB = Pdb()
    self.data = []
    self.registerDefaultsFromFile( "DEFAULTS" )
    self.syssize = Vector3D()
    self.celsize = Vector3D()
    self.minPos = Vector3D()
    self.indexedlist = []
    self.scenedata = ""
    self.storedata = ""

  def purgeData( self, toWhat = sys.__stdout__ ):
    if not self.hasData():
      return
    print >> toWhat, "  <SCENE>"
    if len( self.storedata ) > 0:
      print >> toWhat, self.storedata
    print >> toWhat, self.scenedata + "  </SCENE>"

  def storeData( self ):
    if not self.hasData():
      return
    return( self.scenedata )

  def readPDB( self, fromwhat ):
    self.PDB.read( fromwhat )

  def hasData( self ):
    if len( self.data ) == 0:
      return False
    return True

  def select( self, query ):
    if len(self.data) != 0:
      self.data = query.select( self.data )
    else:
      self.data = query.select( self.PDB )

  def totatom( self ):
    return( self.PDB.natom() )

  def natom( self ):
    return( len( self.data ) )

  def genAtoms( self, arg = None ):
    if len( self.data ) == 0:
      return( 0 )
    linenum = self.scenedata.count( "\n" )
    # parse argument
    norad = False
    myrad = None
    if arg:
      args = arg.split()
      for a in args:
        aa = a.split( '=' )
        if len(aa) == 2 and aa[0] == "radius":
          myrad = float(aa[1])
        elif a == "norad":
          norad = True
    for dat in self.data:
      # ignore if unknown atom
      if not self.definedAtoms.has_key( dat.atomtype ):
        continue
      defa = self.definedAtoms[dat.atomtype]
      self.scenedata += "    <ATOM pos=\"" + str(dat.pos) + "\" "
      if not norad:
        if not myrad:
          self.scenedata += " radius=\"" + str(defa.radius) + "\" "
        else:
          self.scenedata += " radius=\"" + str(myrad) + "\" "
      self.scenedata += " color=\"" + str(defa.color) + "\" "
      self.scenedata += " />\n"
    return( self.scenedata.count( "\n" ) - linenum )

  def genBonds( self, arg = None ):
    if len( self.data ) == 0:
      return( 0 )
    linenum = self.scenedata.count( "\n" )
    bondlist = []
    self.__setSystemSize()
    self.__genIndexedList()
    myrad = None
    myoffset = None
    if arg:
      args = arg.split()
      for a in args:
        aa = a.split( "=" )
        if len(aa) == 2 and aa[0] == "radius":
          myrad = float( aa[1] )
        elif len(aa) == 2 and aa[0] == "offset":
          myoffset = float( aa[1] )
    for i in range( 0, self.cellsize.x ):
      for j in range( 0, self.cellsize.y ):
        for k in range( 0, self.cellsize.z ):
          natom = len( self.indexedlist[i][j][k] )
          for ia in range( 0, natom - 1 ):
            for ja in range( ia + 1, natom ):
              ai = self.data[self.indexedlist[i][j][k][ia]]
              aj = self.data[self.indexedlist[i][j][k][ja]]
              bd = self.__assignBond( ai, aj )
              if not bd.empty():
                bondlist.append( bd )
          # search bonds with atoms in neighbour cells
          self.__searchBonds( bondlist, i, j, k )
    for bond in bondlist:
      if myrad:
        bond.radius = myrad
      if myoffset:
        bond.offset = myoffset
      self.scenedata += bond.genXML() + "\n"
    return( self.scenedata.count( "\n" ) - linenum )

  def __assignBond( self, ai, aj ):
    ret = Bond()
    dist = ( ai.pos - aj.pos ).norm()
    # 2012/4/30 added: ignore unknown atoms
    if not System.definedAtoms.has_key( ai.atomtype ):
      print >> sys.stderr, ">> __assignBond: Unknown atom type -", ai.atomtype
      print >> sys.stderr, ">> ignore this atom pair"
      return( ret )
    if not System.definedAtoms.has_key( aj.atomtype ):
      print >> sys.stderr, ">> __assignBond: Unknown atom type -", aj.atomtype
      print >> sys.stderr, ">> ignore this atom pair"
      return( ret )
    if dist > 0.5 * ( System.definedAtoms[ai.atomtype].vdw + System.definedAtoms[aj.atomtype].vdw ):
      return( ret )
    ret.atom0 = ai
    ret.color0 = System.definedAtoms[ai.atomtype].color
    ret.atom1 = aj
    ret.color1 = System.definedAtoms[aj.atomtype].color
    return( ret )

  def __searchBonds( self, bondlist, ui, uj, uk ):
    for i in range( 0, 2 ):
      for j in range( -1, 2 ):
        for k in range( -1, 2 ):
          if self.__isIgnored( i, j, k ) or self.__invalidCell( ui, uj, uk, i, j, k ):
            continue
          natomi = len( self.indexedlist[ui][uj][uk] )
          natomj = len( self.indexedlist[ui+i][uj+j][uk+k] )
          for l in range( 0, natomi ):
            for m in range( 0, natomj ):
              al = self.data[self.indexedlist[ui][uj][uk][l]]
              am = self.data[self.indexedlist[ui+i][uj+j][uk+k][m]]
              bd = self.__assignBond( al, am )
              if not bd.empty():
                bondlist.append( bd )

  def __isIgnored( self, i, j, k ):
    if i != 0:
      return( i < 0 )
    if j != 0:
      return( j < 0 )
    if k != 0:
      return( k < 0 )
    # when i = j = k = 0
    return( True )

  def __invalidCell( self, ui, uj, uk, i, j, k ):
    if i + ui >= self.cellsize.x:
      return( True )
    if j + uj >= self.cellsize.y or j + uj < 0:
      return( True )
    if k + uk >= self.cellsize.z or k + uk < 0:
      return( True )
    return( False )

  def __setSystemSize( self ):
    if len( self.data ) == 0:
      return
    self.minPos = Vector3D( self.data[0].pos.x, self.data[0].pos.y, self.data[0].pos.z )
    maxPos = Vector3D( self.data[0].pos.x, self.data[0].pos.y, self.data[0].pos.z )
    for dat in self.data:
      self.minPos.x = min( self.minPos.x, dat.pos.x )
      self.minPos.y = min( self.minPos.y, dat.pos.y )
      self.minPos.z = min( self.minPos.z, dat.pos.z )
      maxPos.x = max( maxPos.x, dat.pos.x )
      maxPos.y = max( maxPos.y, dat.pos.y )
      maxPos.z = max( maxPos.z, dat.pos.z )
    self.syssize = Vector3D()
    self.syssize.x = maxPos.x - self.minPos.x
    self.syssize.y = maxPos.y - self.minPos.y
    self.syssize.z = maxPos.z - self.minPos.z

  def __genIndexedList( self ):
    grid = float(self.maxvdw)
    self.cellsize = Vector3D()
    self.cellsize.x = int( self.syssize.x / grid + 1 )
    self.cellsize.y = int( self.syssize.y / grid + 1 )
    self.cellsize.z = int( self.syssize.z / grid + 1 )
    self.__initIndexedList()
    counter = 0
    for data in self.data:
      myX = int( ( data.pos.x - self.minPos.x ) / grid )
      myY = int( ( data.pos.y - self.minPos.y ) / grid )
      myZ = int( ( data.pos.z - self.minPos.z ) / grid )
      self.indexedlist[myX][myY][myZ].append( counter )
      counter += 1

  def __initIndexedList( self ):
    for i in range( 0, self.cellsize.x ):
      self.indexedlist.append( [] )
      for j in range( 0, self.cellsize.y ):
        self.indexedlist[i].append( [] )
        for k in range( 0, self.cellsize.z ):
          self.indexedlist[i][j].append( [] )

  def genChains( self, gencoil = False ):
    if len( self.data ) == 0:
      return( 0 )
    linenum = self.scenedata.count( "\n" )
    # create temporary file for stride input
    tmp_file = tempfile.mkstemp()
    fh = os.fdopen( tmp_file[0], "w" )
    localdata0 = isAminoAcid().select( self.data )
    for dat in localdata0:
      print >> fh, dat.toPdbString()
    fh.close()
    # run stride
    cmdline = [System.stride + " " + tmp_file[1]]
    if platform.system() == "Windows":
      cmdline = [System.stride, tmp_file[1]]
    if not gencoil:
      p = subprocess.Popen( cmdline, shell=True, stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                            close_fds=False )
      re_asg = re.compile( "^ASG" )
      pstr = []
      while True:
        line = p.stdout.readline()
        if not line:
          break
        if re_asg.match( line ):
          line.rstrip()
          pstr.append( line )
    # split pstr into per-CHAIN information
    ppstr = self.__splitStrideLog( pstr )
    myquery = PdbQuery()
    myquery.selectAtoms( "CA" )
    localdata = myquery.select( localdata0 )
    # create <CHAIN> element and its child elements, <POINT>s
    self.scenedata += "    <CHAIN N=\"" + str(System.chainnum) + "\">\n"
    initindex = localdata[0].resindex
    previndex = -1
    prev = None
    curchain = 0
    for dat in localdata:
      if dat.resindex != previndex + 1 and prev:
        if dat.isSameAtom( prev ):
          # ignore if the same atom appears
          # trick for atoms involving alternate location indicator
          continue
      if previndex > 0 and dat.resindex < previndex:
        print >> sys.stderr, ">> CHAIN" + str(curchain) + ": init:" + str(initindex) + " last:" + str(previndex)
        # must be different CHAIN; create a new CHAIN
        if not gencoil:
          self.__genSecondary( initindex, previndex, ppstr[curchain] )
        else:
          self.__printSec( "Coil", initindex, previndex )
        initindex = dat.resindex
        self.scenedata += "    </CHAIN>\n"
        self.scenedata += "    <CHAIN N=\"" + str(System.chainnum) + "\">\n"
        curchain += 1
      self.scenedata += "      <POINT index=\"" + str(dat.resindex) + "\" pos=\"" + str(dat.pos) + "\" />\n"
      previndex = dat.resindex
      prev = dat
    if initindex != previndex:
      print >> sys.stderr, ">> CHAIN" + str(curchain) + ": init:" + str(initindex) + " last:" + str(previndex)
      if not gencoil:
        self.__genSecondary( initindex, previndex, ppstr[curchain] )
      else:
        self.__printSec( "Coil", str(initindex), str(previndex) )
      self.scenedata += "    </CHAIN>\n"
    os.remove( tmp_file[1] )
    return( self.scenedata.count( "\n" ) - linenum )

  def __splitStrideLog( self, pstr ):
    prevNum = -1
    ppstr = []
    lstr = []
    for line in pstr:
      sline = line.split()
      if int(sline[3]) > prevNum:
        lstr.append( line )
      else:
        ppstr.append( lstr )
        lstr = []
        lstr.append( line )
      prevNum = int(sline[3])
    if len(lstr) > 0:
      ppstr.append( lstr )
    # log message about number of CHAINs
    print >> sys.stderr, ">> There seems to be " + str(len(ppstr)) + " CHAINs"
    return( ppstr )

  def __genSecondary( self, initindex, lastindex, pstr ):
    re_helix = re.compile( "Helix" )
    re_strand = re.compile( "Strand" )
    prev = None
    prevnum = 0
    for line in pstr:
      sline = line.split()
      if int(sline[3]) < initindex:
        continue
      if int(sline[3]) > lastindex:
        break
      # extract secondary structure type
      secname = sline[6]
      if re_helix.search( secname ):
        secname = "Helix"
      elif re_strand.search( secname ):
        secname = "Strand"
      else:
        secname = "Coil"
      # write secondary structure information
      if ( ( secname != prev or int(sline[3]) - int(prevnum) != 1 ) and prev ) or int(sline[3]) == lastindex:
        ## NOTE: seconday structure type of the terminal residue is ignored
        ## (modified on 2012/5/25)
        if int(sline[3]) == lastindex:
          #prev = secname
          prevnum = sline[3]
        self.__printSec( prev, minnum, prevnum )
        prev = None
      if not prev:
        minnum = sline[3]
      prevnum = sline[3]
      prev = secname

  def __printSecEach( self, objtype, minnum, prevnum ):
    self.scenedata += "      <" + str(objtype) + System.suffix + " "
    self.scenedata += "init=\"" + str(minnum) + "\" last=\"" + str(prevnum) + "\" />\n"

  def __printSec( self, sec, minnum, prevnum ):
    if sec == "Helix":
      self.__printSecEach( "HELIX", str(minnum), str(prevnum) )
    elif sec == "Strand":
      self.__printSecEach( "STRAND", str(minnum), str(prevnum) )
    else:
      self.__printSecEach( "COIL", str(minnum), str(prevnum) )

  def write( self, toWhat ):
    if len( self.data ) == 0:
      self.PDB.write( toWhat )
    else:
      for dat in self.data:
        print >> toWhat, dat.toPdbString()

# Pre-defined atom types.
# This does not contain actual atom information.
# Atoms are contained in the System class above,
# and the output is also done in that class.
class WMAtom:
  def __init__( self, line = "" ):
    if len(line) == 0:
      self.name = ""
      self.color = "0x00FF00"
      self.radius = 1
      self.vdw = 1.0
    else:
      self.read( line )

  def read( self, line ):
    items = line.split()
    # simply ignore if number of values is insufficient
    if len( items ) < 4:
      return
    self.name = items[0].upper()
    self.color = items[1]
    self.radius = float(items[2])
    self.vdw = float(items[3])

  def empty( self ):
    if self.name == "":
      return( True )
    return( False )

  def __str__( self ):
    return( str( self.name ) + " " + str( self.color ) + " " + str( self.radius ) + " " + str( self.vdw ) )

# This class defined a bond.
# Unlike WMAtom, this class does not contain information about pre-defined
# bonds.
class Bond:
  def __init__( self ):
    self.atom0 = None
    self.atom1 = None
    self.color0 = None
    self.color1 = None
    self.radius = None
    self.offset = None

  def empty( self ):
    if self.atom0 is None or self.atom1 is None:
      return( True )
    return( False )

  def genXML( self ):
    ret = "    <BOND pos0=\"" + str( self.atom0.pos ) + "\" pos1=\"" + str( self.atom1.pos ) + "\" "
    if self.radius:
      ret += "radius=\"" + str( self.radius ) + "\" "
    if self.offset:
      ret += "offset=\"" + str( self.offset ) + "\" "
    if self.color0 != self.color1:
      ret += "col0=\"" + str( self.color0 ) + "\" col1=\"" + str( self.color1 ) + "\" />"
    else:
      ret += "color=\"" + str( self.color0 ) + "\" />"
    return( ret )
