flip_in_mem_edsm_server.cc revision c7f5f8508d98d5952d42ed7648c2a8f30a4da156
1// Copyright (c) 2009 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <dirent.h>
6#include <linux/tcp.h>  // For TCP_NODELAY
7#include <sys/socket.h>
8#include <sys/types.h>
9#include <unistd.h>
10#include <openssl/ssl.h>
11
12#include <deque>
13#include <iostream>
14#include <limits>
15#include <vector>
16#include <list>
17
18#include "base/logging.h"
19#include "base/simple_thread.h"
20#include "base/timer.h"
21#include "base/lock.h"
22#include "net/flip/flip_frame_builder.h"
23#include "net/flip/flip_framer.h"
24#include "net/flip/flip_protocol.h"
25#include "net/tools/flip_server/balsa_enums.h"
26#include "net/tools/flip_server/balsa_frame.h"
27#include "net/tools/flip_server/balsa_headers.h"
28#include "net/tools/flip_server/balsa_visitor_interface.h"
29#include "net/tools/flip_server/buffer_interface.h"
30#include "net/tools/flip_server/create_listener.h"
31#include "net/tools/flip_server/epoll_server.h"
32#include "net/tools/flip_server/loadtime_measurement.h"
33#include "net/tools/flip_server/other_defines.h"
34#include "net/tools/flip_server/ring_buffer.h"
35#include "net/tools/flip_server/simple_buffer.h"
36#include "net/tools/flip_server/split.h"
37#include "net/tools/flip_server/url_to_filename_encoder.h"
38#include "net/tools/flip_server/url_utilities.h"
39
40////////////////////////////////////////////////////////////////////////////////
41
42using base::StringPiece;
43using base::SimpleThread;
44// using base::Lock;  // heh, this isn't in base namespace?!
45// using base::AutoLock;  // ditto!
46using flip::CONTROL_FLAG_NONE;
47using flip::DATA_FLAG_COMPRESSED;
48using flip::DATA_FLAG_FIN;
49using flip::FIN_STREAM;
50using flip::FlipControlFrame;
51using flip::FlipDataFlags;
52using flip::FlipDataFrame;
53using flip::FlipFinStreamControlFrame;
54using flip::FlipFrame;
55using flip::FlipFrameBuilder;
56using flip::FlipFramer;
57using flip::FlipFramerVisitorInterface;
58using flip::FlipHeaderBlock;
59using flip::FlipStreamId;
60using flip::FlipSynReplyControlFrame;
61using flip::FlipSynStreamControlFrame;
62using flip::SYN_REPLY;
63using flip::SYN_STREAM;
64using net::BalsaFrame;
65using net::BalsaFrameEnums;
66using net::BalsaHeaders;
67using net::BalsaHeadersEnums;
68using net::BalsaVisitorInterface;
69using net::EpollAlarmCallbackInterface;
70using net::EpollCallbackInterface;
71using net::EpollEvent;
72using net::EpollServer;
73using net::RingBuffer;
74using net::SimpleBuffer;
75using net::SplitStringPieceToVector;
76using net::UrlUtilities;
77using std::deque;
78using std::map;
79using std::pair;
80using std::string;
81using std::vector;
82using std::list;
83using std::ostream;
84using std::cerr;
85
86////////////////////////////////////////////////////////////////////////////////
87
88//         If set to true, then the server will act as an SSL server for both
89//          HTTP and FLIP);
90bool FLAGS_use_ssl = true;
91
92// The name of the cert .pem file);
93string FLAGS_ssl_cert_name = "cert.pem";
94
95// The name of the key .pem file);
96string FLAGS_ssl_key_name = "key.pem";
97
98// The number of responses given before the server closes the
99//  connection);
100int32 FLAGS_response_count_until_close = 1000*1000;
101
102// If true, then disables the nagle algorithm);
103bool FLAGS_no_nagle = true;
104
105// The number of times that accept() will be called when the
106//  alarm goes off when the accept_using_alarm flag is set to true.
107//  If set to 0, accept() will be performed until the accept queue
108//  is completely drained and the accept() call returns an error);
109int32 FLAGS_accepts_per_wake = 0;
110
111// The port on which the flip server listens);
112int32 FLAGS_flip_port = 10040;
113
114// The port on which the http server listens);
115int32 FLAGS_port = 16002;
116
117// The size of the TCP accept backlog);
118int32 FLAGS_accept_backlog_size = 1024;
119
120// The directory where cache locates);
121string FLAGS_cache_base_dir = ".";
122
123// If true, then encode url to filename);
124bool FLAGS_need_to_encode_url = true;
125
126// If set to false a single socket will be used. If set to true
127//  then a new socket will be created for each accept thread.
128//  Note that this only works with kernels that support
129//  SO_REUSEPORT);
130bool FLAGS_reuseport = false;
131
132// The amount of time the server delays before sending back the
133//  reply);
134double FLAGS_server_think_time_in_s = 0;
135
136// Does the server send X-Subresource headers);
137bool FLAGS_use_xsub = false;
138
139// Does the server send X-Associated-Content headers);
140bool FLAGS_use_xac = false;
141
142// Does the server advance cwnd by sending no-op packets);
143bool FLAGS_use_cwnd_opener = false;
144
145// Does the server compress data frames);
146bool FLAGS_use_compression = false;
147
148// The path to the urls file which includes the urls for testing);
149string FLAGS_urls_file = "experimental/users/fenix/flip/urls.txt";
150
151// The path to the html that does the pageload in iframe);
152string FLAGS_pageload_html_file =
153  "experimental/users/fenix/flip/loadtime_measurement.html";
154
155// If set to true, record requests in file named as fd used);
156bool FLAGS_record_mode = false;
157
158// The path to save the record files);
159string FLAGS_record_path = ".";
160
161////////////////////////////////////////////////////////////////////////////////
162
163// Creates a socket with domain, type and protocol parameters.
164// Assigns the return value of socket() to *fd.
165// Returns errno if an error occurs, else returns zero.
166int CreateSocket(int domain, int type, int protocol, int *fd) {
167  CHECK(fd != NULL);
168  *fd = ::socket(domain, type, protocol);
169  return (*fd == -1) ? errno : 0;
170}
171
172////////////////////////////////////////////////////////////////////////////////
173
174// Sets an FD to be nonblocking.
175void SetNonBlocking(int fd) {
176  DCHECK(fd >= 0);
177
178  int fcntl_return = fcntl(fd, F_GETFL, 0);
179  CHECK_NE(fcntl_return, -1)
180    << "error doing fcntl(fd, F_GETFL, 0) fd: " << fd
181    << " errno=" << errno;
182
183  if (fcntl_return & O_NONBLOCK)
184    return;
185
186  fcntl_return = fcntl(fd, F_SETFL, fcntl_return | O_NONBLOCK);
187  CHECK_NE(fcntl_return, -1)
188    << "error doing fcntl(fd, F_SETFL, fcntl_return) fd: " << fd
189    << " errno=" << errno;
190}
191
192////////////////////////////////////////////////////////////////////////////////
193
194LoadtimeMeasurement global_loadtime_measurement(FLAGS_urls_file,
195                                                FLAGS_pageload_html_file);
196
197////////////////////////////////////////////////////////////////////////////////
198
199struct GlobalSSLState {
200  SSL_METHOD* ssl_method;
201  SSL_CTX* ssl_ctx;
202};
203
204////////////////////////////////////////////////////////////////////////////////
205
206GlobalSSLState* global_ssl_state = NULL;
207
208////////////////////////////////////////////////////////////////////////////////
209
210// SSL stuff
211void flip_init_ssl(GlobalSSLState* state) {
212  SSL_library_init();
213  SSL_load_error_strings();
214
215  state->ssl_method = TLSv1_server_method();
216  state->ssl_ctx = SSL_CTX_new(state->ssl_method);
217  if (!state->ssl_ctx) {
218    LOG(FATAL) << "Unable to create SSL context";
219  }
220  if (SSL_CTX_use_certificate_file(state->ssl_ctx,
221                                   FLAGS_ssl_cert_name.c_str(),
222                                   SSL_FILETYPE_PEM) <= 0) {
223    LOG(FATAL) << "Unable to use cert.pem as SSL cert.";
224  }
225  if (SSL_CTX_use_PrivateKey_file(state->ssl_ctx,
226                                  FLAGS_ssl_key_name.c_str(),
227                                  SSL_FILETYPE_PEM) <= 0) {
228    LOG(FATAL) << "Unable to use key.pem as SSL key.";
229  }
230  if (!SSL_CTX_check_private_key(state->ssl_ctx)) {
231    LOG(FATAL) << "The cert.pem and key.pem files don't match";
232  }
233}
234
235SSL* flip_new_ssl(SSL_CTX* ssl_ctx) {
236  SSL* ssl = SSL_new(ssl_ctx);
237  SSL_set_accept_state(ssl);
238  return ssl;
239}
240
241////////////////////////////////////////////////////////////////////////////////
242
243const int kInitialDataSendersThreshold =  (2 * 1460) - FlipFrame::size();
244const int kNormalSegmentSize = (2 * 1460) - FlipFrame::size();
245
246////////////////////////////////////////////////////////////////////////////////
247
248class DataFrame {
249 public:
250  const char* data;
251  size_t size;
252  bool delete_when_done;
253  size_t index;
254  DataFrame() : data(NULL), size(0), delete_when_done(false), index(0) {}
255  void MaybeDelete() {
256    if (delete_when_done) {
257      delete[] data;
258    }
259  }
260};
261
262////////////////////////////////////////////////////////////////////////////////
263
264class StoreBodyAndHeadersVisitor: public BalsaVisitorInterface {
265 public:
266  BalsaHeaders headers;
267  string body;
268  bool error_;
269
270  virtual void ProcessBodyInput(const char *input, size_t size) {}
271  virtual void ProcessBodyData(const char *input, size_t size) {
272    body.append(input, size);
273  }
274  virtual void ProcessHeaderInput(const char *input, size_t size) {}
275  virtual void ProcessTrailerInput(const char *input, size_t size) {}
276  virtual void ProcessHeaders(const BalsaHeaders& headers) {
277    // nothing to do here-- we're assuming that the BalsaFrame has
278    // been handed our headers.
279  }
280  virtual void ProcessRequestFirstLine(const char* line_input,
281                                       size_t line_length,
282                                       const char* method_input,
283                                       size_t method_length,
284                                       const char* request_uri_input,
285                                       size_t request_uri_length,
286                                       const char* version_input,
287                                       size_t version_length) {}
288  virtual void ProcessResponseFirstLine(const char *line_input,
289                                        size_t line_length,
290                                        const char *version_input,
291                                        size_t version_length,
292                                        const char *status_input,
293                                        size_t status_length,
294                                        const char *reason_input,
295                                        size_t reason_length) {}
296  virtual void ProcessChunkLength(size_t chunk_length) {}
297  virtual void ProcessChunkExtensions(const char *input, size_t size) {}
298  virtual void HeaderDone() {}
299  virtual void MessageDone() {}
300  virtual void HandleHeaderError(BalsaFrame* framer) { HandleError(); }
301  virtual void HandleHeaderWarning(BalsaFrame* framer) { HandleError(); }
302  virtual void HandleChunkingError(BalsaFrame* framer) { HandleError(); }
303  virtual void HandleBodyError(BalsaFrame* framer) { HandleError(); }
304
305  void HandleError() { error_ = true; }
306};
307
308////////////////////////////////////////////////////////////////////////////////
309
310struct FileData {
311  void CopyFrom(const FileData& file_data) {
312    headers = new BalsaHeaders;
313    headers->CopyFrom(*(file_data.headers));
314    filename = file_data.filename;
315    related_files = file_data.related_files;
316    body = file_data.body;
317  }
318  FileData(BalsaHeaders* h, const string& b) : headers(h), body(b) {}
319  FileData() {}
320  BalsaHeaders* headers;
321  string filename;
322  vector< pair<int, string> > related_files;   // priority, filename
323  string body;
324};
325
326////////////////////////////////////////////////////////////////////////////////
327
328class MemCacheIter {
329 public:
330  MemCacheIter() :
331      file_data(NULL),
332      priority(0),
333      transformed_header(false),
334      body_bytes_consumed(0),
335      stream_id(0),
336      max_segment_size(kInitialDataSendersThreshold),
337      bytes_sent(0) {}
338  explicit MemCacheIter(FileData* fd) :
339      file_data(fd),
340      priority(0),
341      transformed_header(false),
342      body_bytes_consumed(0),
343      stream_id(0),
344      max_segment_size(kInitialDataSendersThreshold),
345      bytes_sent(0) {}
346  FileData* file_data;
347  int priority;
348  bool transformed_header;
349  size_t body_bytes_consumed;
350  uint32 stream_id;
351  uint32 max_segment_size;
352  size_t bytes_sent;
353};
354
355////////////////////////////////////////////////////////////////////////////////
356
357class MemoryCache {
358 public:
359  typedef map<string, FileData> Files;
360
361 public:
362  Files files_;
363  string cwd_;
364
365  void CloneFrom(const MemoryCache& mc) {
366    for (Files::const_iterator i = mc.files_.begin();
367         i != mc.files_.end();
368         ++i) {
369      Files::iterator out_i =
370        files_.insert(make_pair(i->first, FileData())).first;
371      out_i->second.CopyFrom(i->second);
372      cwd_ = mc.cwd_;
373    }
374  }
375
376  void AddFiles() {
377    LOG(INFO) << "Adding files!";
378    deque<string> paths;
379    cwd_ = FLAGS_cache_base_dir;
380    paths.push_back(cwd_ + "/GET_");
381    DIR* current_dir = NULL;
382    while (!paths.empty()) {
383      while (current_dir == NULL && !paths.empty()) {
384        string current_dir_name = paths.front();
385        VLOG(1) << "Attempting to open dir: \"" << current_dir_name << "\"";
386        current_dir = opendir(current_dir_name.c_str());
387        paths.pop_front();
388
389        if (current_dir == NULL) {
390          perror("Unable to open directory. ");
391          current_dir_name.clear();
392          continue;
393        }
394
395        if (current_dir) {
396          VLOG(1) << "Succeeded opening";
397          for (struct dirent* dir_data = readdir(current_dir);
398               dir_data != NULL;
399               dir_data = readdir(current_dir)) {
400            string current_entry_name =
401              current_dir_name + "/" + dir_data->d_name;
402            if (dir_data->d_type == DT_REG) {
403              VLOG(1) << "Found file: " << current_entry_name;
404              ReadAndStoreFileContents(current_entry_name.c_str());
405            } else if (dir_data->d_type == DT_DIR) {
406              VLOG(1) << "Found subdir: " << current_entry_name;
407              if (string(dir_data->d_name) != "." &&
408                  string(dir_data->d_name) != "..") {
409                VLOG(1) << "Adding to search path: " << current_entry_name;
410                paths.push_front(current_entry_name);
411              }
412            }
413          }
414          VLOG(1) << "Oops, no data left. Closing dir.";
415          closedir(current_dir);
416          current_dir = NULL;
417        }
418      }
419    }
420  }
421
422  void ReadToString(const char* filename, string* output) {
423    output->clear();
424    int fd = open(filename, 0, "r");
425    if (fd == -1)
426      return;
427    char buffer[4096];
428    ssize_t read_status = read(fd, buffer, sizeof(buffer));
429    while (read_status > 0) {
430      output->append(buffer, static_cast<size_t>(read_status));
431      do {
432        read_status = read(fd, buffer, sizeof(buffer));
433      } while (read_status <= 0 && errno == EINTR);
434    }
435    close(fd);
436  }
437
438  void ReadAndStoreFileContents(const char* filename) {
439    StoreBodyAndHeadersVisitor visitor;
440    BalsaFrame framer;
441    framer.set_balsa_visitor(&visitor);
442    framer.set_balsa_headers(&(visitor.headers));
443    string filename_contents;
444    ReadToString(filename, &filename_contents);
445
446    // Ugly hack to make everything look like 1.1.
447    if (filename_contents.find("HTTP/1.0") == 0)
448      filename_contents[7] = '1';
449
450    size_t pos = 0;
451    size_t old_pos = 0;
452    while (true) {
453      old_pos = pos;
454      pos += framer.ProcessInput(filename_contents.data() + pos,
455                                 filename_contents.size() - pos);
456      if (framer.Error() || pos == old_pos) {
457        LOG(ERROR) << "Unable to make forward progress, or error"
458          " framing file: " << filename;
459        if (framer.Error()) {
460          LOG(INFO) << "********************************************ERROR!";
461          return;
462        }
463        return;
464      }
465      if (framer.MessageFullyRead()) {
466        // If no Content-Length or Transfer-Encoding was captured in the
467        // file, then the rest of the data is the body.  Many of the captures
468        // from within Chrome don't have content-lengths.
469        if (!visitor.body.length())
470          visitor.body = filename_contents.substr(pos);
471        break;
472      }
473    }
474    visitor.headers.RemoveAllOfHeader("content-length");
475    visitor.headers.RemoveAllOfHeader("transfer-encoding");
476    visitor.headers.RemoveAllOfHeader("connection");
477    visitor.headers.AppendHeader("transfer-encoding", "chunked");
478    visitor.headers.AppendHeader("connection", "keep-alive");
479
480    // Experiment with changing headers for forcing use of cached
481    // versions of content.
482    // TODO(mbelshe) REMOVE ME
483#if 0
484    // TODO(mbelshe): append current date.
485    visitor.headers.RemoveAllOfHeader("date");
486    if (visitor.headers.HasHeader("expires")) {
487      visitor.headers.RemoveAllOfHeader("expires");
488      visitor.headers.AppendHeader("expires",
489                                 "Fri, 30 Aug, 2019 12:00:00 GMT");
490    }
491#endif
492    BalsaHeaders* headers = new BalsaHeaders;
493    headers->CopyFrom(visitor.headers);
494    string filename_stripped =
495      string(filename).substr(cwd_.size() + 1);
496//    LOG(INFO) << "Adding file (" << visitor.body.length() << " bytes): "
497//              << filename_stripped;
498    files_[filename_stripped] = FileData();
499    FileData& fd = files_[filename_stripped];
500    fd = FileData(headers, visitor.body);
501    fd.filename = string(filename_stripped,
502                         filename_stripped.find_first_of('/'));
503    if (headers->HasHeader("X-Associated-Content")) {
504      string content =
505        headers->GetHeader("X-Associated-Content").as_string();
506      vector<StringPiece> urls_and_priorities;
507      SplitStringPieceToVector(content, "||", &urls_and_priorities, true);
508      VLOG(1) << "Examining X-Associated-Content header";
509      for (unsigned int i = 0; i < urls_and_priorities.size(); ++i) {
510        const StringPiece& url_and_priority_pair = urls_and_priorities[i];
511        vector<StringPiece> url_and_priority;
512        SplitStringPieceToVector(url_and_priority_pair, "??",
513                                 &url_and_priority, true);
514        if (url_and_priority.size() >= 2) {
515          string priority_string(url_and_priority[0].data(),
516                                 url_and_priority[0].size());
517          string filename_string(url_and_priority[1].data(),
518                                 url_and_priority[1].size());
519          int priority;
520          char* last_eaten_char;
521          priority = strtol(priority_string.c_str(), &last_eaten_char, 0);
522          if (last_eaten_char ==
523              priority_string.c_str() + priority_string.size()) {
524            pair<int, string> entry(priority, filename_string);
525            VLOG(1) << "Adding associated content: " << filename_string;
526            fd.related_files.push_back(entry);
527          }
528        }
529      }
530    }
531  }
532
533  // Called at runtime to update learned headers
534  // |url| is a url which contains a referrer header.
535  // |referrer| is the referring URL
536  // Adds an X-Subresource or X-Associated-Content to |referer| for |url|
537  void UpdateHeaders(string referrer, string file_url) {
538    if (!FLAGS_use_xac && !FLAGS_use_xsub)
539      return;
540
541    string referrer_host_path =
542      net::UrlToFilenameEncoder::Encode(referrer, "GET_/");
543
544    FileData* fd1 = GetFileData(string("GET_") + file_url);
545    if (!fd1) {
546      LOG(ERROR) << "Updating headers for unknown url: " << file_url;
547      return;
548    }
549    string url = fd1->headers->GetHeader("X-Original-Url").as_string();
550    string content_type = fd1->headers->GetHeader("Content-Type").as_string();
551    if (content_type.length() == 0) {
552      LOG(ERROR) << "Skipping subresource with unknown content-type";
553      return;
554    }
555
556    // Now, lets see if this is the same host or not
557    bool same_host = (UrlUtilities::GetUrlHost(referrer) ==
558                      UrlUtilities::GetUrlHost(url));
559
560    // This is a hacked algorithm for figuring out what priority
561    // to use with pushed content.
562    int priority = 4;
563    if (content_type.find("css") != string::npos)
564      priority = 1;
565    else if (content_type.find("cript") != string::npos)
566      priority = 1;
567    else if (content_type.find("html") != string::npos)
568      priority = 2;
569
570    LOG(ERROR) << "Attempting update for " << referrer_host_path;
571
572    FileData* fd2 = GetFileData(referrer_host_path);
573    if (fd2 != NULL) {
574      // If they are on the same host, we'll use X-Associated-Content
575      string header_name;
576      string new_value;
577      string delimiter;
578      bool related_files = false;
579      if (same_host && FLAGS_use_xac) {
580        header_name = "X-Associated-Content";
581        char pri_ch = priority + '0';
582        new_value = pri_ch + string("??") + url;
583        delimiter = "||";
584        related_files = true;
585      } else {
586        if (!FLAGS_use_xsub)
587          return;
588        header_name = "X-Subresource";
589        new_value = content_type + "!!" + url;
590        delimiter = "!!";
591      }
592
593      if (fd2->headers->HasNonEmptyHeader(header_name)) {
594        string existing_header =
595            fd2->headers->GetHeader(header_name).as_string();
596        if (existing_header.find(url) != string::npos)
597          return;  // header already recorded
598
599        // Don't let these lists grow too long for low pri stuff.
600        // TODO(mbelshe) We need better algorithms for this.
601        if (existing_header.length() > 256 && priority > 2)
602          return;
603
604        new_value = existing_header + delimiter + new_value;
605      }
606
607      LOG(INFO) << "Recording " << header_name << " for " << new_value;
608      fd2->headers->ReplaceOrAppendHeader(header_name, new_value);
609
610      // Add it to the related files so that it will actually get sent out.
611      if (related_files) {
612        pair<int, string> entry(4, file_url);
613        fd2->related_files.push_back(entry);
614      }
615    } else {
616      LOG(ERROR) << "Failed to update headers:";
617      LOG(ERROR) << "FAIL url: " << url;
618      LOG(ERROR) << "FAIL ref: " << referrer_host_path;
619    }
620  }
621
622  FileData* GetFileData(const string& filename) {
623    Files::iterator fi = files_.end();
624    if (filename.compare(filename.length() - 5, 5, ".html", 5) == 0) {
625      string new_filename(filename.data(), filename.size() - 5);
626      new_filename += ".http";
627      fi = files_.find(new_filename);
628    }
629    if (fi == files_.end())
630      fi = files_.find(filename);
631
632    if (fi == files_.end()) {
633      return NULL;
634    }
635    return &(fi->second);
636  }
637
638  bool AssignFileData(const string& filename, MemCacheIter* mci) {
639    mci->file_data = GetFileData(filename);
640    if (mci->file_data == NULL) {
641      LOG(ERROR) << "Could not find file data for " << filename;
642      return false;
643    }
644    return true;
645  }
646};
647
648////////////////////////////////////////////////////////////////////////////////
649
650class NotifierInterface {
651 public:
652  virtual ~NotifierInterface() {}
653  virtual void Notify() = 0;
654};
655
656////////////////////////////////////////////////////////////////////////////////
657
658class SMInterface {
659 public:
660  virtual size_t ProcessInput(const char* data, size_t len) = 0;
661  virtual bool MessageFullyRead() const = 0;
662  virtual bool Error() const = 0;
663  virtual const char* ErrorAsString() const = 0;
664  virtual void Reset() = 0;
665  virtual void ResetForNewConnection() = 0;
666
667  virtual void PostAcceptHook() = 0;
668
669  virtual void NewStream(uint32 stream_id, uint32 priority,
670                         const string& filename) = 0;
671  virtual void SendEOF(uint32 stream_id) = 0;
672  virtual void SendErrorNotFound(uint32 stream_id) = 0;
673  virtual size_t SendSynStream(uint32 stream_id,
674                              const BalsaHeaders& headers) = 0;
675  virtual size_t SendSynReply(uint32 stream_id,
676                              const BalsaHeaders& headers) = 0;
677  virtual void SendDataFrame(uint32 stream_id, const char* data, int64 len,
678                             uint32 flags, bool compress) = 0;
679  virtual void GetOutput() = 0;
680
681  virtual ~SMInterface() {}
682};
683
684////////////////////////////////////////////////////////////////////////////////
685
686class SMServerConnection;
687typedef SMInterface*(SMInterfaceFactory)(SMServerConnection* conn);
688
689////////////////////////////////////////////////////////////////////////////////
690
691typedef list<DataFrame> OutputList;
692
693////////////////////////////////////////////////////////////////////////////////
694
695class SMServerConnection;
696
697class SMServerConnectionPoolInterface {
698 public:
699  virtual ~SMServerConnectionPoolInterface() {}
700  // SMServerConnections will use this:
701  virtual void SMServerConnectionDone(SMServerConnection* conn) = 0;
702};
703
704////////////////////////////////////////////////////////////////////////////////
705
706class SMServerConnection: public EpollCallbackInterface,
707                          public NotifierInterface {
708 private:
709  SMServerConnection(SMInterfaceFactory* sm_interface_factory,
710                     MemoryCache* memory_cache,
711                     EpollServer* epoll_server) :
712      fd_(-1),
713      record_fd_(-1),
714      events_(0),
715
716      registered_in_epoll_server_(false),
717      initialized_(false),
718
719      connection_pool_(NULL),
720      epoll_server_(epoll_server),
721
722      read_buffer_(4096*10),
723      memory_cache_(memory_cache),
724      sm_interface_(sm_interface_factory(this)),
725
726      max_bytes_sent_per_dowrite_(128),
727
728      ssl_(NULL) {}
729
730  int fd_;
731  int record_fd_;
732  int events_;
733
734  bool registered_in_epoll_server_;
735  bool initialized_;
736
737  SMServerConnectionPoolInterface* connection_pool_;
738  EpollServer* epoll_server_;
739
740  RingBuffer read_buffer_;
741
742  OutputList output_list_;
743  MemoryCache* memory_cache_;
744  SMInterface* sm_interface_;
745
746  size_t max_bytes_sent_per_dowrite_;
747
748  SSL* ssl_;
749 public:
750  EpollServer* epoll_server() { return epoll_server_; }
751  OutputList* output_list() { return &output_list_; }
752  MemoryCache* memory_cache() { return memory_cache_; }
753  int record_fd() { return record_fd_; }
754  void close_record_fd() {
755    if (record_fd_ != -1) {
756      close(record_fd_);
757      record_fd_ = -1;
758    }
759  }
760  void ReadyToSend() {
761    epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT);
762  }
763  void EnqueueDataFrame(const DataFrame& df) {
764    output_list_.push_back(df);
765    VLOG(2) << "EnqueueDataFrame. Setting FD ready.";
766    ReadyToSend();
767  }
768
769 public:
770  ~SMServerConnection() {
771    if (initialized()) {
772      Reset();
773    }
774  }
775  static SMServerConnection* NewSMServerConnection(SMInterfaceFactory* smif,
776                                                   MemoryCache* memory_cache,
777                                                   EpollServer* epoll_server) {
778    return new SMServerConnection(smif, memory_cache, epoll_server);
779  }
780
781  bool initialized() const { return initialized_; }
782
783  void InitSMServerConnection(SMServerConnectionPoolInterface* connection_pool,
784                              EpollServer* epoll_server,
785                              int fd) {
786    if (initialized_) {
787      LOG(FATAL) << "Attempted to initialize already initialized server";
788      return;
789    }
790    if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) {
791      epoll_server_->UnregisterFD(fd_);
792    }
793    if (fd_ != -1) {
794      VLOG(2) << "Closing pre-existing fd";
795      close(fd_);
796      fd_ = -1;
797    }
798    if (FLAGS_record_mode) {
799      char record_file_name[1024];
800      snprintf(record_file_name, sizeof(record_file_name), "%s/%d_%ld",
801              FLAGS_record_path.c_str(), fd, epoll_server->NowInUsec()/1000);
802      record_fd_ = open(record_file_name, O_CREAT|O_APPEND|O_WRONLY, S_IRWXU);
803      if (record_fd_ < 0) {
804        LOG(ERROR) << "Open record file for fd " << fd << " failed";
805        record_fd_ = -1;
806      }
807    }
808
809    fd_ = fd;
810
811    registered_in_epoll_server_ = false;
812    initialized_ = true;
813
814    connection_pool_ = connection_pool;
815    epoll_server_ = epoll_server;
816
817    sm_interface_->Reset();
818    read_buffer_.Clear();
819
820    epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET);
821
822    if (global_ssl_state) {
823      ssl_ = flip_new_ssl(global_ssl_state->ssl_ctx);
824      SSL_set_fd(ssl_, fd_);
825    }
826    sm_interface_->PostAcceptHook();
827  }
828
829  int Send(const char* bytes, int len, int flags) {
830    return send(fd_, bytes, len, flags);
831  }
832
833  // the following are from the EpollCallbackInterface
834  virtual void OnRegistration(EpollServer* eps, int fd, int event_mask) {
835    registered_in_epoll_server_ = true;
836  }
837  virtual void OnModification(int fd, int event_mask) { }
838  virtual void OnEvent(int fd, EpollEvent* event) {
839    events_ |= event->in_events;
840    HandleEvents();
841    if (events_) {
842      event->out_ready_mask = events_;
843      events_ = 0;
844    }
845  }
846  virtual void OnUnregistration(int fd, bool replaced) {
847    registered_in_epoll_server_ = false;
848  }
849  virtual void OnShutdown(EpollServer* eps, int fd) {
850    Cleanup("OnShutdown");
851    return;
852  }
853
854 private:
855  void HandleEvents() {
856    VLOG(1) << "Received: " << EpollServer::EventMaskToString(events_);
857    if (events_ & EPOLLIN) {
858      if (!DoRead())
859        goto handle_close_or_error;
860    }
861
862    if (events_ & EPOLLOUT) {
863      if (!DoWrite())
864        goto handle_close_or_error;
865    }
866
867    if (events_ & (EPOLLHUP | EPOLLERR)) {
868      VLOG(2) << "!!!! Got HUP or ERR";
869      goto handle_close_or_error;
870    }
871    return;
872
873 handle_close_or_error:
874    Cleanup("HandleEvents");
875  }
876
877  bool DoRead() {
878    VLOG(2) << "DoRead()";
879    if (fd_ == -1) {
880      VLOG(2) << "DoRead(): fd_ == -1. Invalid FD. Returning false";
881      return false;
882    }
883    while (!read_buffer_.Full()) {
884      char* bytes;
885      int size;
886      read_buffer_.GetWritablePtr(&bytes, &size);
887      ssize_t bytes_read = 0;
888      if (ssl_) {
889        bytes_read = SSL_read(ssl_, bytes, size);
890      } else {
891        bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT);
892      }
893      int stored_errno = errno;
894      if (bytes_read == -1) {
895        switch (stored_errno) {
896          case EAGAIN:
897            events_ &= ~EPOLLIN;
898            VLOG(2) << "Got EAGAIN while reading";
899            goto done;
900          case EINTR:
901            VLOG(2) << "Got EINTR while reading";
902            continue;
903          default:
904            VLOG(2) << "While calling recv, got error: " << stored_errno
905              << " " << strerror(stored_errno);
906            goto error_or_close;
907        }
908      } else if (bytes_read > 0) {
909        VLOG(2) << "Read: " << bytes_read << " bytes from fd: " << fd_;
910        read_buffer_.AdvanceWritablePtr(bytes_read);
911        if (!DoConsumeReadData()) {
912          goto error_or_close;
913        }
914        continue;
915      } else {  // bytes_read == 0
916        VLOG(2) << "0 bytes read with recv call.";
917      }
918      goto error_or_close;
919    }
920   done:
921    return true;
922
923   error_or_close:
924    VLOG(2) << "DoRead(): error_or_close. Cleaning up, then returning false";
925    Cleanup("DoRead");
926    return false;
927  }
928
929  bool DoConsumeReadData() {
930    char* bytes;
931    int size;
932    read_buffer_.GetReadablePtr(&bytes, &size);
933    while (size != 0) {
934      size_t bytes_consumed = sm_interface_->ProcessInput(bytes, size);
935      VLOG(2) << "consumed: " << bytes_consumed << " from socket fd: " << fd_;
936      if (bytes_consumed == 0) {
937        break;
938      }
939      read_buffer_.AdvanceReadablePtr(bytes_consumed);
940      if (sm_interface_->MessageFullyRead()) {
941        VLOG(2) << "HandleRequestFullyRead";
942        HandleRequestFullyRead();
943        sm_interface_->Reset();
944        events_ |= EPOLLOUT;
945      } else if (sm_interface_->Error()) {
946        LOG(ERROR) << "Framer error detected: "
947                   << sm_interface_->ErrorAsString();
948        // this causes everything to be closed/cleaned up.
949        events_ |= EPOLLOUT;
950        return false;
951      }
952      read_buffer_.GetReadablePtr(&bytes, &size);
953    }
954    return true;
955  }
956
957  void WriteResponse() {
958    // this happens asynchronously from separate threads
959    // feeding files into the output buffer.
960  }
961
962  void HandleRequestFullyRead() {
963  }
964
965  void Notify() {
966  }
967
968  bool DoWrite() {
969    size_t bytes_sent = 0;
970    int flags = MSG_NOSIGNAL | MSG_DONTWAIT;
971    if (fd_ == -1) {
972      VLOG(2) << "DoWrite: fd == -1. Returning false.";
973      return false;
974    }
975    if (output_list_.empty()) {
976      sm_interface_->GetOutput();
977      if (output_list_.empty())
978        events_ &= ~EPOLLOUT;
979    }
980    while (!output_list_.empty()) {
981      if (bytes_sent >= max_bytes_sent_per_dowrite_) {
982        events_ |= EPOLLOUT;
983        break;
984      }
985      if (output_list_.size() < 2) {
986        sm_interface_->GetOutput();
987      }
988      DataFrame& data_frame = output_list_.front();
989      const char*  bytes = data_frame.data;
990      int size = data_frame.size;
991      bytes += data_frame.index;
992      size -= data_frame.index;
993      DCHECK_GE(size, 0);
994      if (size <= 0) {
995        data_frame.MaybeDelete();
996        output_list_.pop_front();
997        continue;
998      }
999
1000      flags = MSG_NOSIGNAL | MSG_DONTWAIT;
1001      if (output_list_.size() > 1) {
1002        flags |= MSG_MORE;
1003      }
1004      ssize_t bytes_written = 0;
1005      if (ssl_) {
1006        bytes_written = SSL_write(ssl_, bytes, size);
1007      } else {
1008        bytes_written = send(fd_, bytes, size, flags);
1009      }
1010      int stored_errno = errno;
1011      if (bytes_written == -1) {
1012        switch (stored_errno) {
1013          case EAGAIN:
1014            events_ &= ~EPOLLOUT;
1015            VLOG(2) << " Got EAGAIN while writing";
1016            goto done;
1017          case EINTR:
1018            VLOG(2) << " Got EINTR while writing";
1019            continue;
1020          default:
1021            VLOG(2) << "While calling send, got error: " << stored_errno
1022              << " " << strerror(stored_errno);
1023            goto error_or_close;
1024        }
1025      } else if (bytes_written > 0) {
1026        VLOG(1) << "Wrote: " << bytes_written  << " bytes to socket fd: "
1027          << fd_;
1028        data_frame.index += bytes_written;
1029        bytes_sent += bytes_written;
1030        continue;
1031      }
1032      VLOG(2) << "0 bytes written to socket " << fd_ << " with send call.";
1033      goto error_or_close;
1034    }
1035   done:
1036    return true;
1037
1038   error_or_close:
1039    VLOG(2) << "DoWrite: error_or_close. Returning false after cleaning up";
1040    Cleanup("DoWrite");
1041    return false;
1042  }
1043
1044  friend ostream& operator<<(ostream& os, const SMServerConnection& c) {
1045    os << &c << "\n";
1046    return os;
1047  }
1048
1049  void Reset() {
1050    VLOG(2) << "Resetting";
1051    if (ssl_) {
1052      SSL_shutdown(ssl_);
1053      SSL_free(ssl_);
1054    }
1055    if (registered_in_epoll_server_) {
1056      epoll_server_->UnregisterFD(fd_);
1057      registered_in_epoll_server_ = false;
1058    }
1059    if (fd_ >= 0) {
1060      VLOG(2) << "Closing connection";
1061      close(fd_);
1062      fd_ = -1;
1063    }
1064    sm_interface_->ResetForNewConnection();
1065    read_buffer_.Clear();
1066    initialized_ = false;
1067    events_ = 0;
1068    output_list_.clear();
1069  }
1070
1071  void Cleanup(const char* cleanup) {
1072    VLOG(2) << "Cleaning up: " << cleanup;
1073    if (!initialized_) {
1074      return;
1075    }
1076    Reset();
1077    connection_pool_->SMServerConnectionDone(this);
1078  }
1079};
1080
1081////////////////////////////////////////////////////////////////////////////////
1082
1083class OutputOrdering {
1084 public:
1085  typedef list<MemCacheIter> PriorityRing;
1086
1087  typedef map<uint32, PriorityRing> PriorityMap;
1088
1089  struct PriorityMapPointer {
1090    PriorityMapPointer(): ring(NULL), alarm_enabled(false) {}
1091    PriorityRing* ring;
1092    PriorityRing::iterator it;
1093    bool alarm_enabled;
1094    EpollServer::AlarmRegToken alarm_token;
1095  };
1096  typedef map<uint32, PriorityMapPointer> StreamIdToPriorityMap;
1097
1098  StreamIdToPriorityMap stream_ids_;
1099  PriorityMap priority_map_;
1100  PriorityRing first_data_senders_;
1101  uint32 first_data_senders_threshold_;  // when you've passed this, you're no
1102                                         // longer a first_data_sender...
1103  SMServerConnection* connection_;
1104  EpollServer* epoll_server_;
1105
1106  explicit OutputOrdering(SMServerConnection* connection) :
1107      first_data_senders_threshold_(kInitialDataSendersThreshold),
1108      connection_(connection),
1109      epoll_server_(connection->epoll_server()) {
1110  }
1111
1112  void Reset() {
1113    while (!stream_ids_.empty()) {
1114      StreamIdToPriorityMap::iterator sitpmi = stream_ids_.begin();
1115      PriorityMapPointer& pmp = sitpmi->second;
1116      if (pmp.alarm_enabled) {
1117        epoll_server_->UnregisterAlarm(pmp.alarm_token);
1118      }
1119      stream_ids_.erase(sitpmi);
1120    }
1121    priority_map_.clear();
1122    first_data_senders_.clear();
1123  }
1124
1125  bool ExistsInPriorityMaps(uint32 stream_id) {
1126    StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(stream_id);
1127    return sitpmi != stream_ids_.end();
1128  }
1129
1130  struct BeginOutputtingAlarm : public EpollAlarmCallbackInterface {
1131   public:
1132    BeginOutputtingAlarm(OutputOrdering* oo,
1133                         OutputOrdering::PriorityMapPointer* pmp,
1134                         const MemCacheIter& mci) :
1135        output_ordering_(oo), pmp_(pmp), mci_(mci), epoll_server_(NULL) {}
1136
1137    int64 OnAlarm() {
1138      OnUnregistration();
1139      output_ordering_->MoveToActive(pmp_, mci_);
1140      VLOG(1) << "ON ALARM! Should now start to output...";
1141      delete this;
1142      return 0;
1143    }
1144    void OnRegistration(const EpollServer::AlarmRegToken& tok,
1145                        EpollServer* eps) {
1146      epoll_server_ = eps;
1147      pmp_->alarm_token = tok;
1148      pmp_->alarm_enabled = true;
1149    }
1150    void OnUnregistration() {
1151      pmp_->alarm_enabled = false;
1152    }
1153    void OnShutdown(EpollServer* eps) {
1154      OnUnregistration();
1155    }
1156    ~BeginOutputtingAlarm() {
1157      if (epoll_server_ && pmp_->alarm_enabled)
1158        epoll_server_->UnregisterAlarm(pmp_->alarm_token);
1159    }
1160   private:
1161    OutputOrdering* output_ordering_;
1162    OutputOrdering::PriorityMapPointer* pmp_;
1163    MemCacheIter mci_;
1164    EpollServer* epoll_server_;
1165  };
1166
1167  void MoveToActive(PriorityMapPointer* pmp, MemCacheIter mci) {
1168    VLOG(1) <<"Moving to active!";
1169    first_data_senders_.push_back(mci);
1170    pmp->ring = &first_data_senders_;
1171    pmp->it = first_data_senders_.end();
1172    --pmp->it;
1173    connection_->ReadyToSend();
1174  }
1175
1176  void AddToOutputOrder(const MemCacheIter& mci) {
1177    if (ExistsInPriorityMaps(mci.stream_id))
1178      LOG(FATAL) << "OOps, already was inserted here?!";
1179
1180    StreamIdToPriorityMap::iterator sitpmi;
1181    sitpmi = stream_ids_.insert(
1182        pair<uint32, PriorityMapPointer>(mci.stream_id,
1183                                         PriorityMapPointer())).first;
1184    PriorityMapPointer& pmp = sitpmi->second;
1185
1186    BeginOutputtingAlarm* boa = new BeginOutputtingAlarm(this, &pmp, mci);
1187    epoll_server_->RegisterAlarmApproximateDelta(
1188        FLAGS_server_think_time_in_s * 1000000, boa);
1189  }
1190
1191  void SpliceToPriorityRing(PriorityRing::iterator pri) {
1192    MemCacheIter& mci = *pri;
1193    PriorityMap::iterator pmi = priority_map_.find(mci.priority);
1194    if (pmi == priority_map_.end()) {
1195      pmi = priority_map_.insert(
1196          pair<uint32, PriorityRing>(mci.priority, PriorityRing())).first;
1197    }
1198
1199    pmi->second.splice(pmi->second.end(),
1200                       first_data_senders_,
1201                       pri);
1202    StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(mci.stream_id);
1203    sitpmi->second.ring = &(pmi->second);
1204  }
1205
1206  MemCacheIter* GetIter() {
1207    while (!first_data_senders_.empty()) {
1208      MemCacheIter& mci = first_data_senders_.front();
1209      if (mci.bytes_sent >= first_data_senders_threshold_) {
1210        SpliceToPriorityRing(first_data_senders_.begin());
1211      } else {
1212        first_data_senders_.splice(first_data_senders_.end(),
1213                                  first_data_senders_,
1214                                  first_data_senders_.begin());
1215        mci.max_segment_size = kInitialDataSendersThreshold;
1216        return &mci;
1217      }
1218    }
1219    while (!priority_map_.empty()) {
1220      PriorityRing& first_ring = priority_map_.begin()->second;
1221      if (first_ring.empty()) {
1222        priority_map_.erase(priority_map_.begin());
1223        continue;
1224      }
1225      MemCacheIter& mci = first_ring.front();
1226      first_ring.splice(first_ring.end(),
1227                        first_ring,
1228                        first_ring.begin());
1229      mci.max_segment_size = kNormalSegmentSize;
1230      return &mci;
1231    }
1232    return NULL;
1233  }
1234
1235  void RemoveStreamId(uint32 stream_id) {
1236    StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(stream_id);
1237    if (sitpmi == stream_ids_.end())
1238      return;
1239    PriorityMapPointer& pmp = sitpmi->second;
1240    if (pmp.alarm_enabled) {
1241      epoll_server_->UnregisterAlarm(pmp.alarm_token);
1242    } else {
1243      pmp.ring->erase(pmp.it);
1244    }
1245
1246    stream_ids_.erase(sitpmi);
1247  }
1248};
1249
1250////////////////////////////////////////////////////////////////////////////////
1251
1252class FlipSM : public FlipFramerVisitorInterface, public SMInterface {
1253 private:
1254  uint64 seq_num_;
1255  FlipFramer* framer_;
1256
1257  SMServerConnection* connection_;
1258  OutputList* output_list_;
1259  OutputOrdering output_ordering_;
1260  MemoryCache* memory_cache_;
1261  uint32 next_outgoing_stream_id_;
1262 public:
1263  explicit FlipSM(SMServerConnection* connection) :
1264      seq_num_(0),
1265      framer_(new FlipFramer),
1266      connection_(connection),
1267      output_list_(connection->output_list()),
1268      output_ordering_(connection),
1269      memory_cache_(connection->memory_cache()),
1270      next_outgoing_stream_id_(2) {
1271    framer_->set_visitor(this);
1272  }
1273 private:
1274  virtual void OnError(FlipFramer* framer) {
1275    /* do nothing with this right now */
1276  }
1277
1278  virtual void OnControl(const FlipControlFrame* frame) {
1279    FlipHeaderBlock headers;
1280    bool parsed_headers = false;
1281    switch (frame->type()) {
1282      case SYN_STREAM:
1283        {
1284        parsed_headers = framer_->ParseHeaderBlock(frame, &headers);
1285        VLOG(2) << "OnSyn(" << frame->stream_id() << ")";
1286        VLOG(2) << "headers parsed?: " << (parsed_headers? "yes": "no");
1287        if (parsed_headers) {
1288          VLOG(2) << "# headers: " << headers.size();
1289        }
1290        unsigned int j = 0;
1291        for (FlipHeaderBlock::iterator i = headers.begin();
1292             i != headers.end();
1293             ++i) {
1294          VLOG(2) << i->first << ": " << i->second;
1295          if (FLAGS_record_mode && connection_->record_fd() > 0) {
1296            // If record mode is enabled and corresponding server connection
1297            // has file opened, then save the request headers into the file.
1298            // All the requests from the same connection is save in one file.
1299            // This file will be used to replay and generate FLIP requests
1300            // load.
1301            string header = i->first + ": " + i->second + "\n";
1302            ++j;
1303            if (j == headers.size()) {
1304              header += "\n";  // add an additional empty lime
1305            }
1306            int r = write(
1307                connection_->record_fd(), header.c_str(), header.size());
1308            if (r < 0) {
1309              perror("unable to write to record file:");
1310            }
1311          }
1312        }
1313
1314        FlipHeaderBlock::iterator method = headers.find("method");
1315        FlipHeaderBlock::iterator url = headers.find("url");
1316        if (url == headers.end() || method == headers.end()) {
1317          VLOG(2) << "didn't find method or url or method. Not creating stream";
1318          break;
1319        }
1320
1321        FlipHeaderBlock::iterator referer = headers.find("referer");
1322        if (referer != headers.end() && method->second == "GET") {
1323          memory_cache_->UpdateHeaders(referer->second, url->second);
1324        }
1325        string uri = UrlUtilities::GetUrlPath(url->second);
1326        string host = UrlUtilities::GetUrlHost(url->second);
1327        // requests started with /testing are loadtime measurement related
1328        // urls, use LoadtimeMeasurement class to handle them.
1329        if (uri.find("/testing") == 0) {
1330          string output;
1331          global_loadtime_measurement.ProcessRequest(uri, output);
1332          SendOKResponse(frame->stream_id(), &output);
1333        } else {
1334          string filename;
1335          if (FLAGS_need_to_encode_url) {
1336            filename = net::UrlToFilenameEncoder::Encode(
1337                "http://" + host + uri, method->second + "_/");
1338          } else {
1339            filename = string(method->second + "_" + url->second);
1340          }
1341
1342          NewStream(frame->stream_id(),
1343                    reinterpret_cast<const FlipSynStreamControlFrame*>(frame)->
1344                      priority(),
1345                    filename);
1346          }
1347        }
1348        break;
1349
1350      case SYN_REPLY:
1351        parsed_headers = framer_->ParseHeaderBlock(frame, &headers);
1352        VLOG(2) << "OnSynReply(" << frame->stream_id() << ")";
1353        break;
1354      case FIN_STREAM:
1355        VLOG(2) << "OnFin(" << frame->stream_id() << ")";
1356        output_ordering_.RemoveStreamId(frame->stream_id());
1357
1358        break;
1359      default:
1360        LOG(DFATAL) << "Unknown control frame type";
1361    }
1362  }
1363  virtual void OnStreamFrameData(
1364    FlipStreamId stream_id,
1365    const char* data, size_t len) {
1366    VLOG(2) << "StreamData(" << stream_id << ", [" << len << "])";
1367    /* do nothing with this right now */
1368  }
1369  virtual void OnLameDuck() {
1370    /* do nothing with this right now */
1371  }
1372
1373 public:
1374  ~FlipSM() {
1375    Reset();
1376  }
1377  size_t ProcessInput(const char* data, size_t len) {
1378    return framer_->ProcessInput(data, len);
1379  }
1380
1381  bool MessageFullyRead() const {
1382    return framer_->MessageFullyRead();
1383  }
1384
1385  bool Error() const {
1386    return framer_->HasError();
1387  }
1388
1389  const char* ErrorAsString() const {
1390    return FlipFramer::ErrorCodeToString(framer_->error_code());
1391  }
1392
1393  void Reset() {}
1394  void ResetForNewConnection() {
1395    // seq_num is not cleared, intentionally.
1396    delete framer_;
1397    framer_ = new FlipFramer;
1398    framer_->set_visitor(this);
1399    output_ordering_.Reset();
1400    next_outgoing_stream_id_ = 2;
1401  }
1402
1403  // Send a couple of NOOP packets to force opening of cwnd.
1404  void PostAcceptHook() {
1405    if (!FLAGS_use_cwnd_opener)
1406      return;
1407
1408    // We send 2 because that is the initial cwnd, and also because
1409    // we have to in order to get an ACK back from the client due to
1410    // delayed ACK.
1411    const int kPkts = 2;
1412
1413    LOG(ERROR) << "Sending NOP FRAMES";
1414
1415    scoped_ptr<FlipControlFrame> frame(FlipFramer::CreateNopFrame());
1416    for (int i = 0; i < kPkts; ++i) {
1417      char* bytes = frame->data();
1418      size_t size = FlipFrame::size();
1419      ssize_t bytes_written = connection_->Send(bytes, size, MSG_DONTWAIT);
1420      if (bytes_written > 0 && static_cast<size_t>(bytes_written) != size) {
1421        LOG(ERROR) << "Trouble sending Nop packet! (" << errno << ")";
1422        if (errno == EAGAIN)
1423          break;
1424      }
1425    }
1426  }
1427
1428  void AddAssociatedContent(FileData* file_data) {
1429    for (unsigned int i = 0; i < file_data->related_files.size(); ++i) {
1430      pair<int, string>& related_file = file_data->related_files[i];
1431      MemCacheIter mci;
1432      string filename  = "GET_";
1433      filename += related_file.second;
1434      if (!memory_cache_->AssignFileData(filename, &mci)) {
1435        VLOG(1) << "Unable to find associated content for: " << filename;
1436        continue;
1437      }
1438      VLOG(1) << "Adding associated content: " << filename;
1439      mci.stream_id = next_outgoing_stream_id_;
1440      next_outgoing_stream_id_ += 2;
1441      mci.priority =  related_file.first;
1442      AddToOutputOrder(mci);
1443    }
1444  }
1445
1446  void NewStream(uint32 stream_id, uint32 priority, const string& filename) {
1447    MemCacheIter mci;
1448    mci.stream_id = stream_id;
1449    mci.priority = priority;
1450    if (!memory_cache_->AssignFileData(filename, &mci)) {
1451      // error creating new stream.
1452      VLOG(2) << "Sending ErrorNotFound";
1453      SendErrorNotFound(stream_id);
1454    } else {
1455      AddToOutputOrder(mci);
1456      if (FLAGS_use_xac) {
1457        AddAssociatedContent(mci.file_data);
1458      }
1459    }
1460  }
1461
1462  void AddToOutputOrder(const MemCacheIter& mci) {
1463    output_ordering_.AddToOutputOrder(mci);
1464  }
1465
1466  void SendEOF(uint32 stream_id) {
1467    SendEOFImpl(stream_id);
1468  }
1469
1470  void SendErrorNotFound(uint32 stream_id) {
1471    SendErrorNotFoundImpl(stream_id);
1472  }
1473
1474  void SendOKResponse(uint32 stream_id, string* output) {
1475    SendOKResponseImpl(stream_id, output);
1476  }
1477
1478  size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) {
1479    return SendSynStreamImpl(stream_id, headers);
1480  }
1481
1482  size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) {
1483    return SendSynReplyImpl(stream_id, headers);
1484  }
1485
1486  void SendDataFrame(uint32 stream_id, const char* data, int64 len,
1487                     uint32 flags, bool compress) {
1488    FlipDataFlags flip_flags = static_cast<FlipDataFlags>(flags);
1489    SendDataFrameImpl(stream_id, data, len, flip_flags, compress);
1490  }
1491
1492  FlipFramer* flip_framer() { return framer_; }
1493
1494 private:
1495  void SendEOFImpl(uint32 stream_id) {
1496    SendDataFrame(stream_id, NULL, 0, DATA_FLAG_FIN, false);
1497    VLOG(2) << "Sending EOF: " << stream_id;
1498    KillStream(stream_id);
1499  }
1500
1501  void SendErrorNotFoundImpl(uint32 stream_id) {
1502    BalsaHeaders my_headers;
1503    my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "404", "Not Found");
1504    SendSynReplyImpl(stream_id, my_headers);
1505    SendDataFrame(stream_id, "wtf?", 4, DATA_FLAG_FIN, false);
1506    output_ordering_.RemoveStreamId(stream_id);
1507  }
1508
1509  void SendOKResponseImpl(uint32 stream_id, string* output) {
1510    BalsaHeaders my_headers;
1511    my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK");
1512    SendSynReplyImpl(stream_id, my_headers);
1513    SendDataFrame(
1514        stream_id, output->c_str(), output->size(), DATA_FLAG_FIN, false);
1515    output_ordering_.RemoveStreamId(stream_id);
1516  }
1517
1518  void KillStream(uint32 stream_id) {
1519    output_ordering_.RemoveStreamId(stream_id);
1520  }
1521
1522  void CopyHeaders(FlipHeaderBlock& dest, const BalsaHeaders& headers) {
1523    for (BalsaHeaders::const_header_lines_iterator hi =
1524         headers.header_lines_begin();
1525         hi != headers.header_lines_end();
1526         ++hi) {
1527      FlipHeaderBlock::iterator fhi = dest.find(hi->first.as_string());
1528      if (fhi == dest.end()) {
1529        dest[hi->first.as_string()] = hi->second.as_string();
1530      } else {
1531        dest[hi->first.as_string()] = (
1532            string(fhi->second.data(), fhi->second.size()) + "," +
1533            string(hi->second.data(), hi->second.size()));
1534      }
1535    }
1536
1537    // These headers have no value
1538    dest.erase("X-Associated-Content");  // TODO(mbelshe): case-sensitive
1539    dest.erase("X-Original-Url");  // TODO(mbelshe): case-sensitive
1540  }
1541
1542  size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) {
1543    FlipHeaderBlock block;
1544    block["method"] = headers.request_method().as_string();
1545    if (!headers.HasHeader("status"))
1546      block["status"] = headers.response_code().as_string();
1547    if (!headers.HasHeader("version"))
1548      block["version"] =headers.response_version().as_string();
1549    if (headers.HasHeader("X-Original-Url")) {
1550      string original_url = headers.GetHeader("X-Original-Url").as_string();
1551      block["path"] = UrlUtilities::GetUrlPath(original_url);
1552    } else {
1553      block["path"] = headers.request_uri().as_string();
1554    }
1555    CopyHeaders(block, headers);
1556
1557    FlipSynStreamControlFrame* fsrcf =
1558      framer_->CreateSynStream(stream_id, 0, CONTROL_FLAG_NONE, true, &block);
1559    DataFrame df;
1560    df.size = fsrcf->length() + FlipFrame::size();
1561    size_t df_size = df.size;
1562    df.data = fsrcf->data();
1563    df.delete_when_done = true;
1564    EnqueueDataFrame(df);
1565
1566    VLOG(2) << "Sending SynStreamheader " << stream_id;
1567    return df_size;
1568  }
1569
1570  size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) {
1571    FlipHeaderBlock block;
1572    CopyHeaders(block, headers);
1573    block["status"] = headers.response_code().as_string() + " " +
1574                      headers.response_reason_phrase().as_string();
1575    block["version"] = headers.response_version().as_string();
1576
1577    FlipSynReplyControlFrame* fsrcf =
1578      framer_->CreateSynReply(stream_id, CONTROL_FLAG_NONE, true, &block);
1579    DataFrame df;
1580    df.size = fsrcf->length() + FlipFrame::size();
1581    size_t df_size = df.size;
1582    df.data = fsrcf->data();
1583    df.delete_when_done = true;
1584    EnqueueDataFrame(df);
1585
1586    VLOG(2) << "Sending SynReplyheader " << stream_id;
1587    return df_size;
1588  }
1589
1590  void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len,
1591                         FlipDataFlags flags, bool compress) {
1592    // Force compression off if disabled via command line.
1593    if (!FLAGS_use_compression)
1594      flags = static_cast<FlipDataFlags>(flags & ~DATA_FLAG_COMPRESSED);
1595
1596    // TODO(mbelshe):  We can't compress here - before going into the
1597    //                 priority queue.  Compression needs to be done
1598    //                 with late binding.
1599    FlipDataFrame* fdf = framer_->CreateDataFrame(stream_id, data, len,
1600                                                  flags);
1601    DataFrame df;
1602    df.size = fdf->length() + FlipFrame::size();
1603    df.data = fdf->data();
1604    df.delete_when_done = true;
1605    EnqueueDataFrame(df);
1606
1607    VLOG(2) << "Sending data frame" << stream_id << " [" << len << "]"
1608            << " shrunk to " << fdf->length();
1609  }
1610
1611  void EnqueueDataFrame(const DataFrame& df) {
1612    connection_->EnqueueDataFrame(df);
1613  }
1614
1615  void GetOutput() {
1616    while (output_list_->size() < 2) {
1617      MemCacheIter* mci = output_ordering_.GetIter();
1618      if (mci == NULL) {
1619        VLOG(2) << "GetOutput: nothing to output!?";
1620        return;
1621      }
1622      if (!mci->transformed_header) {
1623        mci->transformed_header = true;
1624        VLOG(2) << "GetOutput transformed header stream_id: ["
1625          << mci->stream_id << "]";
1626        if ((mci->stream_id % 2) == 0) {
1627          // this is a server initiated stream.
1628          // Ideally, we'd do a 'syn-push' here, instead of a syn-reply.
1629          BalsaHeaders headers;
1630          headers.CopyFrom(*(mci->file_data->headers));
1631          headers.ReplaceOrAppendHeader("status", "200");
1632          headers.ReplaceOrAppendHeader("version", "http/1.1");
1633          headers.SetRequestFirstlineFromStringPieces("PUSH",
1634                                                      mci->file_data->filename,
1635                                                      "");
1636          mci->bytes_sent = SendSynStream(mci->stream_id, headers);
1637        } else {
1638          BalsaHeaders headers;
1639          headers.CopyFrom(*(mci->file_data->headers));
1640          mci->bytes_sent = SendSynReply(mci->stream_id, headers);
1641        }
1642        return;
1643      }
1644      if (mci->body_bytes_consumed >= mci->file_data->body.size()) {
1645        VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]";
1646        SendEOF(mci->stream_id);
1647        return;
1648      }
1649      size_t num_to_write =
1650        mci->file_data->body.size() - mci->body_bytes_consumed;
1651      if (num_to_write > mci->max_segment_size)
1652        num_to_write = mci->max_segment_size;
1653
1654      bool should_compress = false;
1655      if (!mci->file_data->headers->HasHeader("content-encoding")) {
1656        if (mci->file_data->headers->HasHeader("content-type")) {
1657          string content_type =
1658              mci->file_data->headers->GetHeader("content-type").as_string();
1659          if (content_type.find("image") == content_type.npos)
1660            should_compress = true;
1661        }
1662      }
1663
1664      SendDataFrame(mci->stream_id,
1665                    mci->file_data->body.data() + mci->body_bytes_consumed,
1666                    num_to_write, 0, should_compress);
1667      VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id
1668        << "]: " << num_to_write;
1669      mci->body_bytes_consumed += num_to_write;
1670      mci->bytes_sent += num_to_write;
1671    }
1672  }
1673};
1674
1675////////////////////////////////////////////////////////////////////////////////
1676
1677class HTTPSM : public BalsaVisitorInterface, public SMInterface {
1678 private:
1679  uint64 seq_num_;
1680  BalsaFrame* framer_;
1681  BalsaHeaders headers_;
1682  uint32 stream_id_;
1683
1684  SMServerConnection* connection_;
1685  OutputList* output_list_;
1686  OutputOrdering output_ordering_;
1687  MemoryCache* memory_cache_;
1688 public:
1689  explicit HTTPSM(SMServerConnection* connection) :
1690      seq_num_(0),
1691      framer_(new BalsaFrame),
1692      stream_id_(1),
1693      connection_(connection),
1694      output_list_(connection->output_list()),
1695      output_ordering_(connection),
1696      memory_cache_(connection->memory_cache()) {
1697    framer_->set_balsa_visitor(this);
1698    framer_->set_balsa_headers(&headers_);
1699  }
1700 private:
1701  typedef map<string, uint32> ClientTokenMap;
1702 private:
1703    virtual void ProcessBodyInput(const char *input, size_t size) {
1704    }
1705    virtual void ProcessBodyData(const char *input, size_t size) {
1706      // ignoring this.
1707    }
1708    virtual void ProcessHeaderInput(const char *input, size_t size) {
1709    }
1710    virtual void ProcessTrailerInput(const char *input, size_t size) {}
1711    virtual void ProcessHeaders(const BalsaHeaders& headers) {
1712      VLOG(2) << "Got new request!";
1713      // requests started with /testing are loadtime measurement related
1714      // urls, use LoadtimeMeasurement class to handle them.
1715      if (headers.request_uri().as_string().find("/testing") == 0) {
1716        string output;
1717        global_loadtime_measurement.ProcessRequest(
1718            headers.request_uri().as_string(), output);
1719        SendOKResponse(stream_id_, &output);
1720        stream_id_ += 2;
1721      } else {
1722        string filename;
1723        if (FLAGS_need_to_encode_url) {
1724          filename = net::UrlToFilenameEncoder::Encode(
1725              headers.GetHeader("Host").as_string() +
1726              headers.request_uri().as_string(),
1727              headers.request_method().as_string() + "_/");
1728        } else {
1729         filename = headers.request_method().as_string() + "_" +
1730                    headers.request_uri().as_string();
1731        }
1732        NewStream(stream_id_, 0, filename);
1733        stream_id_ += 2;
1734      }
1735    }
1736    virtual void ProcessRequestFirstLine(const char* line_input,
1737                                         size_t line_length,
1738                                         const char* method_input,
1739                                         size_t method_length,
1740                                         const char* request_uri_input,
1741                                         size_t request_uri_length,
1742                                         const char* version_input,
1743                                         size_t version_length) {}
1744    virtual void ProcessResponseFirstLine(const char *line_input,
1745                                          size_t line_length,
1746                                          const char *version_input,
1747                                          size_t version_length,
1748                                          const char *status_input,
1749                                          size_t status_length,
1750                                          const char *reason_input,
1751                                          size_t reason_length) {}
1752    virtual void ProcessChunkLength(size_t chunk_length) {}
1753    virtual void ProcessChunkExtensions(const char *input, size_t size) {}
1754    virtual void HeaderDone() {}
1755    virtual void MessageDone() {
1756      VLOG(2) << "MessageDone!";
1757    }
1758    virtual void HandleHeaderError(BalsaFrame* framer) {
1759      HandleError();
1760    }
1761    virtual void HandleHeaderWarning(BalsaFrame* framer) {}
1762    virtual void HandleChunkingError(BalsaFrame* framer) {
1763      HandleError();
1764    }
1765    virtual void HandleBodyError(BalsaFrame* framer) {
1766      HandleError();
1767    }
1768
1769    void HandleError() {
1770      VLOG(2) << "Error detected";
1771    }
1772
1773 public:
1774  ~HTTPSM() {
1775    Reset();
1776  }
1777  size_t ProcessInput(const char* data, size_t len) {
1778    return framer_->ProcessInput(data, len);
1779  }
1780
1781  bool MessageFullyRead() const {
1782    return framer_->MessageFullyRead();
1783  }
1784
1785  bool Error() const {
1786    return framer_->Error();
1787  }
1788
1789  const char* ErrorAsString() const {
1790    return BalsaFrameEnums::ErrorCodeToString(framer_->ErrorCode());
1791  }
1792
1793  void Reset() {
1794    framer_->Reset();
1795  }
1796
1797  void ResetForNewConnection() {
1798    seq_num_ = 0;
1799    output_ordering_.Reset();
1800    framer_->Reset();
1801  }
1802
1803  void PostAcceptHook() {
1804  }
1805
1806  void NewStream(uint32 stream_id, uint32 priority, const string& filename) {
1807    MemCacheIter mci;
1808    mci.stream_id = stream_id;
1809    mci.priority = priority;
1810    if (!memory_cache_->AssignFileData(filename, &mci)) {
1811      SendErrorNotFound(stream_id);
1812    } else {
1813      AddToOutputOrder(mci);
1814    }
1815  }
1816
1817  void AddToOutputOrder(const MemCacheIter& mci) {
1818    output_ordering_.AddToOutputOrder(mci);
1819  }
1820
1821  void SendEOF(uint32 stream_id) {
1822    SendEOFImpl(stream_id);
1823  }
1824
1825  void SendErrorNotFound(uint32 stream_id) {
1826    SendErrorNotFoundImpl(stream_id);
1827  }
1828
1829  void SendOKResponse(uint32 stream_id, string* output) {
1830    SendOKResponseImpl(stream_id, output);
1831  }
1832
1833  size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) {
1834    return 0;
1835  }
1836
1837  size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) {
1838    return SendSynReplyImpl(stream_id, headers);
1839  }
1840
1841  void SendDataFrame(uint32 stream_id, const char* data, int64 len,
1842                     uint32 flags, bool compress) {
1843    SendDataFrameImpl(stream_id, data, len, flags, compress);
1844  }
1845
1846  BalsaFrame* flip_framer() { return framer_; }
1847
1848 private:
1849  void SendEOFImpl(uint32 stream_id) {
1850    DataFrame df;
1851    df.data = "0\r\n\r\n";
1852    df.size = 5;
1853    df.delete_when_done = false;
1854    EnqueueDataFrame(df);
1855  }
1856
1857  void SendErrorNotFoundImpl(uint32 stream_id) {
1858    BalsaHeaders my_headers;
1859    my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "404", "Not Found");
1860    my_headers.RemoveAllOfHeader("content-length");
1861    my_headers.HackHeader("transfer-encoding", "chunked");
1862    SendSynReplyImpl(stream_id, my_headers);
1863    SendDataFrame(stream_id, "wtf?", 4, 0, false);
1864    SendEOFImpl(stream_id);
1865    output_ordering_.RemoveStreamId(stream_id);
1866  }
1867
1868  void SendOKResponseImpl(uint32 stream_id, string* output) {
1869    BalsaHeaders my_headers;
1870    my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK");
1871    my_headers.RemoveAllOfHeader("content-length");
1872    my_headers.HackHeader("transfer-encoding", "chunked");
1873    SendSynReplyImpl(stream_id, my_headers);
1874    SendDataFrame(stream_id, output->c_str(), output->size(), 0, false);
1875    SendEOFImpl(stream_id);
1876    output_ordering_.RemoveStreamId(stream_id);
1877  }
1878
1879  size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) {
1880    SimpleBuffer sb;
1881    headers.WriteHeaderAndEndingToBuffer(&sb);
1882    DataFrame df;
1883    df.size = sb.ReadableBytes();
1884    char* buffer = new char[df.size];
1885    df.data = buffer;
1886    df.delete_when_done = true;
1887    sb.Read(buffer, df.size);
1888    VLOG(2) << "******************Sending HTTP Reply header " << stream_id;
1889    size_t df_size = df.size;
1890    EnqueueDataFrame(df);
1891    return df_size;
1892  }
1893
1894  size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) {
1895    SimpleBuffer sb;
1896    headers.WriteHeaderAndEndingToBuffer(&sb);
1897    DataFrame df;
1898    df.size = sb.ReadableBytes();
1899    char* buffer = new char[df.size];
1900    df.data = buffer;
1901    df.delete_when_done = true;
1902    sb.Read(buffer, df.size);
1903    VLOG(2) << "******************Sending HTTP Reply header " << stream_id;
1904    size_t df_size = df.size;
1905    EnqueueDataFrame(df);
1906    return df_size;
1907  }
1908
1909  void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len,
1910                         uint32 flags, bool compress) {
1911    char chunk_buf[128];
1912    snprintf(chunk_buf, sizeof(chunk_buf), "%x\r\n", (unsigned int)len);
1913    string chunk_description(chunk_buf);
1914    DataFrame df;
1915    df.size = chunk_description.size() + len + 2;
1916    char* buffer = new char[df.size];
1917    df.data = buffer;
1918    df.delete_when_done = true;
1919    memcpy(buffer, chunk_description.data(), chunk_description.size());
1920    memcpy(buffer + chunk_description.size(), data, len);
1921    memcpy(buffer + chunk_description.size() + len, "\r\n", 2);
1922    EnqueueDataFrame(df);
1923  }
1924
1925  void EnqueueDataFrame(const DataFrame& df) {
1926    connection_->EnqueueDataFrame(df);
1927  }
1928
1929  void GetOutput() {
1930    MemCacheIter* mci = output_ordering_.GetIter();
1931    if (mci == NULL) {
1932      VLOG(2) << "GetOutput: nothing to output!?";
1933      return;
1934    }
1935    if (!mci->transformed_header) {
1936      mci->bytes_sent = SendSynReply(mci->stream_id,
1937                                     *(mci->file_data->headers));
1938      mci->transformed_header = true;
1939      VLOG(2) << "GetOutput transformed header stream_id: ["
1940        << mci->stream_id << "]";
1941      return;
1942    }
1943    if (mci->body_bytes_consumed >= mci->file_data->body.size()) {
1944      SendEOF(mci->stream_id);
1945      output_ordering_.RemoveStreamId(mci->stream_id);
1946      VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]";
1947      return;
1948    }
1949    size_t num_to_write =
1950      mci->file_data->body.size() - mci->body_bytes_consumed;
1951    if (num_to_write > mci->max_segment_size)
1952      num_to_write = mci->max_segment_size;
1953    SendDataFrame(mci->stream_id,
1954                  mci->file_data->body.data() + mci->body_bytes_consumed,
1955                  num_to_write, 0, true);
1956    VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id
1957      << "]: " << num_to_write;
1958    mci->body_bytes_consumed += num_to_write;
1959    mci->bytes_sent += num_to_write;
1960  }
1961};
1962
1963////////////////////////////////////////////////////////////////////////////////
1964
1965class Notification {
1966 public:
1967  explicit Notification(bool value) : value_(value) {}
1968
1969  void Notify() {
1970    AutoLock al(lock_);
1971    value_ = true;
1972  }
1973  bool HasBeenNotified() {
1974    AutoLock al(lock_);
1975    return value_;
1976  }
1977  bool value_;
1978  Lock lock_;
1979};
1980
1981////////////////////////////////////////////////////////////////////////////////
1982
1983class SMAcceptorThread : public SimpleThread,
1984                         public EpollCallbackInterface,
1985                         public SMServerConnectionPoolInterface {
1986  EpollServer epoll_server_;
1987  int listen_fd_;
1988  int accepts_per_wake_;
1989
1990  vector<SMServerConnection*> unused_server_connections_;
1991  vector<SMServerConnection*> tmp_unused_server_connections_;
1992  vector<SMServerConnection*> allocated_server_connections_;
1993  Notification quitting_;
1994  SMInterfaceFactory* sm_interface_factory_;
1995  MemoryCache* memory_cache_;
1996 public:
1997
1998  SMAcceptorThread(int listen_fd,
1999                   int accepts_per_wake,
2000                   SMInterfaceFactory* smif,
2001                   MemoryCache* memory_cache) :
2002      SimpleThread("SMAcceptorThread"),
2003      listen_fd_(listen_fd),
2004      accepts_per_wake_(accepts_per_wake),
2005      quitting_(false),
2006      sm_interface_factory_(smif),
2007      memory_cache_(memory_cache) {
2008  }
2009
2010  ~SMAcceptorThread() {
2011    for (vector<SMServerConnection*>::iterator i =
2012           allocated_server_connections_.begin();
2013         i != allocated_server_connections_.end();
2014         ++i) {
2015      delete *i;
2016    }
2017  }
2018
2019  SMServerConnection* NewConnection() {
2020    SMServerConnection* server =
2021      SMServerConnection::NewSMServerConnection(sm_interface_factory_,
2022                                                memory_cache_,
2023                                                &epoll_server_);
2024    allocated_server_connections_.push_back(server);
2025    VLOG(3) << "Making new server: " << server;
2026    return server;
2027  }
2028
2029  SMServerConnection* FindOrMakeNewSMServerConnection() {
2030    if (unused_server_connections_.empty()) {
2031      return NewConnection();
2032    }
2033    SMServerConnection* retval = unused_server_connections_.back();
2034    unused_server_connections_.pop_back();
2035    return retval;
2036  }
2037
2038
2039  void InitWorker() {
2040    epoll_server_.RegisterFD(listen_fd_, this, EPOLLIN | EPOLLET);
2041  }
2042
2043  void HandleConnection(int client_fd) {
2044    SMServerConnection* server_connection = FindOrMakeNewSMServerConnection();
2045    if (server_connection == NULL) {
2046      VLOG(2) << "Closing " << client_fd;
2047      close(client_fd);
2048      return;
2049    }
2050    server_connection->InitSMServerConnection(this,
2051                                            &epoll_server_,
2052                                            client_fd);
2053  }
2054
2055  void AcceptFromListenFD() {
2056    if (accepts_per_wake_ > 0) {
2057      for (int i = 0; i < accepts_per_wake_; ++i) {
2058        struct sockaddr address;
2059        socklen_t socklen = sizeof(address);
2060        int fd = accept(listen_fd_, &address, &socklen);
2061        if (fd == -1) {
2062          VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno;
2063          break;
2064        }
2065        VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n";
2066        HandleConnection(fd);
2067      }
2068    } else {
2069      while (true) {
2070        struct sockaddr address;
2071        socklen_t socklen = sizeof(address);
2072        int fd = accept(listen_fd_, &address, &socklen);
2073        if (fd == -1) {
2074          VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno;
2075          break;
2076        }
2077        VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n";
2078        HandleConnection(fd);
2079      }
2080    }
2081  }
2082
2083  // EpollCallbackInteface virtual functions.
2084  virtual void OnRegistration(EpollServer* eps, int fd, int event_mask) { }
2085  virtual void OnModification(int fd, int event_mask) { }
2086  virtual void OnEvent(int fd, EpollEvent* event) {
2087    if (event->in_events | EPOLLIN) {
2088      VLOG(2) << "Accepting based upon epoll events";
2089      AcceptFromListenFD();
2090    }
2091  }
2092  virtual void OnUnregistration(int fd, bool replaced) { }
2093  virtual void OnShutdown(EpollServer* eps, int fd) { }
2094
2095  void Quit() {
2096    quitting_.Notify();
2097  }
2098
2099  void Run() {
2100    while (!quitting_.HasBeenNotified()) {
2101      epoll_server_.set_timeout_in_us(10 * 1000);  // 10 ms
2102      epoll_server_.WaitForEventsAndExecuteCallbacks();
2103      unused_server_connections_.insert(unused_server_connections_.end(),
2104                                        tmp_unused_server_connections_.begin(),
2105                                        tmp_unused_server_connections_.end());
2106      tmp_unused_server_connections_.clear();
2107    }
2108  }
2109
2110  // SMServerConnections will use this:
2111  virtual void SMServerConnectionDone(SMServerConnection* sc) {
2112    VLOG(3) << "Done with server connection: " << sc;
2113    sc->close_record_fd();
2114    tmp_unused_server_connections_.push_back(sc);
2115  }
2116};
2117
2118////////////////////////////////////////////////////////////////////////////////
2119
2120SMInterface* NewFlipSM(SMServerConnection* connection) {
2121  return new FlipSM(connection);
2122}
2123
2124SMInterface* NewHTTPSM(SMServerConnection* connection) {
2125  return new HTTPSM(connection);
2126}
2127
2128////////////////////////////////////////////////////////////////////////////////
2129
2130int CreateListeningSocket(int port, int backlog_size,
2131                          bool reuseport, bool no_nagle) {
2132  int listening_socket = 0;
2133  char port_buf[256];
2134  snprintf(port_buf, sizeof(port_buf), "%d", port);
2135  cerr <<" Attempting to listen on port: " << port_buf << "\n";
2136  cerr <<" input port: " << port << "\n";
2137  net::CreateListeningSocket("",
2138                              port_buf,
2139                              true,
2140                              backlog_size,
2141                              &listening_socket,
2142                              true,
2143                              reuseport,
2144                              &cerr);
2145  SetNonBlocking(listening_socket);
2146  if (no_nagle) {
2147    // set SO_REUSEADDR on the listening socket.
2148    int on = 1;
2149    int rc;
2150    rc = setsockopt(listening_socket, IPPROTO_TCP,  TCP_NODELAY,
2151                    reinterpret_cast<char *>(&on), sizeof(on));
2152    if (rc < 0) {
2153      close(listening_socket);
2154      LOG(FATAL) << "setsockopt() failed fd=" << listening_socket << "\n";
2155    }
2156  }
2157  return listening_socket;
2158}
2159
2160////////////////////////////////////////////////////////////////////////////////
2161
2162bool GotQuitFromStdin() {
2163  // Make stdin nonblocking. Yes this is done each time. Oh well.
2164  fcntl(0, F_SETFL, O_NONBLOCK);
2165  char c;
2166  string maybequit;
2167  while (read(0, &c, 1) > 0) {
2168    maybequit += c;
2169  }
2170  if (maybequit.size()) {
2171    VLOG(2) << "scanning string: \"" << maybequit << "\"";
2172  }
2173  return (maybequit.size() > 1 &&
2174          (maybequit.c_str()[0] == 'q' ||
2175           maybequit.c_str()[0] == 'Q'));
2176}
2177
2178
2179////////////////////////////////////////////////////////////////////////////////
2180
2181const char* BoolToStr(bool b) {
2182  if (b)
2183    return "true";
2184  return "false";
2185}
2186
2187////////////////////////////////////////////////////////////////////////////////
2188
2189int main(int argc, char**argv) {
2190  bool use_ssl = FLAGS_use_ssl;
2191  int response_count_until_close = FLAGS_response_count_until_close;
2192  int flip_port = FLAGS_flip_port;
2193  int port = FLAGS_port;
2194  int backlog_size = FLAGS_accept_backlog_size;
2195  bool reuseport = FLAGS_reuseport;
2196  bool no_nagle = FLAGS_no_nagle;
2197  double server_think_time_in_s = FLAGS_server_think_time_in_s;
2198  int accepts_per_wake = FLAGS_accepts_per_wake;
2199  int num_threads = 1;
2200
2201  MemoryCache flip_memory_cache;
2202  flip_memory_cache.AddFiles();
2203
2204  MemoryCache http_memory_cache;
2205  http_memory_cache.CloneFrom(flip_memory_cache);
2206
2207  LOG(INFO) <<
2208    "Starting up with the following state: \n"
2209    "                      use_ssl: " << use_ssl << "\n"
2210    "   response_count_until_close: " << response_count_until_close << "\n"
2211    "                         port: " << port << "\n"
2212    "                    flip_port: " << flip_port << "\n"
2213    "                 backlog_size: " << backlog_size << "\n"
2214    "                    reuseport: " << BoolToStr(reuseport) << "\n"
2215    "                     no_nagle: " << BoolToStr(no_nagle) << "\n"
2216    "       server_think_time_in_s: " << server_think_time_in_s << "\n"
2217    "             accepts_per_wake: " << accepts_per_wake << "\n"
2218    "                  num_threads: " << num_threads << "\n"
2219    "                     use_xsub: " << BoolToStr(FLAGS_use_xsub) << "\n"
2220    "                      use_xac: " << BoolToStr(FLAGS_use_xac) << "\n";
2221
2222  if (use_ssl) {
2223    global_ssl_state = new GlobalSSLState;
2224    flip_init_ssl(global_ssl_state);
2225  } else {
2226    global_ssl_state = NULL;
2227  }
2228  EpollServer epoll_server;
2229  vector<SMAcceptorThread*> sm_worker_threads_;
2230
2231  {
2232    // flip
2233    int listen_fd = -1;
2234
2235    if (reuseport || listen_fd == -1) {
2236      listen_fd = CreateListeningSocket(flip_port, backlog_size,
2237                                        reuseport, no_nagle);
2238      if (listen_fd < 0) {
2239        LOG(FATAL) << "Unable to open listening socket on flip_port: "
2240          << flip_port;
2241      } else {
2242        LOG(INFO) << "Listening for flip on port: " << flip_port;
2243      }
2244    }
2245    sm_worker_threads_.push_back(
2246        new SMAcceptorThread(listen_fd,
2247                             accepts_per_wake,
2248                             &NewFlipSM,
2249                             &flip_memory_cache));
2250    // Note that flip_memory_cache is not threadsafe, it is merely
2251    // thread compatible. Thus, if ever we are to spawn multiple threads,
2252    // we either must make the MemoryCache threadsafe, or use
2253    // a separate MemoryCache for each thread.
2254    //
2255    // The latter is what is currently being done as we spawn
2256    // two threads (one for flip, one for http).
2257    sm_worker_threads_.back()->InitWorker();
2258    sm_worker_threads_.back()->Start();
2259  }
2260
2261  {
2262    // http
2263    int listen_fd = -1;
2264    if (reuseport || listen_fd == -1) {
2265      listen_fd = CreateListeningSocket(port, backlog_size,
2266                                        reuseport, no_nagle);
2267      if (listen_fd < 0) {
2268        LOG(FATAL) << "Unable to open listening socket on port: " << port;
2269      } else {
2270        LOG(INFO) << "Listening for HTTP on port: " << port;
2271      }
2272    }
2273    sm_worker_threads_.push_back(
2274        new SMAcceptorThread(listen_fd,
2275                             accepts_per_wake,
2276                             &NewHTTPSM,
2277                             &http_memory_cache));
2278    // Note that flip_memory_cache is not threadsafe, it is merely
2279    // thread compatible. Thus, if ever we are to spawn multiple threads,
2280    // we either must make the MemoryCache threadsafe, or use
2281    // a separate MemoryCache for each thread.
2282    //
2283    // The latter is what is currently being done as we spawn
2284    // two threads (one for flip, one for http).
2285    sm_worker_threads_.back()->InitWorker();
2286    sm_worker_threads_.back()->Start();
2287  }
2288
2289  while (true) {
2290    if (GotQuitFromStdin()) {
2291      for (unsigned int i = 0; i < sm_worker_threads_.size(); ++i) {
2292        sm_worker_threads_[i]->Quit();
2293      }
2294      for (unsigned int i = 0; i < sm_worker_threads_.size(); ++i) {
2295        sm_worker_threads_[i]->Join();
2296      }
2297      return 0;
2298    }
2299    usleep(1000*10);  // 10 ms
2300  }
2301  return 0;
2302}
2303
2304