001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.util; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.io.ObjectInputStream; 022import java.io.ObjectStreamClass; 023import java.lang.reflect.Proxy; 024import java.util.*; 025 026import org.slf4j.Logger; 027import org.slf4j.LoggerFactory; 028 029public class ClassLoadingAwareObjectInputStream extends ObjectInputStream { 030 031 private static final Logger LOG = LoggerFactory.getLogger(ClassLoadingAwareObjectInputStream.class); 032 private static final ClassLoader FALLBACK_CLASS_LOADER = 033 ClassLoadingAwareObjectInputStream.class.getClassLoader(); 034 035 public static final String[] serializablePackages; 036 037 private List<String> trustedPackages = new ArrayList<String>(); 038 private boolean trustAllPackages = false; 039 040 private final ClassLoader inLoader; 041 042 static { 043 serializablePackages = System.getProperty("org.apache.activemq.SERIALIZABLE_PACKAGES","java.lang,org.apache.activemq,org.fusesource.hawtbuf,com.thoughtworks.xstream.mapper").split(","); 044 } 045 046 public ClassLoadingAwareObjectInputStream(InputStream in) throws IOException { 047 super(in); 048 inLoader = in.getClass().getClassLoader(); 049 trustedPackages.addAll(Arrays.asList(serializablePackages)); 050 } 051 052 @Override 053 protected Class<?> resolveClass(ObjectStreamClass classDesc) throws IOException, ClassNotFoundException { 054 ClassLoader cl = Thread.currentThread().getContextClassLoader(); 055 Class clazz = load(classDesc.getName(), cl, inLoader); 056 checkSecurity(clazz); 057 return clazz; 058 } 059 060 @Override 061 protected Class<?> resolveProxyClass(String[] interfaces) throws IOException, ClassNotFoundException { 062 ClassLoader cl = Thread.currentThread().getContextClassLoader(); 063 Class[] cinterfaces = new Class[interfaces.length]; 064 for (int i = 0; i < interfaces.length; i++) { 065 cinterfaces[i] = load(interfaces[i], cl); 066 } 067 068 Class clazz = null; 069 try { 070 clazz = Proxy.getProxyClass(cl, cinterfaces); 071 } catch (IllegalArgumentException e) { 072 try { 073 clazz = Proxy.getProxyClass(inLoader, cinterfaces); 074 } catch (IllegalArgumentException e1) { 075 // ignore 076 } 077 try { 078 clazz = Proxy.getProxyClass(FALLBACK_CLASS_LOADER, cinterfaces); 079 } catch (IllegalArgumentException e2) { 080 // ignore 081 } 082 } 083 084 if (clazz != null) { 085 checkSecurity(clazz); 086 return clazz; 087 } else { 088 throw new ClassNotFoundException(null); 089 } 090 } 091 092 public static boolean isAllAllowed() { 093 return serializablePackages.length == 1 && serializablePackages[0].equals("*"); 094 } 095 096 private boolean trustAllPackages() { 097 return trustAllPackages || (trustedPackages.size() == 1 && trustedPackages.get(0).equals("*")); 098 } 099 100 private void checkSecurity(Class clazz) throws ClassNotFoundException { 101 if (trustAllPackages() || clazz.isPrimitive()) { 102 return; 103 } 104 105 boolean found = false; 106 Package thePackage = clazz.getPackage(); 107 if (thePackage != null) { 108 for (String trustedPackage : getTrustedPackages()) { 109 if (thePackage.getName().equals(trustedPackage) || thePackage.getName().startsWith(trustedPackage + ".")) { 110 found = true; 111 break; 112 } 113 } 114 if (!found) { 115 throw new ClassNotFoundException("Forbidden " + clazz + "! This class is not trusted to be serialized as ObjectMessage payload. Please take a look at http://activemq.apache.org/objectmessage.html for more information on how to configure trusted classes."); 116 } 117 } 118 } 119 120 private Class<?> load(String className, ClassLoader... cl) throws ClassNotFoundException { 121 // check for simple types first 122 final Class<?> clazz = loadSimpleType(className); 123 if (clazz != null) { 124 LOG.trace("Loaded class: {} as simple type -> {}", className, clazz); 125 return clazz; 126 } 127 128 // try the different class loaders 129 for (ClassLoader loader : cl) { 130 LOG.trace("Attempting to load class: {} using classloader: {}", className, cl); 131 try { 132 Class<?> answer = Class.forName(className, false, loader); 133 if (LOG.isTraceEnabled()) { 134 LOG.trace("Loaded class: {} using classloader: {} -> {}", className, cl, answer); 135 } 136 return answer; 137 } catch (ClassNotFoundException e) { 138 LOG.trace("Class not found: {} using classloader: {}", className, cl); 139 // ignore 140 } 141 } 142 143 // and then the fallback class loader 144 return Class.forName(className, false, FALLBACK_CLASS_LOADER); 145 } 146 147 /** 148 * Load a simple type 149 * 150 * @param name the name of the class to load 151 * @return the class or <tt>null</tt> if it could not be loaded 152 */ 153 public static Class<?> loadSimpleType(String name) { 154 // code from ObjectHelper.loadSimpleType in Apache Camel 155 156 // special for byte[] or Object[] as its common to use 157 if ("java.lang.byte[]".equals(name) || "byte[]".equals(name)) { 158 return byte[].class; 159 } else if ("java.lang.Byte[]".equals(name) || "Byte[]".equals(name)) { 160 return Byte[].class; 161 } else if ("java.lang.Object[]".equals(name) || "Object[]".equals(name)) { 162 return Object[].class; 163 } else if ("java.lang.String[]".equals(name) || "String[]".equals(name)) { 164 return String[].class; 165 // and these is common as well 166 } else if ("java.lang.String".equals(name) || "String".equals(name)) { 167 return String.class; 168 } else if ("java.lang.Boolean".equals(name) || "Boolean".equals(name)) { 169 return Boolean.class; 170 } else if ("boolean".equals(name)) { 171 return boolean.class; 172 } else if ("java.lang.Integer".equals(name) || "Integer".equals(name)) { 173 return Integer.class; 174 } else if ("int".equals(name)) { 175 return int.class; 176 } else if ("java.lang.Long".equals(name) || "Long".equals(name)) { 177 return Long.class; 178 } else if ("long".equals(name)) { 179 return long.class; 180 } else if ("java.lang.Short".equals(name) || "Short".equals(name)) { 181 return Short.class; 182 } else if ("short".equals(name)) { 183 return short.class; 184 } else if ("java.lang.Byte".equals(name) || "Byte".equals(name)) { 185 return Byte.class; 186 } else if ("byte".equals(name)) { 187 return byte.class; 188 } else if ("java.lang.Float".equals(name) || "Float".equals(name)) { 189 return Float.class; 190 } else if ("float".equals(name)) { 191 return float.class; 192 } else if ("java.lang.Double".equals(name) || "Double".equals(name)) { 193 return Double.class; 194 } else if ("double".equals(name)) { 195 return double.class; 196 } else if ("void".equals(name)) { 197 return void.class; 198 } 199 200 return null; 201 } 202 203 public List<String> getTrustedPackages() { 204 return trustedPackages; 205 } 206 207 public void setTrustedPackages(List<String> trustedPackages) { 208 this.trustedPackages = trustedPackages; 209 } 210 211 public void addTrustedPackage(String trustedPackage) { 212 this.trustedPackages.add(trustedPackage); 213 } 214 215 public boolean isTrustAllPackages() { 216 return trustAllPackages; 217 } 218 219 public void setTrustAllPackages(boolean trustAllPackages) { 220 this.trustAllPackages = trustAllPackages; 221 } 222}