1// 7zAes.cpp
2
3#include "StdAfx.h"
4
5#include "../../../C/Sha256.h"
6
7#include "../../Windows/Synchronization.h"
8
9#include "../Common/StreamObjects.h"
10#include "../Common/StreamUtils.h"
11
12#include "7zAes.h"
13#include "MyAes.h"
14
15#ifndef EXTRACT_ONLY
16#include "RandGen.h"
17#endif
18
19using namespace NWindows;
20
21namespace NCrypto {
22namespace NSevenZ {
23
24bool CKeyInfo::IsEqualTo(const CKeyInfo &a) const
25{
26  if (SaltSize != a.SaltSize || NumCyclesPower != a.NumCyclesPower)
27    return false;
28  for (UInt32 i = 0; i < SaltSize; i++)
29    if (Salt[i] != a.Salt[i])
30      return false;
31  return (Password == a.Password);
32}
33
34void CKeyInfo::CalculateDigest()
35{
36  if (NumCyclesPower == 0x3F)
37  {
38    UInt32 pos;
39    for (pos = 0; pos < SaltSize; pos++)
40      Key[pos] = Salt[pos];
41    for (UInt32 i = 0; i < Password.Size() && pos < kKeySize; i++)
42      Key[pos++] = Password[i];
43    for (; pos < kKeySize; pos++)
44      Key[pos] = 0;
45  }
46  else
47  {
48    CSha256 sha;
49    Sha256_Init(&sha);
50    const UInt64 numRounds = (UInt64)1 << NumCyclesPower;
51    Byte temp[8] = { 0,0,0,0,0,0,0,0 };
52    for (UInt64 round = 0; round < numRounds; round++)
53    {
54      Sha256_Update(&sha, Salt, (size_t)SaltSize);
55      Sha256_Update(&sha, Password, Password.Size());
56      Sha256_Update(&sha, temp, 8);
57      for (int i = 0; i < 8; i++)
58        if (++(temp[i]) != 0)
59          break;
60    }
61    Sha256_Final(&sha, Key);
62  }
63}
64
65bool CKeyInfoCache::Find(CKeyInfo &key)
66{
67  FOR_VECTOR (i, Keys)
68  {
69    const CKeyInfo &cached = Keys[i];
70    if (key.IsEqualTo(cached))
71    {
72      for (int j = 0; j < kKeySize; j++)
73        key.Key[j] = cached.Key[j];
74      if (i != 0)
75        Keys.MoveToFront(i);
76      return true;
77    }
78  }
79  return false;
80}
81
82void CKeyInfoCache::Add(CKeyInfo &key)
83{
84  if (Find(key))
85    return;
86  if (Keys.Size() >= Size)
87    Keys.DeleteBack();
88  Keys.Insert(0, key);
89}
90
91static CKeyInfoCache g_GlobalKeyCache(32);
92static NSynchronization::CCriticalSection g_GlobalKeyCacheCriticalSection;
93
94CBase::CBase():
95  _cachedKeys(16),
96  _ivSize(0)
97{
98  for (int i = 0; i < sizeof(_iv); i++)
99    _iv[i] = 0;
100}
101
102void CBase::CalculateDigest()
103{
104  NSynchronization::CCriticalSectionLock lock(g_GlobalKeyCacheCriticalSection);
105  if (_cachedKeys.Find(_key))
106    g_GlobalKeyCache.Add(_key);
107  else
108  {
109    if (!g_GlobalKeyCache.Find(_key))
110    {
111      _key.CalculateDigest();
112      g_GlobalKeyCache.Add(_key);
113    }
114    _cachedKeys.Add(_key);
115  }
116}
117
118#ifndef EXTRACT_ONLY
119
120/*
121STDMETHODIMP CEncoder::ResetSalt()
122{
123  _key.SaltSize = 4;
124  g_RandomGenerator.Generate(_key.Salt, _key.SaltSize);
125  return S_OK;
126}
127*/
128
129STDMETHODIMP CEncoder::ResetInitVector()
130{
131  _ivSize = 8;
132  g_RandomGenerator.Generate(_iv, (unsigned)_ivSize);
133  return S_OK;
134}
135
136STDMETHODIMP CEncoder::WriteCoderProperties(ISequentialOutStream *outStream)
137{
138   // _key.Init();
139   for (UInt32 i = _ivSize; i < sizeof(_iv); i++)
140    _iv[i] = 0;
141
142  UInt32 ivSize = _ivSize;
143
144  // _key.NumCyclesPower = 0x3F;
145  _key.NumCyclesPower = 19;
146
147  Byte firstByte = (Byte)(_key.NumCyclesPower |
148    (((_key.SaltSize == 0) ? 0 : 1) << 7) |
149    (((ivSize == 0) ? 0 : 1) << 6));
150  RINOK(outStream->Write(&firstByte, 1, NULL));
151  if (_key.SaltSize == 0 && ivSize == 0)
152    return S_OK;
153  Byte saltSizeSpec = (Byte)((_key.SaltSize == 0) ? 0 : (_key.SaltSize - 1));
154  Byte ivSizeSpec = (Byte)((ivSize == 0) ? 0 : (ivSize - 1));
155  Byte secondByte = (Byte)(((saltSizeSpec) << 4) | ivSizeSpec);
156  RINOK(outStream->Write(&secondByte, 1, NULL));
157  if (_key.SaltSize > 0)
158  {
159    RINOK(WriteStream(outStream, _key.Salt, _key.SaltSize));
160  }
161  if (ivSize > 0)
162  {
163    RINOK(WriteStream(outStream, _iv, ivSize));
164  }
165  return S_OK;
166}
167
168HRESULT CEncoder::CreateFilter()
169{
170  _aesFilter = new CAesCbcEncoder(kKeySize);
171  return S_OK;
172}
173
174#endif
175
176STDMETHODIMP CDecoder::SetDecoderProperties2(const Byte *data, UInt32 size)
177{
178  _key.Init();
179  UInt32 i;
180  for (i = 0; i < sizeof(_iv); i++)
181    _iv[i] = 0;
182  if (size == 0)
183    return S_OK;
184  UInt32 pos = 0;
185  Byte firstByte = data[pos++];
186
187  _key.NumCyclesPower = firstByte & 0x3F;
188  if ((firstByte & 0xC0) == 0)
189    return S_OK;
190  _key.SaltSize = (firstByte >> 7) & 1;
191  UInt32 ivSize = (firstByte >> 6) & 1;
192
193  if (pos >= size)
194    return E_INVALIDARG;
195  Byte secondByte = data[pos++];
196
197  _key.SaltSize += (secondByte >> 4);
198  ivSize += (secondByte & 0x0F);
199
200  if (pos + _key.SaltSize + ivSize > size)
201    return E_INVALIDARG;
202  for (i = 0; i < _key.SaltSize; i++)
203    _key.Salt[i] = data[pos++];
204  for (i = 0; i < ivSize; i++)
205    _iv[i] = data[pos++];
206  return (_key.NumCyclesPower <= 24) ? S_OK :  E_NOTIMPL;
207}
208
209STDMETHODIMP CBaseCoder::CryptoSetPassword(const Byte *data, UInt32 size)
210{
211  _key.Password.CopyFrom(data, (size_t)size);
212  return S_OK;
213}
214
215STDMETHODIMP CBaseCoder::Init()
216{
217  CalculateDigest();
218  if (_aesFilter == 0)
219  {
220    RINOK(CreateFilter());
221  }
222  CMyComPtr<ICryptoProperties> cp;
223  RINOK(_aesFilter.QueryInterface(IID_ICryptoProperties, &cp));
224  RINOK(cp->SetKey(_key.Key, sizeof(_key.Key)));
225  RINOK(cp->SetInitVector(_iv, sizeof(_iv)));
226  return _aesFilter->Init();
227}
228
229STDMETHODIMP_(UInt32) CBaseCoder::Filter(Byte *data, UInt32 size)
230{
231  return _aesFilter->Filter(data, size);
232}
233
234HRESULT CDecoder::CreateFilter()
235{
236  _aesFilter = new CAesCbcDecoder(kKeySize);
237  return S_OK;
238}
239
240}}
241