1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/common/sqlite_utils.h"
6
7#include <list>
8
9#include "base/file_path.h"
10#include "base/lazy_instance.h"
11#include "base/logging.h"
12#include "base/stl_util-inl.h"
13#include "base/string16.h"
14#include "base/synchronization/lock.h"
15
16// The vanilla error handler implements the common fucntionality for all the
17// error handlers. Specialized error handlers are expected to only override
18// the Handler() function.
19class VanillaSQLErrorHandler : public SQLErrorHandler {
20 public:
21  VanillaSQLErrorHandler() : error_(SQLITE_OK) {
22  }
23  virtual int GetLastError() const {
24    return error_;
25  }
26 protected:
27  int error_;
28};
29
30class DebugSQLErrorHandler: public VanillaSQLErrorHandler {
31 public:
32  virtual int HandleError(int error, sqlite3* db) {
33    error_ = error;
34    NOTREACHED() << "sqlite error " << error
35                 << " db " << static_cast<void*>(db);
36    return error;
37  }
38};
39
40class ReleaseSQLErrorHandler : public VanillaSQLErrorHandler {
41 public:
42  virtual int HandleError(int error, sqlite3* db) {
43    error_ = error;
44    // Used to have a CHECK here. Got lots of crashes.
45    return error;
46  }
47};
48
49// The default error handler factory is also in charge of managing the
50// lifetime of the error objects. This object is multi-thread safe.
51class DefaultSQLErrorHandlerFactory : public SQLErrorHandlerFactory {
52 public:
53  ~DefaultSQLErrorHandlerFactory() {
54    STLDeleteContainerPointers(errors_.begin(), errors_.end());
55  }
56
57  virtual SQLErrorHandler* Make() {
58    SQLErrorHandler* handler;
59#ifndef NDEBUG
60    handler = new DebugSQLErrorHandler;
61#else
62    handler = new ReleaseSQLErrorHandler;
63#endif  // NDEBUG
64    AddHandler(handler);
65    return handler;
66  }
67
68 private:
69  void AddHandler(SQLErrorHandler* handler) {
70    base::AutoLock lock(lock_);
71    errors_.push_back(handler);
72  }
73
74  typedef std::list<SQLErrorHandler*> ErrorList;
75  ErrorList errors_;
76  base::Lock lock_;
77};
78
79static base::LazyInstance<DefaultSQLErrorHandlerFactory>
80    g_default_sql_error_handler_factory(base::LINKER_INITIALIZED);
81
82SQLErrorHandlerFactory* GetErrorHandlerFactory() {
83  // TODO(cpu): Testing needs to override the error handler.
84  // Destruction of DefaultSQLErrorHandlerFactory handled by at_exit manager.
85  return g_default_sql_error_handler_factory.Pointer();
86}
87
88namespace sqlite_utils {
89
90int OpenSqliteDb(const FilePath& filepath, sqlite3** database) {
91#if defined(OS_WIN)
92  // We want the default encoding to always be UTF-8, so we use the
93  // 8-bit version of open().
94  return sqlite3_open(WideToUTF8(filepath.value()).c_str(), database);
95#elif defined(OS_POSIX)
96  return sqlite3_open(filepath.value().c_str(), database);
97#endif
98}
99
100bool DoesSqliteTableExist(sqlite3* db,
101                          const char* db_name,
102                          const char* table_name) {
103  // sqlite doesn't allow binding parameters as table names, so we have to
104  // manually construct the sql
105  std::string sql("SELECT name FROM ");
106  if (db_name && db_name[0]) {
107    sql.append(db_name);
108    sql.push_back('.');
109  }
110  sql.append("sqlite_master WHERE type='table' AND name=?");
111
112  SQLStatement statement;
113  if (statement.prepare(db, sql.c_str()) != SQLITE_OK)
114    return false;
115
116  if (statement.bind_text(0, table_name) != SQLITE_OK)
117    return false;
118
119  // we only care about if this matched a row, not the actual data
120  return sqlite3_step(statement.get()) == SQLITE_ROW;
121}
122
123bool DoesSqliteColumnExist(sqlite3* db,
124                           const char* database_name,
125                           const char* table_name,
126                           const char* column_name,
127                           const char* column_type) {
128  SQLStatement s;
129  std::string sql;
130  sql.append("PRAGMA ");
131  if (database_name && database_name[0]) {
132    // optional database name specified
133    sql.append(database_name);
134    sql.push_back('.');
135  }
136  sql.append("TABLE_INFO(");
137  sql.append(table_name);
138  sql.append(")");
139
140  if (s.prepare(db, sql.c_str()) != SQLITE_OK)
141    return false;
142
143  while (s.step() == SQLITE_ROW) {
144    if (!s.column_string(1).compare(column_name)) {
145      if (column_type && column_type[0])
146        return !s.column_string(2).compare(column_type);
147      return true;
148    }
149  }
150  return false;
151}
152
153bool DoesSqliteTableHaveRow(sqlite3* db, const char* table_name) {
154  SQLStatement s;
155  std::string b;
156  b.append("SELECT * FROM ");
157  b.append(table_name);
158
159  if (s.prepare(db, b.c_str()) != SQLITE_OK)
160    return false;
161
162  return s.step() == SQLITE_ROW;
163}
164
165}  // namespace sqlite_utils
166
167SQLTransaction::SQLTransaction(sqlite3* db) : db_(db), began_(false) {
168}
169
170SQLTransaction::~SQLTransaction() {
171  if (began_) {
172    Rollback();
173  }
174}
175
176int SQLTransaction::BeginCommand(const char* command) {
177  int rv = SQLITE_ERROR;
178  if (!began_ && db_) {
179    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
180    began_ = (rv == SQLITE_OK);
181  }
182  return rv;
183}
184
185int SQLTransaction::EndCommand(const char* command) {
186  int rv = SQLITE_ERROR;
187  if (began_ && db_) {
188    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
189    began_ = (rv != SQLITE_OK);
190  }
191  return rv;
192}
193
194SQLNestedTransactionSite::~SQLNestedTransactionSite() {
195  DCHECK(!top_transaction_);
196}
197
198void SQLNestedTransactionSite::SetTopTransaction(SQLNestedTransaction* top) {
199  DCHECK(!top || !top_transaction_);
200  top_transaction_ = top;
201}
202
203SQLNestedTransaction::SQLNestedTransaction(SQLNestedTransactionSite* site)
204  : SQLTransaction(site->GetSqlite3DB()),
205    needs_rollback_(false),
206    site_(site) {
207  DCHECK(site);
208  if (site->GetTopTransaction() == NULL) {
209    site->SetTopTransaction(this);
210  }
211}
212
213SQLNestedTransaction::~SQLNestedTransaction() {
214  if (began_) {
215    Rollback();
216  }
217  if (site_->GetTopTransaction() == this) {
218    site_->SetTopTransaction(NULL);
219  }
220}
221
222int SQLNestedTransaction::BeginCommand(const char* command) {
223  DCHECK(db_);
224  DCHECK(site_ && site_->GetTopTransaction());
225  if (!db_ || began_) {
226    return SQLITE_ERROR;
227  }
228  if (site_->GetTopTransaction() == this) {
229    int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
230    began_ = (rv == SQLITE_OK);
231    if (began_) {
232      site_->OnBegin();
233    }
234    return rv;
235  } else {
236    if (site_->GetTopTransaction()->needs_rollback_) {
237      return SQLITE_ERROR;
238    }
239    began_ = true;
240    return SQLITE_OK;
241  }
242}
243
244int SQLNestedTransaction::EndCommand(const char* command) {
245  DCHECK(db_);
246  DCHECK(site_ && site_->GetTopTransaction());
247  if (!db_ || !began_) {
248    return SQLITE_ERROR;
249  }
250  if (site_->GetTopTransaction() == this) {
251    if (needs_rollback_) {
252      sqlite3_exec(db_, "ROLLBACK", NULL, NULL, NULL);
253      began_ = false;  // reset so we don't try to rollback or call
254                       // OnRollback() again
255      site_->OnRollback();
256      return SQLITE_ERROR;
257    } else {
258      int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
259      began_ = (rv != SQLITE_OK);
260      if (strcmp(command, "ROLLBACK") == 0) {
261        began_ = false;  // reset so we don't try to rollbck or call
262                         // OnRollback() again
263        site_->OnRollback();
264      } else {
265        DCHECK(strcmp(command, "COMMIT") == 0);
266        if (rv == SQLITE_OK) {
267          site_->OnCommit();
268        }
269      }
270      return rv;
271    }
272  } else {
273    if (strcmp(command, "ROLLBACK") == 0) {
274      site_->GetTopTransaction()->needs_rollback_ = true;
275    }
276    began_ = false;
277    return SQLITE_OK;
278  }
279}
280
281int SQLStatement::prepare(sqlite3* db, const char* sql, int sql_len) {
282  DCHECK(!stmt_);
283  int rv = sqlite3_prepare_v2(db, sql, sql_len, &stmt_, NULL);
284  if (rv != SQLITE_OK) {
285    SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
286    return error_handler->HandleError(rv, db);
287  }
288  return rv;
289}
290
291int SQLStatement::step() {
292  DCHECK(stmt_);
293  int status = sqlite3_step(stmt_);
294  if ((status == SQLITE_ROW) || (status == SQLITE_DONE))
295    return status;
296  // We got a problem.
297  SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
298  return error_handler->HandleError(status, db_handle());
299}
300
301int SQLStatement::reset() {
302  DCHECK(stmt_);
303  return sqlite3_reset(stmt_);
304}
305
306sqlite_int64 SQLStatement::last_insert_rowid() {
307  DCHECK(stmt_);
308  return sqlite3_last_insert_rowid(db_handle());
309}
310
311int SQLStatement::changes() {
312  DCHECK(stmt_);
313  return sqlite3_changes(db_handle());
314}
315
316sqlite3* SQLStatement::db_handle() {
317  DCHECK(stmt_);
318  return sqlite3_db_handle(stmt_);
319}
320
321int SQLStatement::bind_parameter_count() {
322  DCHECK(stmt_);
323  return sqlite3_bind_parameter_count(stmt_);
324}
325
326int SQLStatement::bind_blob(int index, std::vector<unsigned char>* blob) {
327  if (blob) {
328    const void* value = blob->empty() ? NULL : &(*blob)[0];
329    int len = static_cast<int>(blob->size());
330    return bind_blob(index, value, len);
331  } else {
332    return bind_null(index);
333  }
334}
335
336int SQLStatement::bind_blob(int index, const void* value, int value_len) {
337   return bind_blob(index, value, value_len, SQLITE_TRANSIENT);
338}
339
340int SQLStatement::bind_blob(int index, const void* value, int value_len,
341                            Function dtor) {
342  DCHECK(stmt_);
343  return sqlite3_bind_blob(stmt_, index + 1, value, value_len, dtor);
344}
345
346int SQLStatement::bind_double(int index, double value) {
347  DCHECK(stmt_);
348  return sqlite3_bind_double(stmt_, index + 1, value);
349}
350
351int SQLStatement::bind_bool(int index, bool value) {
352  DCHECK(stmt_);
353  return sqlite3_bind_int(stmt_, index + 1, value);
354}
355
356int SQLStatement::bind_int(int index, int value) {
357  DCHECK(stmt_);
358  return sqlite3_bind_int(stmt_, index + 1, value);
359}
360
361int SQLStatement::bind_int64(int index, sqlite_int64 value) {
362  DCHECK(stmt_);
363  return sqlite3_bind_int64(stmt_, index + 1, value);
364}
365
366int SQLStatement::bind_null(int index) {
367  DCHECK(stmt_);
368  return sqlite3_bind_null(stmt_, index + 1);
369}
370
371int SQLStatement::bind_text(int index, const char* value, int value_len,
372              Function dtor) {
373  DCHECK(stmt_);
374  return sqlite3_bind_text(stmt_, index + 1, value, value_len, dtor);
375}
376
377int SQLStatement::bind_text16(int index, const char16* value, int value_len,
378                Function dtor) {
379  DCHECK(stmt_);
380  value_len *= sizeof(char16);
381  return sqlite3_bind_text16(stmt_, index + 1, value, value_len, dtor);
382}
383
384int SQLStatement::bind_value(int index, const sqlite3_value* value) {
385  DCHECK(stmt_);
386  return sqlite3_bind_value(stmt_, index + 1, value);
387}
388
389int SQLStatement::column_count() {
390  DCHECK(stmt_);
391  return sqlite3_column_count(stmt_);
392}
393
394int SQLStatement::column_type(int index) {
395  DCHECK(stmt_);
396  return sqlite3_column_type(stmt_, index);
397}
398
399const void* SQLStatement::column_blob(int index) {
400  DCHECK(stmt_);
401  return sqlite3_column_blob(stmt_, index);
402}
403
404bool SQLStatement::column_blob_as_vector(int index,
405                                         std::vector<unsigned char>* blob) {
406  DCHECK(stmt_);
407  const void* p = column_blob(index);
408  size_t len = column_bytes(index);
409  blob->resize(len);
410  if (blob->size() != len) {
411    return false;
412  }
413  if (len > 0)
414    memcpy(&(blob->front()), p, len);
415  return true;
416}
417
418bool SQLStatement::column_blob_as_string(int index, std::string* blob) {
419  DCHECK(stmt_);
420  const void* p = column_blob(index);
421  size_t len = column_bytes(index);
422  blob->resize(len);
423  if (blob->size() != len) {
424    return false;
425  }
426  blob->assign(reinterpret_cast<const char*>(p), len);
427  return true;
428}
429
430int SQLStatement::column_bytes(int index) {
431  DCHECK(stmt_);
432  return sqlite3_column_bytes(stmt_, index);
433}
434
435int SQLStatement::column_bytes16(int index) {
436  DCHECK(stmt_);
437  return sqlite3_column_bytes16(stmt_, index);
438}
439
440double SQLStatement::column_double(int index) {
441  DCHECK(stmt_);
442  return sqlite3_column_double(stmt_, index);
443}
444
445bool SQLStatement::column_bool(int index) {
446  DCHECK(stmt_);
447  return sqlite3_column_int(stmt_, index) ? true : false;
448}
449
450int SQLStatement::column_int(int index) {
451  DCHECK(stmt_);
452  return sqlite3_column_int(stmt_, index);
453}
454
455sqlite_int64 SQLStatement::column_int64(int index) {
456  DCHECK(stmt_);
457  return sqlite3_column_int64(stmt_, index);
458}
459
460const char* SQLStatement::column_text(int index) {
461  DCHECK(stmt_);
462  return reinterpret_cast<const char*>(sqlite3_column_text(stmt_, index));
463}
464
465bool SQLStatement::column_string(int index, std::string* str) {
466  DCHECK(stmt_);
467  DCHECK(str);
468  const char* s = column_text(index);
469  str->assign(s ? s : std::string());
470  return s != NULL;
471}
472
473std::string SQLStatement::column_string(int index) {
474  std::string str;
475  column_string(index, &str);
476  return str;
477}
478
479const char16* SQLStatement::column_text16(int index) {
480  DCHECK(stmt_);
481  return static_cast<const char16*>(sqlite3_column_text16(stmt_, index));
482}
483
484bool SQLStatement::column_string16(int index, string16* str) {
485  DCHECK(stmt_);
486  DCHECK(str);
487  const char* s = column_text(index);
488  str->assign(s ? UTF8ToUTF16(s) : string16());
489  return (s != NULL);
490}
491
492string16 SQLStatement::column_string16(int index) {
493  string16 str;
494  column_string16(index, &str);
495  return str;
496}
497
498bool SQLStatement::column_wstring(int index, std::wstring* str) {
499  DCHECK(stmt_);
500  DCHECK(str);
501  const char* s = column_text(index);
502  str->assign(s ? UTF8ToWide(s) : std::wstring());
503  return (s != NULL);
504}
505
506std::wstring SQLStatement::column_wstring(int index) {
507  std::wstring wstr;
508  column_wstring(index, &wstr);
509  return wstr;
510}
511