001    /**
002     *  Licensed to the Apache Software Foundation (ASF) under one or more
003     *  contributor license agreements.  See the NOTICE file distributed with
004     *  this work for additional information regarding copyright ownership.
005     *  The ASF licenses this file to You under the Apache License, Version 2.0
006     *  (the "License"); you may not use this file except in compliance with
007     *  the License.  You may obtain a copy of the License at
008     *
009     *     http://www.apache.org/licenses/LICENSE-2.0
010     *
011     *  Unless required by applicable law or agreed to in writing, software
012     *  distributed under the License is distributed on an "AS IS" BASIS,
013     *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     *  See the License for the specific language governing permissions and
015     *  limitations under the License.
016     */
017    
018    package org.apache.geronimo.security.network.protocol;
019    
020    import org.activeio.AsyncChannel;
021    import org.activeio.FilterAsyncChannel;
022    import org.activeio.Packet;
023    import org.activeio.adapter.PacketOutputStream;
024    import org.activeio.adapter.PacketToInputStream;
025    import org.activeio.packet.AppendedPacket;
026    import org.activeio.packet.ByteArrayPacket;
027    import org.activeio.packet.FilterPacket;
028    import org.apache.geronimo.security.ContextManager;
029    import org.apache.geronimo.security.IdentificationPrincipal;
030    import org.apache.geronimo.security.SubjectId;
031    
032    import javax.security.auth.Subject;
033    import java.io.DataInputStream;
034    import java.io.DataOutputStream;
035    import java.io.IOException;
036    import java.security.AccessController;
037    import java.util.Collection;
038    
039    /**
040     * SubjectCarryingChannel is a FilterAsynchChannel that allows you to send
041     * the subject associated with the current write operation down to the remote
042     * end of the channel.
043     *
044     * @version $Rev: 487175 $ $Date: 2006-12-14 03:10:31 -0800 (Thu, 14 Dec 2006) $
045     */
046    public class SubjectCarryingChannel extends FilterAsyncChannel {
047    
048        static final byte PASSTHROUGH = (byte) 0x00;
049        static final byte SET_SUBJECT = (byte) 0x01;
050        static final byte CLEAR_SUBJECT = (byte) 0x2;
051    
052        final private ByteArrayPacket header = new ByteArrayPacket(new byte[1 + 8 + 4]);
053    
054        private Subject remoteSubject;
055        private Subject localSubject;
056    
057        private final boolean enableLocalSubjectPublishing;
058        private final boolean enableRemoteSubjectConsumption;
059    
060        public SubjectCarryingChannel(AsyncChannel next) {
061            this(next, true, true);
062        }
063    
064        public SubjectCarryingChannel(AsyncChannel next, boolean enableLocalSubjectPublishing, boolean enableRemoteSubjectConsumption) {
065            super(next);
066            this.enableLocalSubjectPublishing = enableLocalSubjectPublishing;
067            this.enableRemoteSubjectConsumption = enableRemoteSubjectConsumption;
068        }
069    
070        public void write(Packet packet) throws IOException {
071    
072            // Don't add anything to the packet stream if subject writing is not enabled.
073            if (!enableLocalSubjectPublishing) {
074                super.write(packet);
075                return;
076            }
077    
078            Subject subject = Subject.getSubject(AccessController.getContext());
079            if (remoteSubject != subject) {
080                remoteSubject = subject;
081                Collection principals = remoteSubject.getPrincipals(IdentificationPrincipal.class);
082    
083                if (principals.isEmpty()) {
084                    super.write(createClearSubjectPackt());
085                } else {
086                    IdentificationPrincipal principal = (IdentificationPrincipal) principals.iterator().next();
087                    SubjectId subjectId = principal.getId();
088                    super.write(createSubjectPacket(subjectId.getSubjectId(), subjectId.getHash()));
089                }
090    
091            }
092            super.write(createPassthroughPacket(packet));
093        }
094    
095        public class SubjectPacketFilter extends FilterPacket {
096    
097            SubjectPacketFilter(Packet packet) {
098                super(packet);
099            }
100    
101            public Object narrow(Class target) {
102                if (target == SubjectContext.class) {
103                    return new SubjectContext() {
104                        public Subject getSubject() {
105                            return remoteSubject;
106                        }
107                    };
108                }
109                return super.narrow(target);
110            }
111    
112            public Packet filter(Packet packet) {
113                return new SubjectPacketFilter(packet);
114            }
115    
116        }
117    
118        public void onPacket(Packet packet) {
119    
120            // Don't take anything to the packet stream if subject reading is not enabled.
121            if (!enableRemoteSubjectConsumption) {
122                super.onPacket(packet);
123                return;
124            }
125    
126            try {
127                switch (packet.read()) {
128                    case CLEAR_SUBJECT:
129                        localSubject = null;
130                        return;
131                    case SET_SUBJECT:
132                        SubjectId subjectId = extractSubjectId(packet);
133                        localSubject = ContextManager.getRegisteredSubject(subjectId);
134                        return;
135                    case PASSTHROUGH:
136                        super.onPacket(new SubjectPacketFilter(packet));
137                }
138            } catch (IOException e) {
139                super.onPacketError(e);
140            }
141    
142            super.onPacket(packet);
143        }
144    
145        /**
146         */
147        private SubjectId extractSubjectId(Packet packet) throws IOException {
148            DataInputStream is = new DataInputStream(new PacketToInputStream(packet));
149            Long id = new Long(is.readLong());
150            byte hash[] = new byte[is.readInt()];
151            return new SubjectId(id, hash);
152        }
153    
154        private Packet createClearSubjectPackt() {
155            header.clear();
156            header.write(CLEAR_SUBJECT);
157            header.flip();
158            return header;
159        }
160    
161        private Packet createSubjectPacket(Long subjectId, byte[] hash) throws IOException {
162            header.clear();
163            DataOutputStream os = new DataOutputStream(new PacketOutputStream(header));
164            os.writeByte(SET_SUBJECT);
165            os.writeLong(subjectId.longValue());
166            os.writeInt(hash.length);
167            os.close();
168            header.flip();
169            return AppendedPacket.join(header, new ByteArrayPacket(hash));
170        }
171    
172        private Packet createPassthroughPacket(Packet packet) {
173            header.clear();
174            header.write(PASSTHROUGH);
175            header.flip();
176            return AppendedPacket.join(header, packet);
177        }
178    
179        public Subject getLocalSubject() {
180            return localSubject;
181        }
182    
183        public Subject getRemoteSubject() {
184            return remoteSubject;
185        }
186    
187    }