#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) 2012, tamanegi (tamanegi@users.sourceforge.jp)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

# Gaussian-type basis set related code

import sys
import copy
import os

filedirname = os.path.dirname( os.path.abspath( __file__ ) )
commonpath = os.path.join( filedirname, "../common" )
if not commonpath in sys.path:
  sys.path.append( commonpath )
from vector3d import *

# a basis function
# the functional form is x^l y^m z^n exp(- alpha * r^2 )
class Basis:
  # variables
  #   __normalize   : normalize coefficient
  #   __exponent    : alpha in the equation above
  #   __contract    : coefficient of this gaussian
  #   __norm        : norm
  #   __l, __m, __n : see equation above
  def __init__( self, exponent = 1.0, contract = 1.0, l = 0, m = 0, n = 0 ):
    self.setValues( exponent, contract, l, m, n )

  def setValues( self, exponent = 1.0, contract = 1.0, l = 0, m = 0, n = 0 ):
    self.setExponent( exponent, False )
    self.setContractionCoeff( contract, False )
    self.setAzimuthal( l, m, n, False )
    self.__norm = 1.0
    self.__normalize = self.calcNormFactor()

  def getExponent( self ):
    return( float( self.__exponent ) )

  def getCoeff( self ):
    return( float( self.__contract ) * float( self.__normalize ) )

  def getContractionCoeff( self ):
    return( self.__contract )

  def getNormalize( self ):
    return( self.__normalize )

  def getNorm( self ):
    return( self.__norm )

  def getL( self ):
    return( self.__l )

  def getM( self ):
    return( self.__m )

  def getN( self ):
    return( self.__n )

  def setExponent( self, exponent, recalcNormFactor = True ):
    self.__exponent = exponent
    if recalcNormFactor:
      self.__normalize = self.calcNormFactor()

  def setContractionCoeff( self, contract, recalcNormFactor = True ):
    self.__contract = contract
    if recalcNormFactor:
      self.__normalize = self.calcNormFactor()

  def setAzimuthal( self, l, m, n, recalcNormFactor = True ):
    self.__l = l
    self.__m = m
    self.__n = n
    if recalcNormFactor:
      self.__normalize = self.calcNormFactor()

  # calculate normalization factor
  def calcNormFactor( self ):
    # common parts
    basecoeff = math.sqrt( math.pi / ( 2.0 * float( self.__exponent ) ) )
    ret = math.pow( basecoeff, 3 )
    # number of 4a to be divided
    count = int( self.__l ) + int( self.__m ) + int( self.__n )
    ret *= self.__calcCoeff( self.__l )
    ret *= self.__calcCoeff( self.__m )
    ret *= self.__calcCoeff( self.__n )
    a4 = 4.0 * float( self.__exponent )
    ret /= math.pow( a4, count )
    self.__norm = ret
    ret = 1.0 / math.sqrt( self.__norm )
    return ret

  def __calcCoeff( self, num ):
    ret = 1
    mult = 3
    for i in range( 2, num + 1 ):
      ret *= mult
      mult += 2
    return( ret )

  def getValueAt( self, rsq, vec ):
    prefac = math.pow( vec.x, self.__l ) * math.pow( vec.y, self.__m ) * math.pow( vec.z, self.__n )
    valexp = math.exp( - float( self.__exponent ) * rsq )
    #print self.getCoeff(), prefac, valexp
    return( self.getCoeff() * prefac * valexp )

  def __str__( self ):
    return "Contraction Coeff: " + str(self.__contract) + " Exponent: " + str(self.__exponent) + " XYZ = " + str(self.__l) + " " + str(self.__m) + " " + str(self.__n)

def normBases( basis0, basis1 ):
  # assume these two bases have the same azimuthal component
  exponent = 0.5 * ( basis0.getExponent() + basis1.getExponent() )
  coeff = basis0.getCoeff() * basis1.getCoeff()
  combined = Basis( exponent, 1.0, basis0.getL(), basis0.getM(), basis0.getN() )
  return( coeff * combined.getNorm() )

# class for an orbital such as 1s, 2s, 2px, 3dx
class Orbital:
  def __init__( self ):
    # normalize coefficient
    self.__normalize = 0.0
    # list of gaussians
    self.__gaussians = []

  def addGaussian( self, basis ):
    self.__gaussians.append( basis )

  # calculte normalize coeffcient
  def initialize( self ):
    norm = 0.0
    for basis0 in self.__gaussians:
      for basis1 in self.__gaussians:
        norm += normBases( basis0, basis1 )
    self.__normalize = math.sqrt( 1.0 / norm )

  def getValueAt( self, mypos, pos ):
    ret = 0.0
    vec = pos - mypos
    rsq = vec.square()
    ## debug
    #print len( self.__gaussians )
    for gaussian in self.__gaussians:
      ret += gaussian.getValueAt( rsq, vec )
    return self.__normalize * ret

  def size( self ):
    return len( self.__gaussians )

  def empty( self ):
    if self.size() == 0:
      return True
    return False

  def __str__( self ):
    ret = "Normalize Factor: " + str( self.__normalize ) + "\n"
    for gs in self.__gaussians:
      ret += str( gs ) + "\n"
    return ret

# atomic orbital; AO
class AtomicOrbital:
  def __init__( self ):
    # atom name
    self.atomname = ""
    # list of orbitals
    self.__orbitals = []
    # position of atom
    self.__position = Vector3D()

  def setOrbitals( self, ao ):
    self.__orbitals = copy.deepcopy( ao.getOrbitals() )

  def getOrbitals( self ):
    return( self.__orbitals )

  def addOrbital( self, orbital ):
    self.__orbitals.append( orbital )
    self.__orbitals[-1].initialize()

  def setPosition( self, pos ):
    self.__position = pos

  def getPosition( self ):
    return self.__position

  ## get electron density at pos(Vector3D)
  #def getValueAt( self, pos ):
  #  # electron density of + and -, respectively
  #  ret = 0.0
  #  vec = pos - self.__position
  #  squaredist = vec.square()
  #  ## debug
  #  #print len( self.__orbitals )
  #  for orbital in self.__orbitals:
  #    ret += orbital.getValueAt( squaredist, vec )
  #  return ret

  def size( self ):
    return len( self.__orbitals )

  def empty( self ):
    if self.size() == 0:
      return True
    return False
  
  def __str__( self ):
    ret = ""
    if self.atomname != "":
      ret = self.atomname + "\n"
    for orb in self.__orbitals:
      ret += str( orb ) + "\n"
    return ret
