1// 7zAes.cpp
2
3#include "StdAfx.h"
4
5#include "../../../C/Sha256.h"
6
7#include "../../Common/ComTry.h"
8
9#ifndef _7ZIP_ST
10#include "../../Windows/Synchronization.h"
11#endif
12
13#include "../Common/StreamUtils.h"
14
15#include "7zAes.h"
16#include "MyAes.h"
17
18#ifndef EXTRACT_ONLY
19#include "RandGen.h"
20#endif
21
22namespace NCrypto {
23namespace N7z {
24
25static const unsigned k_NumCyclesPower_Supported_MAX = 24;
26
27bool CKeyInfo::IsEqualTo(const CKeyInfo &a) const
28{
29  if (SaltSize != a.SaltSize || NumCyclesPower != a.NumCyclesPower)
30    return false;
31  for (unsigned i = 0; i < SaltSize; i++)
32    if (Salt[i] != a.Salt[i])
33      return false;
34  return (Password == a.Password);
35}
36
37void CKeyInfo::CalcKey()
38{
39  if (NumCyclesPower == 0x3F)
40  {
41    unsigned pos;
42    for (pos = 0; pos < SaltSize; pos++)
43      Key[pos] = Salt[pos];
44    for (unsigned i = 0; i < Password.Size() && pos < kKeySize; i++)
45      Key[pos++] = Password[i];
46    for (; pos < kKeySize; pos++)
47      Key[pos] = 0;
48  }
49  else
50  {
51    size_t bufSize = 8 + SaltSize + Password.Size();
52    CObjArray<Byte> buf(bufSize);
53    memcpy(buf, Salt, SaltSize);
54    memcpy(buf + SaltSize, Password, Password.Size());
55
56    CSha256 sha;
57    Sha256_Init(&sha);
58
59    Byte *ctr = buf + SaltSize + Password.Size();
60
61    for (unsigned i = 0; i < 8; i++)
62      ctr[i] = 0;
63
64    UInt64 numRounds = (UInt64)1 << NumCyclesPower;
65
66    do
67    {
68      Sha256_Update(&sha, buf, bufSize);
69      for (unsigned i = 0; i < 8; i++)
70        if (++(ctr[i]) != 0)
71          break;
72    }
73    while (--numRounds != 0);
74
75    Sha256_Final(&sha, Key);
76  }
77}
78
79bool CKeyInfoCache::GetKey(CKeyInfo &key)
80{
81  FOR_VECTOR (i, Keys)
82  {
83    const CKeyInfo &cached = Keys[i];
84    if (key.IsEqualTo(cached))
85    {
86      for (unsigned j = 0; j < kKeySize; j++)
87        key.Key[j] = cached.Key[j];
88      if (i != 0)
89        Keys.MoveToFront(i);
90      return true;
91    }
92  }
93  return false;
94}
95
96void CKeyInfoCache::FindAndAdd(const CKeyInfo &key)
97{
98  FOR_VECTOR (i, Keys)
99  {
100    const CKeyInfo &cached = Keys[i];
101    if (key.IsEqualTo(cached))
102    {
103      if (i != 0)
104        Keys.MoveToFront(i);
105      return;
106    }
107  }
108  Add(key);
109}
110
111void CKeyInfoCache::Add(const CKeyInfo &key)
112{
113  if (Keys.Size() >= Size)
114    Keys.DeleteBack();
115  Keys.Insert(0, key);
116}
117
118static CKeyInfoCache g_GlobalKeyCache(32);
119
120#ifndef _7ZIP_ST
121  static NWindows::NSynchronization::CCriticalSection g_GlobalKeyCacheCriticalSection;
122  #define MT_LOCK NWindows::NSynchronization::CCriticalSectionLock lock(g_GlobalKeyCacheCriticalSection);
123#else
124  #define MT_LOCK
125#endif
126
127CBase::CBase():
128  _cachedKeys(16),
129  _ivSize(0)
130{
131  for (unsigned i = 0; i < sizeof(_iv); i++)
132    _iv[i] = 0;
133}
134
135void CBase::PrepareKey()
136{
137  // BCJ2 threads use same password. So we use long lock.
138  MT_LOCK
139
140  bool finded = false;
141  if (!_cachedKeys.GetKey(_key))
142  {
143    finded = g_GlobalKeyCache.GetKey(_key);
144    if (!finded)
145      _key.CalcKey();
146    _cachedKeys.Add(_key);
147  }
148  if (!finded)
149    g_GlobalKeyCache.FindAndAdd(_key);
150}
151
152#ifndef EXTRACT_ONLY
153
154/*
155STDMETHODIMP CEncoder::ResetSalt()
156{
157  _key.SaltSize = 4;
158  g_RandomGenerator.Generate(_key.Salt, _key.SaltSize);
159  return S_OK;
160}
161*/
162
163STDMETHODIMP CEncoder::ResetInitVector()
164{
165  for (unsigned i = 0; i < sizeof(_iv); i++)
166    _iv[i] = 0;
167  _ivSize = 8;
168  g_RandomGenerator.Generate(_iv, _ivSize);
169  return S_OK;
170}
171
172STDMETHODIMP CEncoder::WriteCoderProperties(ISequentialOutStream *outStream)
173{
174  Byte props[2 + sizeof(_key.Salt) + sizeof(_iv)];
175  unsigned propsSize = 1;
176
177  props[0] = (Byte)(_key.NumCyclesPower
178      | (_key.SaltSize == 0 ? 0 : (1 << 7))
179      | (_ivSize       == 0 ? 0 : (1 << 6)));
180
181  if (_key.SaltSize != 0 || _ivSize != 0)
182  {
183    props[1] = (Byte)(
184        ((_key.SaltSize == 0 ? 0 : _key.SaltSize - 1) << 4)
185        | (_ivSize      == 0 ? 0 : _ivSize - 1));
186    memcpy(props + 2, _key.Salt, _key.SaltSize);
187    propsSize = 2 + _key.SaltSize;
188    memcpy(props + propsSize, _iv, _ivSize);
189    propsSize += _ivSize;
190  }
191
192  return WriteStream(outStream, props, propsSize);
193}
194
195CEncoder::CEncoder()
196{
197  // _key.SaltSize = 4; g_RandomGenerator.Generate(_key.Salt, _key.SaltSize);
198  // _key.NumCyclesPower = 0x3F;
199  _key.NumCyclesPower = 19;
200  _aesFilter = new CAesCbcEncoder(kKeySize);
201}
202
203#endif
204
205CDecoder::CDecoder()
206{
207  _aesFilter = new CAesCbcDecoder(kKeySize);
208}
209
210STDMETHODIMP CDecoder::SetDecoderProperties2(const Byte *data, UInt32 size)
211{
212  _key.ClearProps();
213
214  _ivSize = 0;
215  unsigned i;
216  for (i = 0; i < sizeof(_iv); i++)
217    _iv[i] = 0;
218
219  if (size == 0)
220    return S_OK;
221
222  Byte b0 = data[0];
223
224  _key.NumCyclesPower = b0 & 0x3F;
225  if ((b0 & 0xC0) == 0)
226    return size == 1 ? S_OK : E_INVALIDARG;
227
228  if (size <= 1)
229    return E_INVALIDARG;
230
231  Byte b1 = data[1];
232
233  unsigned saltSize = ((b0 >> 7) & 1) + (b1 >> 4);
234  unsigned ivSize   = ((b0 >> 6) & 1) + (b1 & 0x0F);
235
236  if (size != 2 + saltSize + ivSize)
237    return E_INVALIDARG;
238  _key.SaltSize = saltSize;
239  data += 2;
240  for (i = 0; i < saltSize; i++)
241    _key.Salt[i] = *data++;
242  for (i = 0; i < ivSize; i++)
243    _iv[i] = *data++;
244  return (_key.NumCyclesPower <= k_NumCyclesPower_Supported_MAX
245      || _key.NumCyclesPower == 0x3F) ? S_OK : E_NOTIMPL;
246}
247
248
249STDMETHODIMP CBaseCoder::CryptoSetPassword(const Byte *data, UInt32 size)
250{
251  COM_TRY_BEGIN
252
253  _key.Password.CopyFrom(data, (size_t)size);
254  return S_OK;
255
256  COM_TRY_END
257}
258
259STDMETHODIMP CBaseCoder::Init()
260{
261  COM_TRY_BEGIN
262
263  PrepareKey();
264  CMyComPtr<ICryptoProperties> cp;
265  RINOK(_aesFilter.QueryInterface(IID_ICryptoProperties, &cp));
266  if (!cp)
267    return E_FAIL;
268  RINOK(cp->SetKey(_key.Key, kKeySize));
269  RINOK(cp->SetInitVector(_iv, sizeof(_iv)));
270  return _aesFilter->Init();
271
272  COM_TRY_END
273}
274
275STDMETHODIMP_(UInt32) CBaseCoder::Filter(Byte *data, UInt32 size)
276{
277  return _aesFilter->Filter(data, size);
278}
279
280}}
281