module AssertSQLIdentifier
  # この値で識別子の長さの制限を行う。
  LENGTH_MAX = 30

  ANSI_SQL_RESERVED = %w[ADD ALL ALTER AND ANY AS ASC BETWEEN BY CHAR CHECK CONNECT CREATE CURRENT DATE DECIMAL DEFAULT DELETE DESC DISTINCT DROP ELSE FLOAT FOR FROM GRANT GROUP HAVING IMMEDIATE IN INSERT INTEGER INTERSECT INTO IS LEVEL LIKE NOT NULL OF ON OPTION OR ORDER PRIOR PRIVILEGES PUBLIC REVOKE ROWS SELECT SESSION SET SIZE SMALLINT TABLE THEN TO UNION UNIQUE UPDATE USER VALUES VARCHAR VIEW WHENEVER WITH]
  ORACLE_SQL_RESERVED = %w[ACCESS AUDIT A ABORT ACCESSED ACCOUNT ACTIVATE ADMINISTER ADMINISTRATOR ADVISE ADVISOR AFTER ALGORITHM ALIAS ALLOCATE ALLOW ANALYZE CLUSTER COLUMN COMMENT COMPRESS EXCLUSIVE EXISTS FILE IDENTIFIED INCREMENT INDEX INITIAL LOCK LONG MAXEXTENTS MINUS MLSLABEL MODE MODIFY NOAUDIT NOCOMPRESS NOWAIT NUMBER OFFLINE ONLINE PCTFREE RAW RENAME RESOURCE ROW ROWID ROWNUM SHARE START SUCCESSFUL SYNONYM SYSDATE TRIGGER UID VALIDATE VARCHAR2 WHERE] # 'ADMIN' is also reserved in fact.
  LOCAL_RESERVED = %w[APPLICATION RECORD RECORD_ID RECORD_NAME RECORD_CODE] # for record picker

  def assert_length(*names)
    names.each do |name|
      raise ArgumentError, "exceed maximum length of name: #{name}" if name.to_s.length > LENGTH_MAX
    end
  end

  def assert_unreserved(*names)
    names.each do |name|
      raise ArgumentError, "reserved identifier: #{name}" if (ANSI_SQL_RESERVED | ORACLE_SQL_RESERVED).include? name.to_s.upcase
    end
  end
end

# フレームワークに導入するテーブルに付随する識別子を検証する。
module ForceColumns
  include AssertSQLIdentifier

  def create_table(name, options = {})
    # ActiveRecordStore
    if %w|sessions schema_migrations|.include?(name.to_s)
      return super
    end

    # through only while loading db/schema.rb
    if options[:force] && !options[:force_columns]
      return super
    end

    assert_length name
    assert_unreserved name
    super do |t|
      class << t
        def column(name, type, options = {})
          self.extend(AssertSQLIdentifier)
          self.assert_length name
          self.assert_unreserved name
          super(name.to_s, type, options)
        end
      end
      [
        :domain_id,
      ].each do |column_name|
        unless options[column_name] == false
          t.column column_name, :integer, :null => false
        end
      end
      yield t
      [
        :created_at,
        :updated_at,
      ].each do |column_name|
        unless options[column_name] == false
          t.column column_name, :string, :limit => 14
        end
      end
      [
        :created_by,
        :updated_by,
        :created_in,
        :updated_in,
      ].each do |column_name|
        unless options[column_name] == false
          t.column column_name, :integer
        end
      end
      [
        :lock_version,
      ].each do |column_name|
        unless options[column_name] == false
          t.column column_name, :integer, :null => false, :default => 0
        end
      end
    end
  end

  def add_column(table_name, column_name, type, options = {})
    assert_length table_name, column_name
    assert_unreserved table_name, column_name
    super
  end

  module ForceDomain
    def self.included(base) #:nodoc:
      super
      base.extend(ClassMethods)
      class << base
        alias_method_chain :find, :domain
        alias_method_chain :calculate, :domain
      end
      base.alias_method_chain :create, :force_domain
      base.alias_method_chain :update, :force_domain
    end

    module ClassMethods
      def find_with_domain(*args)
        if column_names.include?("domain_id") && Domain.current_id
          with_scope(:find => {:conditions => ["#{table_name}.domain_id = ?", Domain.current_id]}) do
            find_without_domain(*args)
          end
        else
          find_without_domain(*args)
        end
      end

      def calculate_with_domain(*args)
        if column_names.include?("domain_id") && Domain.current_id
          with_scope(:find => {:conditions => ["#{table_name}.domain_id = ?", Domain.current_id]}) do
            calculate_without_domain(*args)
          end
        else
          calculate_without_domain(*args)
        end
      end
    end

    # set current domain_id
    def set_current_domain
      if self.class.column_names.include?("domain_id") && read_attribute("domain_id").nil?
        write_attribute("domain_id", Domain.current_id)
      end
    end

    def create_with_force_domain #:nodoc:
      set_current_domain
      create_without_force_domain
    end

    def update_with_force_domain #:nodoc:
      set_current_domain
      update_without_force_domain
    end
  end
end
