1/**
2 * @license
3 * Copyright 2016 Google Inc. All rights reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *   http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.security.wycheproof;
18
19import java.io.ByteArrayInputStream;
20import java.io.IOException;
21import java.io.InputStream;
22import java.security.NoSuchAlgorithmException;
23import java.security.SecureRandom;
24import java.security.spec.AlgorithmParameterSpec;
25import java.util.ArrayList;
26import java.util.Arrays;
27import javax.crypto.Cipher;
28import javax.crypto.CipherInputStream;
29import javax.crypto.spec.GCMParameterSpec;
30import javax.crypto.spec.SecretKeySpec;
31import junit.framework.TestCase;
32
33/** CipherInputStream tests */
34public class CipherInputStreamTest extends TestCase {
35  static final SecureRandom rand = new SecureRandom();
36
37  static byte[] randomBytes(int size) {
38    byte[] bytes = new byte[size];
39    rand.nextBytes(bytes);
40    return bytes;
41  }
42
43  static SecretKeySpec randomKey(String algorithm, int keySizeInBytes) {
44    return new SecretKeySpec(randomBytes(keySizeInBytes), "AES");
45  }
46
47  static AlgorithmParameterSpec randomParameters(
48      String algorithm, int ivSizeInBytes, int tagSizeInBytes) {
49    if ("AES/GCM/NoPadding".equals(algorithm) || "AES/EAX/NoPadding".equals(algorithm)) {
50      return new GCMParameterSpec(8 * tagSizeInBytes, randomBytes(ivSizeInBytes));
51    }
52    return null;
53  }
54
55  /** Test vectors */
56  public static class TestVector {
57    public String algorithm;
58    public SecretKeySpec key;
59    public AlgorithmParameterSpec params;
60    public byte[] pt;
61    public byte[] aad;
62    public byte[] ct;
63
64    @SuppressWarnings("InsecureCryptoUsage")
65    public TestVector(
66        String algorithm, int keySize, int ivSize, int tagSize, int ptSize, int aadSize)
67        throws Exception {
68      this.algorithm = algorithm;
69      this.key = randomKey(algorithm, keySize);
70      this.params = randomParameters(algorithm, ivSize, tagSize);
71      this.pt = randomBytes(ptSize);
72      this.aad = randomBytes(aadSize);
73      Cipher cipher = Cipher.getInstance(algorithm);
74      cipher.init(Cipher.ENCRYPT_MODE, this.key, this.params);
75      cipher.updateAAD(aad);
76      this.ct = cipher.doFinal(pt);
77    }
78  }
79
80  Iterable<TestVector> getTestVectors(
81      String algorithm,
82      int[] keySizes,
83      int[] ivSizes,
84      int[] tagSizes,
85      int[] ptSizes,
86      int[] aadSizes)
87      throws Exception {
88    ArrayList<TestVector> result = new ArrayList<TestVector>();
89    for (int keySize : keySizes) {
90      for (int ivSize : ivSizes) {
91        for (int tagSize : tagSizes) {
92          for (int ptSize : ptSizes) {
93            for (int aadSize : aadSizes) {
94              result.add(new TestVector(algorithm, keySize, ivSize, tagSize, ptSize, aadSize));
95            }
96          }
97        }
98      }
99    }
100    return result;
101  }
102
103  @SuppressWarnings("InsecureCryptoUsage")
104  public void testEncrypt(Iterable<TestVector> tests) throws Exception {
105    for (TestVector t : tests) {
106      Cipher cipher = Cipher.getInstance(t.algorithm);
107      cipher.init(Cipher.ENCRYPT_MODE, t.key, t.params);
108      cipher.updateAAD(t.aad);
109      InputStream is = new ByteArrayInputStream(t.pt);
110      CipherInputStream cis = new CipherInputStream(is, cipher);
111      byte[] result = new byte[t.ct.length];
112      int totalLength = 0;
113      int length = 0;
114      do {
115        length = cis.read(result, totalLength, result.length - totalLength);
116        if (length > 0) {
117          totalLength += length;
118        }
119      } while (length >= 0 && totalLength != result.length);
120      assertEquals(-1, cis.read());
121      assertEquals(TestUtil.bytesToHex(t.ct), TestUtil.bytesToHex(result));
122      cis.close();
123    }
124  }
125
126  /** JDK-8016249: CipherInputStream in decrypt mode fails on close with AEAD ciphers */
127  @SuppressWarnings("InsecureCryptoUsage")
128  public void testDecrypt(Iterable<TestVector> tests) throws Exception {
129    for (TestVector t : tests) {
130      Cipher cipher = Cipher.getInstance(t.algorithm);
131      cipher.init(Cipher.DECRYPT_MODE, t.key, t.params);
132      cipher.updateAAD(t.aad);
133      InputStream is = new ByteArrayInputStream(t.ct);
134      CipherInputStream cis = new CipherInputStream(is, cipher);
135      byte[] result = new byte[t.pt.length];
136      int totalLength = 0;
137      int length = 0;
138      do {
139        length = cis.read(result, totalLength, result.length - totalLength);
140        if (length > 0) {
141          totalLength += length;
142        }
143      } while (length >= 0 && totalLength != result.length);
144      assertEquals(-1, cis.read());
145      cis.close();
146      assertEquals(TestUtil.bytesToHex(t.pt), TestUtil.bytesToHex(result));
147    }
148  }
149
150  /**
151   * JDK-8016171 : CipherInputStream masks ciphertext tampering with AEAD ciphers in decrypt mode
152   * Further description of the bug is here:
153   * https://blog.heckel.xyz/2014/03/01/cipherinputstream-for-aead-modes-is-broken-in-jdk7-gcm/
154   * BouncyCastle claims that this bug is fixed in version 1.51. However, the test below still fails
155   * with BouncyCastle v 1.52. A possible explanation is that BouncyCastle has its own
156   * implemenatation of CipherInputStream (org.bouncycastle.crypto.io.CipherInputStream).
157   */
158  @SuppressWarnings("InsecureCryptoUsage")
159  public void testCorruptDecrypt(Iterable<TestVector> tests) throws Exception {
160    for (TestVector t : tests) {
161      Cipher cipher = Cipher.getInstance(t.algorithm);
162      cipher.init(Cipher.DECRYPT_MODE, t.key, t.params);
163      cipher.updateAAD(t.aad);
164      byte[] ct = Arrays.copyOf(t.ct, t.ct.length);
165      ct[ct.length - 1] ^= (byte) 1;
166      InputStream is = new ByteArrayInputStream(ct);
167      CipherInputStream cis = new CipherInputStream(is, cipher);
168      try {
169        byte[] result = new byte[t.pt.length];
170        int totalLength = 0;
171        int length = 0;
172        do {
173          length = cis.read(result, totalLength, result.length - totalLength);
174          if (length > 0) {
175            totalLength += length;
176          }
177        } while (length >= 0 && totalLength != result.length);
178        cis.close();
179        if (result.length > 0) {
180          fail(
181              "this should fail; decrypted:"
182                  + TestUtil.bytesToHex(result)
183                  + " pt: "
184                  + TestUtil.bytesToHex(t.pt));
185        }
186      } catch (IOException ex) {
187        // expected
188      }
189    }
190  }
191
192  @SuppressWarnings("InsecureCryptoUsage")
193  public void testCorruptDecryptEmpty(Iterable<TestVector> tests) throws Exception {
194    for (TestVector t : tests) {
195      Cipher cipher = Cipher.getInstance(t.algorithm);
196      cipher.init(Cipher.DECRYPT_MODE, t.key, t.params);
197      cipher.updateAAD(t.aad);
198      byte[] ct = Arrays.copyOf(t.ct, t.ct.length);
199      ct[ct.length - 1] ^= (byte) 1;
200      InputStream is = new ByteArrayInputStream(ct);
201      CipherInputStream cis = new CipherInputStream(is, cipher);
202      try {
203        byte[] result = new byte[t.pt.length];
204        int totalLength = 0;
205        int length = 0;
206        do {
207          length = cis.read(result, totalLength, result.length - totalLength);
208          if (length > 0) {
209            totalLength += length;
210          }
211        } while (length >= 0 && totalLength != result.length);
212        cis.close();
213        fail("this should fail");
214      } catch (IOException ex) {
215        // expected
216      }
217    }
218  }
219
220  public void testAesGcm() throws Exception {
221    final int[] keySizes = {16, 32};
222    final int[] ivSizes = {12};
223    final int[] tagSizes = {12, 16};
224    final int[] ptSizes = {0, 8, 16, 65, 8100};
225    final int[] aadSizes = {0, 8, 24};
226    Iterable<TestVector> v =
227        getTestVectors("AES/GCM/NoPadding", keySizes, ivSizes, tagSizes, ptSizes, aadSizes);
228    testEncrypt(v);
229    testDecrypt(v);
230  }
231
232  public void testCorruptAesGcm() throws Exception {
233    final int[] keySizes = {16, 32};
234    final int[] ivSizes = {12};
235    final int[] tagSizes = {12, 16};
236    final int[] ptSizes = {8, 16, 65, 8100};
237    final int[] aadSizes = {0, 8, 24};
238    Iterable<TestVector> v =
239        getTestVectors("AES/GCM/NoPadding", keySizes, ivSizes, tagSizes, ptSizes, aadSizes);
240    testCorruptDecrypt(v);
241  }
242
243  /**
244   * Unfortunately Oracle thinks that returning an empty array is valid behaviour for corrupt
245   * ciphertexts. Because of this we test empty plaintext separately to distinguish behaviour
246   * considered acceptable by Oracle from other behaviour.
247   */
248  public void testEmptyPlaintext() throws Exception {
249    final int[] keySizes = {16, 32};
250    final int[] ivSizes = {12};
251    final int[] tagSizes = {12, 16};
252    final int[] ptSizes = {0};
253    final int[] aadSizes = {0, 8, 24};
254    Iterable<TestVector> v =
255        getTestVectors("AES/GCM/NoPadding", keySizes, ivSizes, tagSizes, ptSizes, aadSizes);
256    testCorruptDecryptEmpty(v);
257  }
258
259  /** Tests CipherOutputStream with AES-EAX if this algorithm is supported by the provider. */
260  public void testAesEax() throws Exception {
261    final String algorithm = "AES/EAX/NoPadding";
262    final int[] keySizes = {16, 32};
263    final int[] ivSizes = {12, 16};
264    final int[] tagSizes = {12, 16};
265    final int[] ptSizes = {0, 8, 16, 65, 8100};
266    final int[] aadSizes = {0, 8, 24};
267    try {
268      Cipher.getInstance(algorithm);
269    } catch (NoSuchAlgorithmException ex) {
270      System.out.println("Skipping testAesEax");
271      return;
272    }
273    Iterable<TestVector> v =
274        getTestVectors(algorithm, keySizes, ivSizes, tagSizes, ptSizes, aadSizes);
275    testEncrypt(v);
276    testDecrypt(v);
277    testCorruptDecrypt(v);
278  }
279}
280