1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/installer/util/advanced_firewall_manager_win.h"
6
7#include "base/guid.h"
8#include "base/logging.h"
9#include "base/strings/stringprintf.h"
10#include "base/strings/utf_string_conversions.h"
11#include "base/win/scoped_bstr.h"
12#include "base/win/scoped_variant.h"
13
14namespace installer {
15
16AdvancedFirewallManager::AdvancedFirewallManager() {}
17
18AdvancedFirewallManager::~AdvancedFirewallManager() {}
19
20bool AdvancedFirewallManager::Init(const base::string16& app_name,
21                                   const base::FilePath& app_path) {
22  firewall_rules_ = NULL;
23  HRESULT hr = firewall_policy_.CreateInstance(CLSID_NetFwPolicy2);
24  if (FAILED(hr)) {
25    DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
26    firewall_policy_ = NULL;
27    return false;
28  }
29  hr = firewall_policy_->get_Rules(firewall_rules_.Receive());
30  if (FAILED(hr)) {
31    DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
32    firewall_rules_ = NULL;
33    return false;
34  }
35  app_name_ = app_name;
36  app_path_ = app_path;
37  return true;
38}
39
40bool AdvancedFirewallManager::IsFirewallEnabled() {
41  long profile_types = 0;
42  HRESULT hr = firewall_policy_->get_CurrentProfileTypes(&profile_types);
43  if (FAILED(hr))
44    return false;
45  // The most-restrictive active profile takes precedence.
46  const NET_FW_PROFILE_TYPE2 kProfileTypes[] = {
47    NET_FW_PROFILE2_PUBLIC,
48    NET_FW_PROFILE2_PRIVATE,
49    NET_FW_PROFILE2_DOMAIN
50  };
51  for (size_t i = 0; i < arraysize(kProfileTypes); ++i) {
52    if ((profile_types & kProfileTypes[i]) != 0) {
53      VARIANT_BOOL enabled = VARIANT_TRUE;
54      hr = firewall_policy_->get_FirewallEnabled(kProfileTypes[i], &enabled);
55      // Assume the firewall is enabled if we can't determine.
56      if (FAILED(hr) || enabled != VARIANT_FALSE)
57        return true;
58    }
59  }
60  return false;
61}
62
63bool AdvancedFirewallManager::HasAnyRule() {
64  std::vector<base::win::ScopedComPtr<INetFwRule> > rules;
65  GetAllRules(&rules);
66  return !rules.empty();
67}
68
69bool AdvancedFirewallManager::AddUDPRule(const base::string16& rule_name,
70                                         const base::string16& description,
71                                         uint16_t port) {
72  // Delete the rule. According MDSN |INetFwRules::Add| should replace rule with
73  // same "rule identifier". It's not clear what is "rule identifier", but it
74  // can successfully create many rule with same name.
75  DeleteRuleByName(rule_name);
76
77  // Create the rule and add it to the rule set (only succeeds if elevated).
78  base::win::ScopedComPtr<INetFwRule> udp_rule =
79      CreateUDPRule(rule_name, description, port);
80  if (!udp_rule.get())
81    return false;
82
83  HRESULT hr = firewall_rules_->Add(udp_rule);
84  DLOG_IF(ERROR, FAILED(hr)) << logging::SystemErrorCodeToString(hr);
85  return SUCCEEDED(hr);
86}
87
88void AdvancedFirewallManager::DeleteRuleByName(
89    const base::string16& rule_name) {
90  std::vector<base::win::ScopedComPtr<INetFwRule> > rules;
91  GetAllRules(&rules);
92  for (size_t i = 0; i < rules.size(); ++i) {
93    base::win::ScopedBstr name;
94    HRESULT hr = rules[i]->get_Name(name.Receive());
95    if (SUCCEEDED(hr) && name && base::string16(name) == rule_name) {
96      DeleteRule(rules[i]);
97    }
98  }
99}
100
101void AdvancedFirewallManager::DeleteRule(
102    base::win::ScopedComPtr<INetFwRule> rule) {
103  // Rename rule to unique name and delete by unique name. We can't just delete
104  // rule by name. Multiple rules with the same name and different app are
105  // possible.
106  base::win::ScopedBstr unique_name(
107      base::UTF8ToUTF16(base::GenerateGUID()).c_str());
108  rule->put_Name(unique_name);
109  firewall_rules_->Remove(unique_name);
110}
111
112void AdvancedFirewallManager::DeleteAllRules() {
113  std::vector<base::win::ScopedComPtr<INetFwRule> > rules;
114  GetAllRules(&rules);
115  for (size_t i = 0; i < rules.size(); ++i) {
116    DeleteRule(rules[i]);
117  }
118}
119
120base::win::ScopedComPtr<INetFwRule> AdvancedFirewallManager::CreateUDPRule(
121    const base::string16& rule_name,
122    const base::string16& description,
123    uint16_t port) {
124  base::win::ScopedComPtr<INetFwRule> udp_rule;
125
126  HRESULT hr = udp_rule.CreateInstance(CLSID_NetFwRule);
127  if (FAILED(hr)) {
128    DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
129    return base::win::ScopedComPtr<INetFwRule>();
130  }
131
132  udp_rule->put_Name(base::win::ScopedBstr(rule_name.c_str()));
133  udp_rule->put_Description(base::win::ScopedBstr(description.c_str()));
134  udp_rule->put_ApplicationName(
135      base::win::ScopedBstr(app_path_.value().c_str()));
136  udp_rule->put_Protocol(NET_FW_IP_PROTOCOL_UDP);
137  udp_rule->put_Direction(NET_FW_RULE_DIR_IN);
138  udp_rule->put_Enabled(VARIANT_TRUE);
139  udp_rule->put_LocalPorts(
140      base::win::ScopedBstr(base::StringPrintf(L"%u", port).c_str()));
141  udp_rule->put_Grouping(base::win::ScopedBstr(app_name_.c_str()));
142  udp_rule->put_Profiles(NET_FW_PROFILE2_ALL);
143  udp_rule->put_Action(NET_FW_ACTION_ALLOW);
144
145  return udp_rule;
146}
147
148void AdvancedFirewallManager::GetAllRules(
149    std::vector<base::win::ScopedComPtr<INetFwRule> >* rules) {
150  base::win::ScopedComPtr<IUnknown> rules_enum_unknown;
151  HRESULT hr = firewall_rules_->get__NewEnum(rules_enum_unknown.Receive());
152  if (FAILED(hr)) {
153    DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
154    return;
155  }
156
157  base::win::ScopedComPtr<IEnumVARIANT> rules_enum;
158  hr = rules_enum.QueryFrom(rules_enum_unknown);
159  if (FAILED(hr)) {
160    DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
161    return;
162  }
163
164  for (;;) {
165    base::win::ScopedVariant rule_var;
166    hr = rules_enum->Next(1, rule_var.Receive(), NULL);
167    DLOG_IF(ERROR, FAILED(hr)) << logging::SystemErrorCodeToString(hr);
168    if (hr != S_OK)
169      break;
170    DCHECK_EQ(VT_DISPATCH, rule_var.type());
171    if (VT_DISPATCH != rule_var.type()) {
172      DLOG(ERROR) << "Unexpected type";
173      continue;
174    }
175    base::win::ScopedComPtr<INetFwRule> rule;
176    hr = rule.QueryFrom(V_DISPATCH(&rule_var));
177    if (FAILED(hr)) {
178      DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
179      continue;
180    }
181
182    base::win::ScopedBstr path;
183    hr = rule->get_ApplicationName(path.Receive());
184    if (FAILED(hr)) {
185      DLOG(ERROR) << logging::SystemErrorCodeToString(hr);
186      continue;
187    }
188
189    if (!path ||
190        !base::FilePath::CompareEqualIgnoreCase(static_cast<BSTR>(path),
191                                                app_path_.value())) {
192      continue;
193    }
194
195    rules->push_back(rule);
196  }
197}
198
199}  // namespace installer
200