001    /**
002     *  Licensed to the Apache Software Foundation (ASF) under one or more
003     *  contributor license agreements.  See the NOTICE file distributed with
004     *  this work for additional information regarding copyright ownership.
005     *  The ASF licenses this file to You under the Apache License, Version 2.0
006     *  (the "License"); you may not use this file except in compliance with
007     *  the License.  You may obtain a copy of the License at
008     *
009     *     http://www.apache.org/licenses/LICENSE-2.0
010     *
011     *  Unless required by applicable law or agreed to in writing, software
012     *  distributed under the License is distributed on an "AS IS" BASIS,
013     *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     *  See the License for the specific language governing permissions and
015     *  limitations under the License.
016     */
017    
018    package org.apache.geronimo.security.realm.providers;
019    
020    import java.io.IOException;
021    import java.security.MessageDigest;
022    import java.security.NoSuchAlgorithmException;
023    import java.sql.Connection;
024    import java.sql.Driver;
025    import java.sql.PreparedStatement;
026    import java.sql.ResultSet;
027    import java.sql.SQLException;
028    import java.util.HashSet;
029    import java.util.Iterator;
030    import java.util.Map;
031    import java.util.Properties;
032    import java.util.Set;
033    import javax.security.auth.Subject;
034    import javax.security.auth.callback.Callback;
035    import javax.security.auth.callback.CallbackHandler;
036    import javax.security.auth.callback.NameCallback;
037    import javax.security.auth.callback.PasswordCallback;
038    import javax.security.auth.callback.UnsupportedCallbackException;
039    import javax.security.auth.login.FailedLoginException;
040    import javax.security.auth.login.LoginException;
041    import javax.security.auth.spi.LoginModule;
042    import javax.sql.DataSource;
043    
044    import org.apache.commons.logging.Log;
045    import org.apache.commons.logging.LogFactory;
046    import org.apache.geronimo.gbean.AbstractName;
047    import org.apache.geronimo.gbean.AbstractNameQuery;
048    import org.apache.geronimo.j2ee.j2eeobjectnames.NameFactory;
049    import org.apache.geronimo.kernel.GBeanNotFoundException;
050    import org.apache.geronimo.kernel.Kernel;
051    import org.apache.geronimo.kernel.KernelRegistry;
052    import org.apache.geronimo.management.geronimo.JCAManagedConnectionFactory;
053    import org.apache.geronimo.security.jaas.JaasLoginModuleUse;
054    import org.apache.geronimo.util.encoders.HexTranslator;
055    
056    
057    /**
058     * A login module that loads security information from a SQL database.  Expects
059     * to be run by a GenericSecurityRealm (doesn't work on its own).
060     * <p>
061     * This requires database connectivity information (either 1: a dataSourceName and
062     * optional dataSourceApplication or 2: a JDBC driver, URL, username, and password)
063     * and 2 SQL queries.
064     * <p>
065     * The userSelect query should return 2 values, the username and the password in
066     * that order.  It should include one PreparedStatement parameter (a ?) which
067     * will be filled in with the username.  In other words, the query should look
068     * like: <tt>SELECT user, password FROM users WHERE username=?</tt>
069     * <p>
070     * The groupSelect query should return 2 values, the username and the group name in
071     * that order (but it may return multiple rows, one per group).  It should include
072     * one PreparedStatement parameter (a ?) which will be filled in with the username.
073     * In other words, the query should look like:
074     * <tt>SELECT user, role FROM user_roles WHERE username=?</tt>
075     *
076     * @version $Rev: 487175 $ $Date: 2006-12-14 03:10:31 -0800 (Thu, 14 Dec 2006) $
077     */
078    public class SQLLoginModule implements LoginModule {
079        private static Log log = LogFactory.getLog(SQLLoginModule.class);
080        public final static String USER_SELECT = "userSelect";
081        public final static String GROUP_SELECT = "groupSelect";
082        public final static String CONNECTION_URL = "jdbcURL";
083        public final static String USER = "jdbcUser";
084        public final static String PASSWORD = "jdbcPassword";
085        public final static String DRIVER = "jdbcDriver";
086        public final static String DATABASE_POOL_NAME = "dataSourceName";
087        public final static String DATABASE_POOL_APP_NAME = "dataSourceApplication";
088        public final static String DIGEST = "digest";
089        private String connectionURL;
090        private Properties properties;
091        private Driver driver;
092        private JCAManagedConnectionFactory factory;
093        private String userSelect;
094        private String groupSelect;
095        private String digest;
096    
097        private Subject subject;
098        private CallbackHandler handler;
099        private String cbUsername;
100        private String cbPassword;
101        private final Set groups = new HashSet();
102    
103        public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) {
104            this.subject = subject;
105            this.handler = callbackHandler;
106            userSelect = (String) options.get(USER_SELECT);
107            groupSelect = (String) options.get(GROUP_SELECT);
108    
109            digest = (String) options.get(DIGEST);
110            if(digest != null && !digest.equals("")) {
111                // Check if the digest algorithm is available
112                try {
113                    MessageDigest.getInstance(digest);
114                } catch(NoSuchAlgorithmException e) {
115                    log.error("Initialization failed. Digest algorithm "+digest+" is not available.", e);
116                    throw new IllegalArgumentException("Unable to configure SQL login module: "+e.getMessage());
117                }
118            }
119    
120            String dataSourceName = (String) options.get(DATABASE_POOL_NAME);
121            if(dataSourceName != null) {
122                dataSourceName = dataSourceName.trim();
123                String dataSourceAppName = (String) options.get(DATABASE_POOL_APP_NAME);
124                if(dataSourceAppName == null || dataSourceAppName.trim().equals("")) {
125                    dataSourceAppName = "null";
126                } else {
127                    dataSourceAppName = dataSourceAppName.trim();
128                }
129                String kernelName = (String) options.get(JaasLoginModuleUse.KERNEL_NAME_LM_OPTION);
130                Kernel kernel = KernelRegistry.getKernel(kernelName);
131                Set set = kernel.listGBeans(new AbstractNameQuery(JCAManagedConnectionFactory.class.getName()));
132                JCAManagedConnectionFactory factory;
133                for (Iterator it = set.iterator(); it.hasNext();) {
134                    AbstractName name = (AbstractName) it.next();
135                    if(name.getName().get(NameFactory.J2EE_APPLICATION).equals(dataSourceAppName) &&
136                        name.getName().get(NameFactory.J2EE_NAME).equals(dataSourceName)) {
137                        try {
138                            factory = (JCAManagedConnectionFactory) kernel.getGBean(name);
139                            String type = factory.getConnectionFactoryInterface();
140                            if(type.equals(DataSource.class.getName())) {
141                                this.factory = factory;
142                                break;
143                            }
144                        } catch (GBeanNotFoundException e) {
145                            // ignore... GBean was unregistered
146                        }
147                    }
148                }
149            } else {
150                connectionURL = (String) options.get(CONNECTION_URL);
151                properties = new Properties();
152                if(options.get(USER) != null) {
153                    properties.put("user", options.get(USER));
154                }
155                if(options.get(PASSWORD) != null) {
156                    properties.put("password", options.get(PASSWORD));
157                }
158                ClassLoader cl = (ClassLoader) options.get(JaasLoginModuleUse.CLASSLOADER_LM_OPTION);
159                try {
160                    driver = (Driver) cl.loadClass((String) options.get(DRIVER)).newInstance();
161                } catch (ClassNotFoundException e) {
162                    throw new IllegalArgumentException("Driver class " + options.get(DRIVER) + " is not available.  Perhaps you need to add it as a dependency in your deployment plan?");
163                } catch (Exception e) {
164                    throw new IllegalArgumentException("Unable to load, instantiate, register driver " + options.get(DRIVER) + ": " + e.getMessage());
165                }
166            }
167        }
168    
169        public boolean login() throws LoginException {
170            Callback[] callbacks = new Callback[2];
171    
172            callbacks[0] = new NameCallback("User name");
173            callbacks[1] = new PasswordCallback("Password", false);
174            try {
175                handler.handle(callbacks);
176            } catch (IOException ioe) {
177                throw (LoginException) new LoginException().initCause(ioe);
178            } catch (UnsupportedCallbackException uce) {
179                throw (LoginException) new LoginException().initCause(uce);
180            }
181            assert callbacks.length == 2;
182            cbUsername = ((NameCallback) callbacks[0]).getName();
183            if (cbUsername == null || cbUsername.equals("")) {
184                return false;
185            }
186            char[] provided = ((PasswordCallback) callbacks[1]).getPassword();
187            cbPassword = provided == null ? null : new String(provided);
188    
189            boolean found = false;
190            try {
191                Connection conn;
192                if(factory != null) {
193                    DataSource ds = (DataSource) factory.getConnectionFactory();
194                    conn = ds.getConnection();
195                } else {
196                    conn = driver.connect(connectionURL, properties);
197                }
198    
199                try {
200                    PreparedStatement statement = conn.prepareStatement(userSelect);
201                    try {
202                        int count = countParameters(userSelect);
203                        for(int i=0; i<count; i++) {
204                            statement.setObject(i+1, cbUsername);
205                        }
206                        ResultSet result = statement.executeQuery();
207    
208                        try {
209                            while (result.next()) {
210                                String userName = result.getString(1);
211                                String userPassword = result.getString(2);
212    
213                                if (cbUsername.equals(userName)) {
214                                    found = (cbPassword == null && userPassword == null) ||
215                                            (cbPassword != null && userPassword != null && checkPassword(userPassword, cbPassword));
216                                    break;
217                                }
218                            }
219                        } finally {
220                            result.close();
221                        }
222                    } finally {
223                        statement.close();
224                    }
225    
226                    if (!found) {
227                        throw new FailedLoginException();
228                    }
229    
230                    statement = conn.prepareStatement(groupSelect);
231                    try {
232                        int count = countParameters(groupSelect);
233                        for(int i=0; i<count; i++) {
234                            statement.setObject(i+1, cbUsername);
235                        }
236                        ResultSet result = statement.executeQuery();
237    
238                        try {
239                            while (result.next()) {
240                                String userName = result.getString(1);
241                                String groupName = result.getString(2);
242    
243                                if (cbUsername.equals(userName)) {
244                                    groups.add(new GeronimoGroupPrincipal(groupName));
245                                }
246                            }
247                        } finally {
248                            result.close();
249                        }
250                    } finally {
251                        statement.close();
252                    }
253                } finally {
254                    conn.close();
255                }
256            } catch (SQLException sqle) {
257                throw (LoginException) new LoginException("SQL error").initCause(sqle);
258            }
259    
260            return true;
261        }
262    
263        public boolean commit() throws LoginException {
264            Set principals = subject.getPrincipals();
265            principals.add(new GeronimoUserPrincipal(cbUsername));
266            Iterator iter = groups.iterator();
267            while (iter.hasNext()) {
268                principals.add(iter.next());
269            }
270    
271            return true;
272        }
273    
274        public boolean abort() throws LoginException {
275            cbUsername = null;
276            cbPassword = null;
277    
278            return true;
279        }
280    
281        public boolean logout() throws LoginException {
282            cbUsername = null;
283            cbPassword = null;
284            //todo: should remove principals put in by commit
285            return true;
286        }
287    
288        private static int countParameters(String sql) {
289            int count = 0;
290            int pos = -1;
291            while((pos = sql.indexOf('?', pos+1)) != -1) {
292                ++count;
293            }
294            return count;
295        }
296    
297        /**
298         * This method checks if the provided password is correct.  The original password may have been digested.
299         * @param real      Original password in digested form if applicable
300         * @param provided  User provided password in clear text
301         * @return true     If the password is correct
302         */
303        private boolean checkPassword(String real, String provided){
304            if(digest == null || digest.equals("")) {
305                // No digest algorithm is used
306                return real.equals(provided);
307            }
308            try {
309                // Digest the user provided password
310                MessageDigest md = MessageDigest.getInstance(digest);
311                byte[] data = md.digest(provided.getBytes());
312                // Convert bytes to hex digits
313                byte[] hexData = new byte[data.length * 2];
314                HexTranslator ht = new HexTranslator();
315                ht.encode(data, 0, data.length, hexData, 0);
316                // Compare the digested provided password with the actual one
317                return real.equalsIgnoreCase(new String(hexData));
318            } catch (NoSuchAlgorithmException e) {
319                // Should not occur.  Availability of algorithm has been checked at initialization
320                log.error("Should not occur.  Availability of algorithm has been checked at initialization.", e);
321            }
322            return false;
323        }
324    }