/*
 *  Copyright 2010 argius
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 */
package net.argius.stew.ui.window;

import static java.util.Arrays.asList;
import static java.util.Collections.nCopies;
import static net.argius.stew.Iteration.*;
import static net.argius.stew.ui.window.Resource.EMPTY_STRING;

import java.math.*;
import java.sql.*;
import java.util.*;
import java.util.regex.*;

import javax.swing.table.*;

import net.argius.logging.*;
import net.argius.stew.*;

/**
 * ʃZbge[utTableModelB
 * ɍXV@\(DB)gĂB
 */
final class ResultSetTableModel extends DefaultTableModel {

    private static final Logger log = LoggerFactory.getLogger(ResultSetTableModel.class);

    private static final long serialVersionUID = -8861356207097438822L;
    private static final String PTN1 = "^.*\\s*SELECT\\s*.+\\s*FROM\\s*([^\\s]+)";

    private final int[] types;
    private final String commandString;
    private final Set<Integer> unlinkedRows;

    private Connection conn;
    private Object tableName;
    private String[] primaryKeys;
    private boolean updatable;

    /**
     * RXgN^B
     * @param rs 
     * @param columnNames
     * @param cmd 
     * @throws SQLException 
     */
    ResultSetTableModel(ResultSet rs, Object[] columnNames, String cmd) throws SQLException {
        super(0, columnNames.length);
        ResultSetMetaData meta = rs.getMetaData();
        final int columnCount = getColumnCount();
        int[] types = new int[columnCount];
        for (int i = 0; i < columnCount; i++) {
            types[i] = meta.getColumnType(i + 1);
        }
        for (int i = 0; i < columnIdentifiers.size(); i++) {
            @SuppressWarnings({"unchecked", "unused"})
            Object o = columnIdentifiers.set(i, columnNames[i]);
        }
        this.types = types;
        this.commandString = cmd;
        this.unlinkedRows = new TreeSet<Integer>();
        try {
            analyzeForLinking(rs, cmd);
        } catch (Exception ex) {
            log.warn("", ex);
        }
    }

    @Override
    public Class<?> getColumnClass(int columnIndex) {
        switch (types[columnIndex]) {
            case Types.CHAR:
            case Types.VARCHAR:
            case Types.LONGVARCHAR:
                return String.class;
            case Types.BIT:
                return Boolean.class;
            case Types.TINYINT:
                return Byte.class;
            case Types.SMALLINT:
                return Short.class;
            case Types.INTEGER:
                return Integer.class;
            case Types.BIGINT:
                return Long.class;
            case Types.REAL:
                return Float.class;
            case Types.DOUBLE:
            case Types.FLOAT:
                return Double.class;
            case Types.DECIMAL:
            case Types.NUMERIC:
                return BigDecimal.class;
            case Types.DATE:
            case Types.TIME:
            case Types.TIMESTAMP:
            default:
                return Object.class;
        }
    }

    @Override
    public boolean isCellEditable(int row, int column) {
        if (primaryKeys == null || primaryKeys.length == 0) {
            return false;
        }
        return super.isCellEditable(row, column);
    }

    @Override
    public void setValueAt(Object newValue, int row, int column) {
        final Object oldValue = getValueAt(row, column);
        final boolean changed;
        if (newValue == null) {
            changed = (newValue != oldValue);
        } else {
            changed = !newValue.equals(oldValue);
        }
        if (changed) {
            if (isLinkedRow(row)) {
                Object[] keys = columnIdentifiers.toArray();
                try {
                    executeUpdate(getRowData(keys, row), keys[column], newValue);
                } catch (SQLException ex) {
                    throw new RuntimeException(ex);
                }
            } else {
                if (log.isTraceEnabled()) {
                    log.debug("update unlinked row");
                }
            }
        } else {
            if (log.isDebugEnabled()) {
                log.debug("skip to update");
            }
        }
        super.setValueAt(newValue, row, column);
    }

    /**
     * 񃊃NsǉB
     * @param rowData sf[^
     */
    void addUnlinkedRow(Object[] rowData) {
        addUnlinkedRow(convertToVector(rowData));
    }

    /**
     * 񃊃NsǉB
     * @param rowData sf[^
     */
    void addUnlinkedRow(Vector<?> rowData) {
        addRow(rowData);
        unlinkedRows.add(getRowCount() - 1);
    }

    /**
     * sDBƃNB
     * @param rowIndex sCfbNX
     * @return Nɐꍇ<code>true</code>A
     *         ɃNĂꍇ<code>false</code>
     * @throws SQLException SQLG[ɂ胊Nsꍇ
     */
    boolean linkRow(int rowIndex) throws SQLException {
        if (isLinkedRow(rowIndex)) {
            return false;
        }
        executeInsert(getRowData(columnIdentifiers.toArray(), rowIndex));
        return unlinkedRows.remove(rowIndex);
    }

    /**
     * Ns폜B
     * @param rowIndex sCfbNX
     * @return Nɐꍇ<code>true</code>A
     *         ɃNĂꍇ<code>false</code>
     * @throws SQLException SQLG[ɂ胊Nsꍇ
     */
    boolean removeLinkedRow(int rowIndex) throws SQLException {
        if (!isLinkedRow(rowIndex)) {
            return false;
        }
        executeDelete(getRowData(primaryKeys, rowIndex));
        super.removeRow(rowIndex);
        return true;
    }

    /**
     * sf[^̎擾B
     * @param keys L[ꗗ
     * @param rowIndex sCfbNX
     * @return sf[^
     */
    private Map<Object, Object> getRowData(Object[] keys, int rowIndex) {
        Map<Object, Object> rowData = new LinkedHashMap<Object, Object>();
        for (int columnIndex = 0, n = keys.length; columnIndex < n; columnIndex++) {
            rowData.put(keys[columnIndex], getValueAt(rowIndex, columnIndex));
        }
        return rowData;
    }

    /**
     * ̃e[uXV\ǂ𒲂ׂB
     * @return e[uXV\ǂ
     */
    boolean isUpdatable() {
        return updatable;
    }

    /**
     * ꂽsǂ𒲂ׂB
     * @param rowIndex sCfbNX
     * @return ꂽsǂ
     */
    boolean isLinkedRow(int rowIndex) {
        return !unlinkedRows.contains(rowIndex);
    }

    /**
     * fێConnectionƓǂ𒲂ׂB
     * @param conn rConnection
     * @return ǂ
     */
    boolean isSameConnection(Connection conn) {
        return conn == this.conn;
    }

    /**
     * Model𐶐R}h擾B
     * @return R}h
     */
    String getCommandString() {
        return commandString;
    }

    /**
     * ZҏWDBɔf(UPDATE)B
     * @param keyMap L[̃}bsO
     * @param target XVΏۗƍXVl
     * @throws SQLException 
     */
    private void executeUpdate(Map<Object, Object> keyMap, Object targetKey, Object targetValue) throws SQLException {
        final String sql = String.format("UPDATE %s SET %s=? WHERE %s",
                                         tableName,
                                         quoteIfNeeds(targetKey),
                                         getPrimaryKeyClauseString());
        List<Object> a = new ArrayList<Object>();
        a.add(targetValue);
        for (Object pk : primaryKeys) {
            a.add(keyMap.get(pk));
        }
        executeSql(sql, a.toArray());
    }

    /**
     * ZҏWDBɔf(INSERT)B
     * @param rowData sf[^
     * @throws SQLException 
     */
    private void executeInsert(Map<Object, Object> rowData) throws SQLException {
        Set<Object> keys = rowData.keySet();
        final String sql = String.format("INSERT INTO %s (%s) VALUES (%s)",
                                         tableName,
                                         join(keys, ", "),
                                         join(nCopies(keys.size(), "?"), ","));
        executeSql(sql, rowData.values().toArray());
    }

    /**
     * ZҏWDBɔf(DELETE)B
     * @param keyMap L[̃}bsO
     * @throws SQLException 
     */
    private void executeDelete(Map<Object, Object> keyMap) throws SQLException {
        final String sql = String.format("DELETE FROM %s WHERE %s",
                                         tableName,
                                         getPrimaryKeyClauseString());
        List<Object> a = new ArrayList<Object>();
        for (Object pk : primaryKeys) {
            a.add(keyMap.get(pk));
        }
        executeSql(sql, a.toArray());
    }

    /**
     * vC}L[啶̎擾B
     * @return vC}L[啶
     */
    private String getPrimaryKeyClauseString() {
        return join(map(asList(primaryKeys), new Iteration.Correspondence<String, String>() {

            public String f(String s) {
                return String.format("%s=?", quoteIfNeeds(s));
            }

        }), " AND ");
    }

    /**
     * SQLsB
     * @param sql SQL
     * @param parameters oChp[^
     * @throws SQLException 
     */
    private void executeSql(String sql, Object[] parameters) throws SQLException {
        if (log.isDebugEnabled()) {
            log.debug("SQL: " + sql);
            log.debug("parameters: " + Arrays.asList(parameters));
        }
        PreparedStatement stmt = conn.prepareStatement(sql);
        try {
            ValueTransporter transfer = ValueTransporter.getInstance("");
            int index = 0;
            for (Object o : parameters) {
                boolean isNull = false;
                if (o == null || String.valueOf(o).length() == 0) {
                    if (getColumnClass(index) != String.class) {
                        isNull = true;
                    }
                }
                ++index;
                if (isNull) {
                    stmt.setNull(index, types[index - 1]);
                } else {
                    transfer.setObject(stmt, index, o);
                }
            }
            final int updatedCount = stmt.executeUpdate();
            if (updatedCount != 1) {
                throw new IllegalStateException("updated count = " + updatedCount);
            }
        } finally {
            stmt.close();
        }
    }

    /**
     * Kvȏꍇ͈pň͂ށB
     * @param o V{
     * @return ꂽ
     */
    static String quoteIfNeeds(Object o) {
        final String s = String.valueOf(o);
        if (s.matches(".*\\W.*")) {
            return String.format("\"%s\"", s);
        }
        return s;
    }

    /**
     * Ns߂̉͂sB
     * @param rs ʃZbg
     * @param cmd R}h
     * @throws SQLException
     */
    private void analyzeForLinking(ResultSet rs, String cmd) throws SQLException {
        if (rs == null) {
            return;
        }
        Statement stmt = rs.getStatement();
        if (stmt == null) {
            return;
        }
        Connection conn = stmt.getConnection();
        if (conn == null) {
            return;
        }
        if (conn.isReadOnly()) {
            return;
        }
        String tableName = findTableName(cmd);
        if (tableName.length() == 0) {
            return;
        }
        List<String> pkList = findPrimaryKeys(conn, tableName);
        if (pkList.isEmpty()) {
            return;
        }
        if (findUnion(cmd)) {
            return;
        }
        this.conn = conn;
        this.tableName = tableName;
        this.primaryKeys = pkList.toArray(new String[pkList.size()]);
        this.updatable = true;
    }

    /**
     * e[u̒TB
     * SELECTP̃e[ułꍇ̂݁Ae[uԂB
     * @param cmd R}h
     * @return e[u łȂꍇ͋󕶎 
     */
    private static String findTableName(String cmd) {
        if (cmd != null) {
            Pattern p = Pattern.compile(PTN1, Pattern.CASE_INSENSITIVE);
            Matcher m = p.matcher(cmd);
            if (m.find()) {
                String afterFrom = m.group(1);
                String[] words = afterFrom.split("\\s");
                boolean foundComma = false;
                for (int i = 0; i < 2 && i < words.length; i++) {
                    String word = words[i];
                    if (word.indexOf(',') >= 0) {
                        foundComma = true;
                    }
                }
                if (!foundComma) {
                    String word = words[0];
                    if (word.matches("[A-Za-z0-9_\\.]+")) {
                        return word;
                    }
                }
            }
        }
        return EMPTY_STRING;
    }

    /**
     * vC}L[̒TB
     * P̃e[uׂẴvC}L[łꍇ̂݁Ae[uԂB
     * @param conn RlNV
     * @param tableName e[u
     * @return vC}L[ꗗ łȂꍇ͋󃊃Xg 
     * @throws SQLException
     */
    private static List<String> findPrimaryKeys(Connection conn, String tableName) throws SQLException {
        // ݒ
        DatabaseMetaData dbmeta = conn.getMetaData();
        String schema = dbmeta.getUserName();
        if (schema == null) {
            schema = EMPTY_STRING;
        }
        String schemaCondition;
        String tableNameCondition;
        if (dbmeta.storesLowerCaseIdentifiers()) {
            schemaCondition = schema.toLowerCase();
            tableNameCondition = tableName.toLowerCase();
        } else if (dbmeta.storesUpperCaseIdentifiers()) {
            schemaCondition = schema.toUpperCase();
            tableNameCondition = tableName.toUpperCase();
        } else {
            schemaCondition = schema;
            tableNameCondition = tableName;
        }
        if (tableNameCondition.indexOf('.') >= 0) {
            String[] splitted = tableNameCondition.split("\\.");
            schemaCondition = splitted[0];
            tableNameCondition = splitted[1];
        }
        // 
        List<String> pkList = getPrimaryKeys(dbmeta, schemaCondition, tableNameCondition);
        if (pkList.isEmpty()) {
            return getPrimaryKeys(dbmeta, null, tableNameCondition);
        }
        return pkList;
    }

    /**
     * vC}L[̎擾B
     * @param dbmeta DatabaseMetaData
     * @param schema XL[}
     * @param table e[u
     * @return vC}L[̃Xg
     * @throws SQLException
     */
    private static List<String> getPrimaryKeys(DatabaseMetaData dbmeta, String schema, String table) throws SQLException {
        ResultSet rs = dbmeta.getPrimaryKeys(null, schema, table);
        try {
            List<String> pkList = new ArrayList<String>();
            Set<String> schemaSet = new HashSet<String>();
            while (rs.next()) {
                pkList.add(rs.getString(4));
                schemaSet.add(rs.getString(2));
            }
            if (schemaSet.size() != 1) {
                pkList.clear();
            }
            return pkList;
        } finally {
            rs.close();
        }
    }

    /**
     * UNIONL[[h̒TB
     * @param sql SQL
     * @return UNIONL[[hꍇ <code>true</code>A
     *         Ȃꍇ <code>false</code>
     */
    private static boolean findUnion(String sql) {
        String s = sql;
        if (s.indexOf("'") >= 0) {
            if (s.indexOf("\\'") >= 0) {
                s = s.replaceAll("\\'", "");
            }
            s = s.replaceAll("'[^']+'", "''");
        }
        StringTokenizer tokenizer = new StringTokenizer(s);
        while (tokenizer.hasMoreTokens()) {
            String token = tokenizer.nextToken();
            if (token.equalsIgnoreCase("UNION")) {
                return true;
            }
        }
        return false;
    }

}
