1//
2// Copyright (C) 2015 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "trunks/policy_session_impl.h"
18
19#include <crypto/sha2.h>
20#include <gmock/gmock.h>
21#include <gtest/gtest.h>
22
23#include "trunks/error_codes.h"
24#include "trunks/mock_session_manager.h"
25#include "trunks/mock_tpm.h"
26#include "trunks/tpm_generated.h"
27#include "trunks/trunks_factory_for_test.h"
28
29using testing::_;
30using testing::NiceMock;
31using testing::Return;
32using testing::SaveArg;
33using testing::SetArgPointee;
34
35namespace trunks {
36
37class PolicySessionTest : public testing::Test {
38 public:
39  PolicySessionTest() {}
40  ~PolicySessionTest() override {}
41
42  void SetUp() override {
43    factory_.set_session_manager(&mock_session_manager_);
44    factory_.set_tpm(&mock_tpm_);
45  }
46
47  HmacAuthorizationDelegate* GetHmacDelegate(PolicySessionImpl* session) {
48    return &(session->hmac_delegate_);
49  }
50
51 protected:
52  TrunksFactoryForTest factory_;
53  NiceMock<MockSessionManager> mock_session_manager_;
54  NiceMock<MockTpm> mock_tpm_;
55};
56
57TEST_F(PolicySessionTest, GetDelegateUninitialized) {
58  PolicySessionImpl session(factory_);
59  EXPECT_CALL(mock_session_manager_, GetSessionHandle())
60      .WillRepeatedly(Return(kUninitializedHandle));
61  EXPECT_EQ(nullptr, session.GetDelegate());
62}
63
64TEST_F(PolicySessionTest, GetDelegateSuccess) {
65  PolicySessionImpl session(factory_);
66  EXPECT_EQ(GetHmacDelegate(&session), session.GetDelegate());
67}
68
69TEST_F(PolicySessionTest, StartBoundSessionSuccess) {
70  PolicySessionImpl session(factory_);
71  EXPECT_EQ(TPM_RC_SUCCESS,
72            session.StartBoundSession(TPM_RH_FIRST, "auth", true));
73}
74
75TEST_F(PolicySessionTest, StartBoundSessionFailure) {
76  PolicySessionImpl session(factory_);
77  TPM_HANDLE handle = TPM_RH_FIRST;
78  EXPECT_CALL(mock_session_manager_, StartSession(TPM_SE_POLICY, handle,
79                                                  _, true, _))
80      .WillRepeatedly(Return(TPM_RC_FAILURE));
81  EXPECT_EQ(TPM_RC_FAILURE, session.StartBoundSession(handle, "auth", true));
82}
83
84TEST_F(PolicySessionTest, StartBoundSessionBadType) {
85  PolicySessionImpl session(factory_, TPM_SE_HMAC);
86  EXPECT_EQ(SAPI_RC_INVALID_SESSIONS,
87            session.StartBoundSession(TPM_RH_FIRST, "auth", true));
88}
89
90TEST_F(PolicySessionTest, StartUnboundSessionSuccess) {
91  PolicySessionImpl session(factory_);
92  EXPECT_EQ(TPM_RC_SUCCESS, session.StartUnboundSession(true));
93}
94
95TEST_F(PolicySessionTest, StartUnboundSessionFailure) {
96  PolicySessionImpl session(factory_);
97  EXPECT_CALL(mock_session_manager_, StartSession(TPM_SE_POLICY, TPM_RH_NULL,
98                                                  _, true, _))
99      .WillRepeatedly(Return(TPM_RC_FAILURE));
100  EXPECT_EQ(TPM_RC_FAILURE, session.StartUnboundSession(true));
101}
102
103TEST_F(PolicySessionTest, GetDigestSuccess) {
104  PolicySessionImpl session(factory_);
105  std::string digest;
106  TPM2B_DIGEST policy_digest;
107  policy_digest.size = SHA256_DIGEST_SIZE;
108  EXPECT_CALL(mock_tpm_, PolicyGetDigestSync(_, _, _, _))
109      .WillOnce(DoAll(SetArgPointee<2>(policy_digest),
110                      Return(TPM_RC_SUCCESS)));
111  EXPECT_EQ(TPM_RC_SUCCESS, session.GetDigest(&digest));
112  EXPECT_EQ(static_cast<size_t>(SHA256_DIGEST_SIZE), digest.size());
113}
114
115TEST_F(PolicySessionTest, GetDigestFailure) {
116  PolicySessionImpl session(factory_);
117  std::string digest;
118  EXPECT_CALL(mock_tpm_, PolicyGetDigestSync(_, _, _, _))
119      .WillOnce(Return(TPM_RC_FAILURE));
120  EXPECT_EQ(TPM_RC_FAILURE, session.GetDigest(&digest));
121}
122
123TEST_F(PolicySessionTest, PolicyORSuccess) {
124  PolicySessionImpl session(factory_);
125  std::vector<std::string> digests;
126  digests.push_back("digest1");
127  digests.push_back("digest2");
128  digests.push_back("digest3");
129  TPML_DIGEST tpm_digests;
130  EXPECT_CALL(mock_tpm_, PolicyORSync(_, _, _, _))
131      .WillOnce(DoAll(SaveArg<2>(&tpm_digests),
132                      Return(TPM_RC_SUCCESS)));
133  EXPECT_EQ(TPM_RC_SUCCESS, session.PolicyOR(digests));
134  EXPECT_EQ(tpm_digests.count, digests.size());
135  EXPECT_EQ(StringFrom_TPM2B_DIGEST(tpm_digests.digests[0]), digests[0]);
136  EXPECT_EQ(StringFrom_TPM2B_DIGEST(tpm_digests.digests[1]), digests[1]);
137  EXPECT_EQ(StringFrom_TPM2B_DIGEST(tpm_digests.digests[2]), digests[2]);
138}
139
140TEST_F(PolicySessionTest, PolicyORBadParam) {
141  PolicySessionImpl session(factory_);
142  std::vector<std::string> digests;
143  // We use 9 here because the maximum number of digests allowed by the TPM
144  // is 8. Therefore having 9 digests here should cause the code to fail.
145  digests.resize(9);
146  EXPECT_EQ(SAPI_RC_BAD_PARAMETER, session.PolicyOR(digests));
147}
148
149TEST_F(PolicySessionTest, PolicyORFailure) {
150  PolicySessionImpl session(factory_);
151  std::vector<std::string> digests;
152  EXPECT_CALL(mock_tpm_, PolicyORSync(_, _, _, _))
153      .WillOnce(Return(TPM_RC_FAILURE));
154  EXPECT_EQ(TPM_RC_FAILURE, session.PolicyOR(digests));
155}
156
157TEST_F(PolicySessionTest, PolicyPCRSuccess) {
158  PolicySessionImpl session(factory_);
159  std::string pcr_digest("digest");
160  int pcr_index = 1;
161  TPML_PCR_SELECTION pcr_select;
162  TPM2B_DIGEST pcr_value;
163  EXPECT_CALL(mock_tpm_, PolicyPCRSync(_, _, _, _, _))
164      .WillOnce(DoAll(SaveArg<2>(&pcr_value),
165                      SaveArg<3>(&pcr_select),
166                      Return(TPM_RC_SUCCESS)));
167  EXPECT_EQ(TPM_RC_SUCCESS, session.PolicyPCR(pcr_index, pcr_digest));
168  uint8_t pcr_select_index = pcr_index / 8;
169  uint8_t pcr_select_byte = 1 << (pcr_index % 8);
170  EXPECT_EQ(pcr_select.count, 1u);
171  EXPECT_EQ(pcr_select.pcr_selections[0].hash, TPM_ALG_SHA256);
172  EXPECT_EQ(pcr_select.pcr_selections[0].sizeof_select, PCR_SELECT_MIN);
173  EXPECT_EQ(pcr_select.pcr_selections[0].pcr_select[pcr_select_index],
174            pcr_select_byte);
175  EXPECT_EQ(StringFrom_TPM2B_DIGEST(pcr_value),
176            crypto::SHA256HashString(pcr_digest));
177}
178
179TEST_F(PolicySessionTest, PolicyPCRFailure) {
180  PolicySessionImpl session(factory_);
181  EXPECT_CALL(mock_tpm_, PolicyPCRSync(_, _, _, _, _))
182      .WillOnce(Return(TPM_RC_FAILURE));
183  EXPECT_EQ(TPM_RC_FAILURE, session.PolicyPCR(1, "pcr_digest"));
184}
185
186TEST_F(PolicySessionTest, PolicyPCRTrialWithNoDigest) {
187  PolicySessionImpl session(factory_, TPM_SE_TRIAL);
188  EXPECT_EQ(SAPI_RC_BAD_PARAMETER, session.PolicyPCR(1, ""));
189}
190
191TEST_F(PolicySessionTest, PolicyCommandCodeSuccess) {
192  PolicySessionImpl session(factory_);
193  TPM_CC command_code = TPM_CC_FIRST;
194  EXPECT_CALL(mock_tpm_, PolicyCommandCodeSync(_, _, command_code, _))
195      .WillOnce(Return(TPM_RC_SUCCESS));
196  EXPECT_EQ(TPM_RC_SUCCESS, session.PolicyCommandCode(TPM_CC_FIRST));
197}
198
199TEST_F(PolicySessionTest, PolicyCommandCodeFailure) {
200  PolicySessionImpl session(factory_);
201  EXPECT_CALL(mock_tpm_, PolicyCommandCodeSync(_, _, _, _))
202      .WillOnce(Return(TPM_RC_FAILURE));
203  EXPECT_EQ(TPM_RC_FAILURE, session.PolicyCommandCode(TPM_CC_FIRST));
204}
205
206TEST_F(PolicySessionTest, PolicyAuthValueSuccess) {
207  PolicySessionImpl session(factory_);
208  EXPECT_CALL(mock_tpm_, PolicyAuthValueSync(_, _, _))
209      .WillOnce(Return(TPM_RC_SUCCESS));
210  EXPECT_EQ(TPM_RC_SUCCESS, session.PolicyAuthValue());
211}
212
213TEST_F(PolicySessionTest, PolicyAuthValueFailure) {
214  PolicySessionImpl session(factory_);
215  EXPECT_CALL(mock_tpm_, PolicyAuthValueSync(_, _, _))
216      .WillOnce(Return(TPM_RC_FAILURE));
217  EXPECT_EQ(TPM_RC_FAILURE, session.PolicyAuthValue());
218}
219
220TEST_F(PolicySessionTest, EntityAuthorizationForwardingTest) {
221  PolicySessionImpl session(factory_);
222  std::string test_auth("test_auth");
223  session.SetEntityAuthorizationValue(test_auth);
224  HmacAuthorizationDelegate* hmac_delegate = GetHmacDelegate(&session);
225  std::string entity_auth = hmac_delegate->entity_authorization_value();
226  EXPECT_EQ(0, test_auth.compare(entity_auth));
227}
228
229}  // namespace trunks
230