1/*
2 * Conditions Of Use
3 *
4 * This software was developed by employees of the National Institute of
5 * Standards and Technology (NIST), an agency of the Federal Government.
6 * Pursuant to title 15 United States Code Section 105, works of NIST
7 * employees are not subject to copyright protection in the United States
8 * and are considered to be in the public domain.  As a result, a formal
9 * license is not needed to use the software.
10 *
11 * This software is provided by NIST as a service and is expressly
12 * provided "AS IS."  NIST MAKES NO WARRANTY OF ANY KIND, EXPRESS, IMPLIED
13 * OR STATUTORY, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTY OF
14 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT
15 * AND DATA ACCURACY.  NIST does not warrant or make any representations
16 * regarding the use of the software or the results thereof, including but
17 * not limited to the correctness, accuracy, reliability or usefulness of
18 * the software.
19 *
20 * Permission to use this software is contingent upon your acceptance
21 * of the terms of this agreement
22 *
23 * .
24 *
25 */
26/*******************************************************************************
27 * Product of NIST/ITL Advanced Networking Technologies Division (ANTD).       *
28 *******************************************************************************/
29package gov.nist.javax.sip.stack;
30
31import gov.nist.core.StackLogger;
32import gov.nist.javax.sip.SipStackImpl;
33
34import java.io.*;
35import java.net.*;
36import java.util.Enumeration;
37import java.util.concurrent.ConcurrentHashMap;
38import java.util.concurrent.Semaphore;
39import java.util.concurrent.TimeUnit;
40
41import javax.net.ssl.HandshakeCompletedListener;
42import javax.net.ssl.SSLSocket;
43
44/*
45 * TLS support Added by Daniel J.Martinez Manzano <dani@dif.um.es>
46 *
47 */
48
49/**
50 * Low level Input output to a socket. Caches TCP connections and takes care of re-connecting to
51 * the remote party if the other end drops the connection
52 *
53 * @version 1.2
54 *
55 * @author M. Ranganathan <br/>
56 *
57 *
58 */
59
60class IOHandler {
61
62    private Semaphore ioSemaphore = new Semaphore(1);
63
64    private SipStackImpl sipStack;
65
66    private static String TCP = "tcp";
67
68    // Added by Daniel J. Martinez Manzano <dani@dif.um.es>
69    private static String TLS = "tls";
70
71    // A cache of client sockets that can be re-used for
72    // sending tcp messages.
73    private ConcurrentHashMap<String, Socket> socketTable;
74
75    protected static String makeKey(InetAddress addr, int port) {
76        return addr.getHostAddress() + ":" + port;
77
78    }
79
80    protected IOHandler(SIPTransactionStack sipStack) {
81        this.sipStack = (SipStackImpl) sipStack;
82        this.socketTable = new ConcurrentHashMap<String, Socket>();
83
84    }
85
86    protected void putSocket(String key, Socket sock) {
87        socketTable.put(key, sock);
88
89    }
90
91    protected Socket getSocket(String key) {
92        return (Socket) socketTable.get(key);
93    }
94
95    protected void removeSocket(String key) {
96        socketTable.remove(key);
97    }
98
99    /**
100     * A private function to write things out. This needs to be synchronized as writes can occur
101     * from multiple threads. We write in chunks to allow the other side to synchronize for large
102     * sized writes.
103     */
104    private void writeChunks(OutputStream outputStream, byte[] bytes, int length)
105            throws IOException {
106        // Chunk size is 16K - this hack is for large
107        // writes over slow connections.
108        synchronized (outputStream) {
109            // outputStream.write(bytes,0,length);
110            int chunksize = 8 * 1024;
111            for (int p = 0; p < length; p += chunksize) {
112                int chunk = p + chunksize < length ? chunksize : length - p;
113                outputStream.write(bytes, p, chunk);
114            }
115        }
116        outputStream.flush();
117    }
118
119    /**
120     * Creates and binds, if necessary, a socket connected to the specified destination address
121     * and port and then returns its local address.
122     *
123     * @param dst the destination address that the socket would need to connect to.
124     * @param dstPort the port number that the connection would be established with.
125     * @param localAddress the address that we would like to bind on (null for the "any" address).
126     * @param localPort the port that we'd like our socket to bind to (0 for a random port).
127     *
128     * @return the SocketAddress that this handler would use when connecting to the specified
129     *         destination address and port.
130     *
131     * @throws IOException
132     */
133    public SocketAddress obtainLocalAddress(InetAddress dst, int dstPort,
134            InetAddress localAddress, int localPort) throws IOException {
135        String key = makeKey(dst, dstPort);
136
137        Socket clientSock = getSocket(key);
138
139        if (clientSock == null) {
140            clientSock = sipStack.getNetworkLayer().createSocket(dst, dstPort, localAddress,
141                    localPort);
142            putSocket(key, clientSock);
143        }
144
145        return clientSock.getLocalSocketAddress();
146
147    }
148
149    /**
150     * Send an array of bytes.
151     *
152     * @param receiverAddress -- inet address
153     * @param contactPort -- port to connect to.
154     * @param transport -- tcp or udp.
155     * @param retry -- retry to connect if the other end closed connection
156     * @throws IOException -- if there is an IO exception sending message.
157     */
158
159    public Socket sendBytes(InetAddress senderAddress, InetAddress receiverAddress,
160            int contactPort, String transport, byte[] bytes, boolean retry,
161            MessageChannel messageChannel) throws IOException {
162        int retry_count = 0;
163        int max_retry = retry ? 2 : 1;
164        // Server uses TCP transport. TCP client sockets are cached
165        int length = bytes.length;
166        if (sipStack.isLoggingEnabled()) {
167            sipStack.getStackLogger().logDebug(
168                    "sendBytes " + transport + " inAddr " + receiverAddress.getHostAddress()
169                            + " port = " + contactPort + " length = " + length);
170        }
171        if (sipStack.isLoggingEnabled() && sipStack.isLogStackTraceOnMessageSend()) {
172            sipStack.getStackLogger().logStackTrace(StackLogger.TRACE_INFO);
173        }
174        if (transport.compareToIgnoreCase(TCP) == 0) {
175            String key = makeKey(receiverAddress, contactPort);
176            // This should be in a synchronized block ( reported by
177            // Jayashenkhar ( lucent ).
178
179            try {
180                boolean retval = this.ioSemaphore.tryAcquire(10000, TimeUnit.MILLISECONDS);
181                if (!retval) {
182                    throw new IOException(
183                            "Could not acquire IO Semaphore after 10 seconds -- giving up ");
184                }
185            } catch (InterruptedException ex) {
186                throw new IOException("exception in acquiring sem");
187            }
188            Socket clientSock = getSocket(key);
189
190            try {
191
192                while (retry_count < max_retry) {
193                    if (clientSock == null) {
194                        if (sipStack.isLoggingEnabled()) {
195                            sipStack.getStackLogger().logDebug("inaddr = " + receiverAddress);
196                            sipStack.getStackLogger().logDebug("port = " + contactPort);
197                        }
198                        // note that the IP Address for stack may not be
199                        // assigned.
200                        // sender address is the address of the listening point.
201                        // in version 1.1 all listening points have the same IP
202                        // address (i.e. that of the stack). In version 1.2
203                        // the IP address is on a per listening point basis.
204                        clientSock = sipStack.getNetworkLayer().createSocket(receiverAddress,
205                                contactPort, senderAddress);
206                        OutputStream outputStream = clientSock.getOutputStream();
207                        writeChunks(outputStream, bytes, length);
208                        putSocket(key, clientSock);
209                        break;
210                    } else {
211                        try {
212                            OutputStream outputStream = clientSock.getOutputStream();
213                            writeChunks(outputStream, bytes, length);
214                            break;
215                        } catch (IOException ex) {
216                            if (sipStack.isLoggingEnabled())
217                                sipStack.getStackLogger().logDebug(
218                                        "IOException occured retryCount " + retry_count);
219                            // old connection is bad.
220                            // remove from our table.
221                            removeSocket(key);
222                            try {
223                                clientSock.close();
224                            } catch (Exception e) {
225                            }
226                            clientSock = null;
227                            retry_count++;
228                        }
229                    }
230                }
231            } finally {
232                ioSemaphore.release();
233            }
234
235            if (clientSock == null) {
236
237                if (sipStack.isLoggingEnabled()) {
238                    sipStack.getStackLogger().logDebug(this.socketTable.toString());
239                    sipStack.getStackLogger().logError(
240                            "Could not connect to " + receiverAddress + ":" + contactPort);
241                }
242
243                throw new IOException("Could not connect to " + receiverAddress + ":"
244                        + contactPort);
245            } else
246                return clientSock;
247
248            // Added by Daniel J. Martinez Manzano <dani@dif.um.es>
249            // Copied and modified from the former section for TCP
250        } else if (transport.compareToIgnoreCase(TLS) == 0) {
251            String key = makeKey(receiverAddress, contactPort);
252            try {
253                boolean retval = this.ioSemaphore.tryAcquire(10000, TimeUnit.MILLISECONDS);
254                if (!retval)
255                    throw new IOException("Timeout acquiring IO SEM");
256            } catch (InterruptedException ex) {
257                throw new IOException("exception in acquiring sem");
258            }
259            Socket clientSock = getSocket(key);
260
261            try {
262                while (retry_count < max_retry) {
263                    if (clientSock == null) {
264                        if (sipStack.isLoggingEnabled()) {
265                            sipStack.getStackLogger().logDebug("inaddr = " + receiverAddress);
266                            sipStack.getStackLogger().logDebug("port = " + contactPort);
267                        }
268
269                        clientSock = sipStack.getNetworkLayer().createSSLSocket(receiverAddress,
270                                contactPort, senderAddress);
271                        SSLSocket sslsock = (SSLSocket) clientSock;
272                        HandshakeCompletedListener listner = new HandshakeCompletedListenerImpl(
273                                (TLSMessageChannel) messageChannel);
274                        ((TLSMessageChannel) messageChannel)
275                                .setHandshakeCompletedListener(listner);
276                        sslsock.addHandshakeCompletedListener(listner);
277                        sslsock.setEnabledProtocols(sipStack.getEnabledProtocols());
278                        sslsock.startHandshake();
279
280                        OutputStream outputStream = clientSock.getOutputStream();
281                        writeChunks(outputStream, bytes, length);
282                        putSocket(key, clientSock);
283                        break;
284                    } else {
285                        try {
286                            OutputStream outputStream = clientSock.getOutputStream();
287                            writeChunks(outputStream, bytes, length);
288                            break;
289                        } catch (IOException ex) {
290                            if (sipStack.isLoggingEnabled())
291                                sipStack.getStackLogger().logException(ex);
292                            // old connection is bad.
293                            // remove from our table.
294                            removeSocket(key);
295                            try {
296                                clientSock.close();
297                            } catch (Exception e) {
298                            }
299                            clientSock = null;
300                            retry_count++;
301                        }
302                    }
303                }
304            } finally {
305                ioSemaphore.release();
306            }
307            if (clientSock == null) {
308                throw new IOException("Could not connect to " + receiverAddress + ":"
309                        + contactPort);
310            } else
311                return clientSock;
312
313        } else {
314            // This is a UDP transport...
315            DatagramSocket datagramSock = sipStack.getNetworkLayer().createDatagramSocket();
316            datagramSock.connect(receiverAddress, contactPort);
317            DatagramPacket dgPacket = new DatagramPacket(bytes, 0, length, receiverAddress,
318                    contactPort);
319            datagramSock.send(dgPacket);
320            datagramSock.close();
321            return null;
322        }
323
324    }
325
326    /**
327     * Close all the cached connections.
328     */
329    public void closeAll() {
330        for (Enumeration<Socket> values = socketTable.elements(); values.hasMoreElements();) {
331            Socket s = (Socket) values.nextElement();
332            try {
333                s.close();
334            } catch (IOException ex) {
335            }
336        }
337
338    }
339
340}
341