sqlite_utils.cc revision 06741cbc25cd4227a9fba40dfd0273bfcc1a587a
1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/common/sqlite_utils.h"
6
7#include <list>
8
9#include "base/at_exit.h"
10#include "base/file_path.h"
11#include "base/lock.h"
12#include "base/logging.h"
13#include "base/singleton.h"
14#include "base/stl_util-inl.h"
15#include "base/string16.h"
16
17// The vanilla error handler implements the common fucntionality for all the
18// error handlers. Specialized error handlers are expected to only override
19// the Handler() function.
20class VanillaSQLErrorHandler : public SQLErrorHandler {
21 public:
22  VanillaSQLErrorHandler() : error_(SQLITE_OK) {
23  }
24  virtual int GetLastError() const {
25    return error_;
26  }
27 protected:
28  int error_;
29};
30
31class DebugSQLErrorHandler: public VanillaSQLErrorHandler {
32 public:
33  virtual int HandleError(int error, sqlite3* db) {
34    error_ = error;
35    NOTREACHED() << "sqlite error " << error
36                 << " db " << static_cast<void*>(db);
37    return error;
38  }
39};
40
41class ReleaseSQLErrorHandler : public VanillaSQLErrorHandler {
42 public:
43  virtual int HandleError(int error, sqlite3* db) {
44    error_ = error;
45    // Used to have a CHECK here. Got lots of crashes.
46    return error;
47  }
48};
49
50// The default error handler factory is also in charge of managing the
51// lifetime of the error objects. This object is multi-thread safe.
52class DefaultSQLErrorHandlerFactory : public SQLErrorHandlerFactory {
53 public:
54  ~DefaultSQLErrorHandlerFactory() {
55    STLDeleteContainerPointers(errors_.begin(), errors_.end());
56  }
57
58  virtual SQLErrorHandler* Make() {
59    SQLErrorHandler* handler;
60#ifndef NDEBUG
61    handler = new DebugSQLErrorHandler;
62#else
63    handler = new ReleaseSQLErrorHandler;
64#endif  // NDEBUG
65    AddHandler(handler);
66    return handler;
67  }
68
69 private:
70  void AddHandler(SQLErrorHandler* handler) {
71    AutoLock lock(lock_);
72    errors_.push_back(handler);
73  }
74
75  typedef std::list<SQLErrorHandler*> ErrorList;
76  ErrorList errors_;
77  Lock lock_;
78};
79
80SQLErrorHandlerFactory* GetErrorHandlerFactory() {
81  // TODO(cpu): Testing needs to override the error handler.
82  // Destruction of DefaultSQLErrorHandlerFactory handled by at_exit manager.
83  return Singleton<DefaultSQLErrorHandlerFactory>::get();
84}
85
86namespace sqlite_utils {
87
88int OpenSqliteDb(const FilePath& filepath, sqlite3** database) {
89#if defined(OS_WIN)
90  // We want the default encoding to always be UTF-8, so we use the
91  // 8-bit version of open().
92  return sqlite3_open(WideToUTF8(filepath.value()).c_str(), database);
93#elif defined(OS_POSIX)
94  return sqlite3_open(filepath.value().c_str(), database);
95#endif
96}
97
98bool DoesSqliteTableExist(sqlite3* db,
99                          const char* db_name,
100                          const char* table_name) {
101  // sqlite doesn't allow binding parameters as table names, so we have to
102  // manually construct the sql
103  std::string sql("SELECT name FROM ");
104  if (db_name && db_name[0]) {
105    sql.append(db_name);
106    sql.push_back('.');
107  }
108  sql.append("sqlite_master WHERE type='table' AND name=?");
109
110  SQLStatement statement;
111  if (statement.prepare(db, sql.c_str()) != SQLITE_OK)
112    return false;
113
114  if (statement.bind_text(0, table_name) != SQLITE_OK)
115    return false;
116
117  // we only care about if this matched a row, not the actual data
118  return sqlite3_step(statement.get()) == SQLITE_ROW;
119}
120
121bool DoesSqliteColumnExist(sqlite3* db,
122                           const char* database_name,
123                           const char* table_name,
124                           const char* column_name,
125                           const char* column_type) {
126  SQLStatement s;
127  std::string sql;
128  sql.append("PRAGMA ");
129  if (database_name && database_name[0]) {
130    // optional database name specified
131    sql.append(database_name);
132    sql.push_back('.');
133  }
134  sql.append("TABLE_INFO(");
135  sql.append(table_name);
136  sql.append(")");
137
138  if (s.prepare(db, sql.c_str()) != SQLITE_OK)
139    return false;
140
141  while (s.step() == SQLITE_ROW) {
142    if (!s.column_string(1).compare(column_name)) {
143      if (column_type && column_type[0])
144        return !s.column_string(2).compare(column_type);
145      return true;
146    }
147  }
148  return false;
149}
150
151bool DoesSqliteTableHaveRow(sqlite3* db, const char* table_name) {
152  SQLStatement s;
153  std::string b;
154  b.append("SELECT * FROM ");
155  b.append(table_name);
156
157  if (s.prepare(db, b.c_str()) != SQLITE_OK)
158    return false;
159
160  return s.step() == SQLITE_ROW;
161}
162
163}  // namespace sqlite_utils
164
165SQLTransaction::SQLTransaction(sqlite3* db) : db_(db), began_(false) {
166}
167
168SQLTransaction::~SQLTransaction() {
169  if (began_) {
170    Rollback();
171  }
172}
173
174int SQLTransaction::BeginCommand(const char* command) {
175  int rv = SQLITE_ERROR;
176  if (!began_ && db_) {
177    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
178    began_ = (rv == SQLITE_OK);
179  }
180  return rv;
181}
182
183int SQLTransaction::EndCommand(const char* command) {
184  int rv = SQLITE_ERROR;
185  if (began_ && db_) {
186    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
187    began_ = (rv != SQLITE_OK);
188  }
189  return rv;
190}
191
192SQLNestedTransactionSite::~SQLNestedTransactionSite() {
193  DCHECK(!top_transaction_);
194}
195
196void SQLNestedTransactionSite::SetTopTransaction(SQLNestedTransaction* top) {
197  DCHECK(!top || !top_transaction_);
198  top_transaction_ = top;
199}
200
201SQLNestedTransaction::SQLNestedTransaction(SQLNestedTransactionSite* site)
202  : SQLTransaction(site->GetSqlite3DB()),
203    needs_rollback_(false),
204    site_(site) {
205  DCHECK(site);
206  if (site->GetTopTransaction() == NULL) {
207    site->SetTopTransaction(this);
208  }
209}
210
211SQLNestedTransaction::~SQLNestedTransaction() {
212  if (began_) {
213    Rollback();
214  }
215  if (site_->GetTopTransaction() == this) {
216    site_->SetTopTransaction(NULL);
217  }
218}
219
220int SQLNestedTransaction::BeginCommand(const char* command) {
221  DCHECK(db_);
222  DCHECK(site_ && site_->GetTopTransaction());
223  if (!db_ || began_) {
224    return SQLITE_ERROR;
225  }
226  if (site_->GetTopTransaction() == this) {
227    int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
228    began_ = (rv == SQLITE_OK);
229    if (began_) {
230      site_->OnBegin();
231    }
232    return rv;
233  } else {
234    if (site_->GetTopTransaction()->needs_rollback_) {
235      return SQLITE_ERROR;
236    }
237    began_ = true;
238    return SQLITE_OK;
239  }
240}
241
242int SQLNestedTransaction::EndCommand(const char* command) {
243  DCHECK(db_);
244  DCHECK(site_ && site_->GetTopTransaction());
245  if (!db_ || !began_) {
246    return SQLITE_ERROR;
247  }
248  if (site_->GetTopTransaction() == this) {
249    if (needs_rollback_) {
250      sqlite3_exec(db_, "ROLLBACK", NULL, NULL, NULL);
251      began_ = false;  // reset so we don't try to rollback or call
252                       // OnRollback() again
253      site_->OnRollback();
254      return SQLITE_ERROR;
255    } else {
256      int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
257      began_ = (rv != SQLITE_OK);
258      if (strcmp(command, "ROLLBACK") == 0) {
259        began_ = false;  // reset so we don't try to rollbck or call
260                         // OnRollback() again
261        site_->OnRollback();
262      } else {
263        DCHECK(strcmp(command, "COMMIT") == 0);
264        if (rv == SQLITE_OK) {
265          site_->OnCommit();
266        }
267      }
268      return rv;
269    }
270  } else {
271    if (strcmp(command, "ROLLBACK") == 0) {
272      site_->GetTopTransaction()->needs_rollback_ = true;
273    }
274    began_ = false;
275    return SQLITE_OK;
276  }
277}
278
279int SQLStatement::prepare(sqlite3* db, const char* sql, int sql_len) {
280  DCHECK(!stmt_);
281  int rv = sqlite3_prepare_v2(db, sql, sql_len, &stmt_, NULL);
282  if (rv != SQLITE_OK) {
283    SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
284    return error_handler->HandleError(rv, db);
285  }
286  return rv;
287}
288
289int SQLStatement::step() {
290  DCHECK(stmt_);
291  int status = sqlite3_step(stmt_);
292  if ((status == SQLITE_ROW) || (status == SQLITE_DONE))
293    return status;
294  // We got a problem.
295  SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
296  return error_handler->HandleError(status, db_handle());
297}
298
299int SQLStatement::reset() {
300  DCHECK(stmt_);
301  return sqlite3_reset(stmt_);
302}
303
304sqlite_int64 SQLStatement::last_insert_rowid() {
305  DCHECK(stmt_);
306  return sqlite3_last_insert_rowid(db_handle());
307}
308
309int SQLStatement::changes() {
310  DCHECK(stmt_);
311  return sqlite3_changes(db_handle());
312}
313
314sqlite3* SQLStatement::db_handle() {
315  DCHECK(stmt_);
316  return sqlite3_db_handle(stmt_);
317}
318
319int SQLStatement::bind_parameter_count() {
320  DCHECK(stmt_);
321  return sqlite3_bind_parameter_count(stmt_);
322}
323
324int SQLStatement::bind_blob(int index, std::vector<unsigned char>* blob) {
325  if (blob) {
326    const void* value = blob->empty() ? NULL : &(*blob)[0];
327    int len = static_cast<int>(blob->size());
328    return bind_blob(index, value, len);
329  } else {
330    return bind_null(index);
331  }
332}
333
334int SQLStatement::bind_blob(int index, const void* value, int value_len) {
335   return bind_blob(index, value, value_len, SQLITE_TRANSIENT);
336}
337
338int SQLStatement::bind_blob(int index, const void* value, int value_len,
339                            Function dtor) {
340  DCHECK(stmt_);
341  return sqlite3_bind_blob(stmt_, index + 1, value, value_len, dtor);
342}
343
344int SQLStatement::bind_double(int index, double value) {
345  DCHECK(stmt_);
346  return sqlite3_bind_double(stmt_, index + 1, value);
347}
348
349int SQLStatement::bind_bool(int index, bool value) {
350  DCHECK(stmt_);
351  return sqlite3_bind_int(stmt_, index + 1, value);
352}
353
354int SQLStatement::bind_int(int index, int value) {
355  DCHECK(stmt_);
356  return sqlite3_bind_int(stmt_, index + 1, value);
357}
358
359int SQLStatement::bind_int64(int index, sqlite_int64 value) {
360  DCHECK(stmt_);
361  return sqlite3_bind_int64(stmt_, index + 1, value);
362}
363
364int SQLStatement::bind_null(int index) {
365  DCHECK(stmt_);
366  return sqlite3_bind_null(stmt_, index + 1);
367}
368
369int SQLStatement::bind_text(int index, const char* value, int value_len,
370              Function dtor) {
371  DCHECK(stmt_);
372  return sqlite3_bind_text(stmt_, index + 1, value, value_len, dtor);
373}
374
375int SQLStatement::bind_text16(int index, const char16* value, int value_len,
376                Function dtor) {
377  DCHECK(stmt_);
378  value_len *= sizeof(char16);
379  return sqlite3_bind_text16(stmt_, index + 1, value, value_len, dtor);
380}
381
382int SQLStatement::bind_value(int index, const sqlite3_value* value) {
383  DCHECK(stmt_);
384  return sqlite3_bind_value(stmt_, index + 1, value);
385}
386
387int SQLStatement::column_count() {
388  DCHECK(stmt_);
389  return sqlite3_column_count(stmt_);
390}
391
392int SQLStatement::column_type(int index) {
393  DCHECK(stmt_);
394  return sqlite3_column_type(stmt_, index);
395}
396
397const void* SQLStatement::column_blob(int index) {
398  DCHECK(stmt_);
399  return sqlite3_column_blob(stmt_, index);
400}
401
402bool SQLStatement::column_blob_as_vector(int index,
403                                         std::vector<unsigned char>* blob) {
404  DCHECK(stmt_);
405  const void* p = column_blob(index);
406  size_t len = column_bytes(index);
407  blob->resize(len);
408  if (blob->size() != len) {
409    return false;
410  }
411  if (len > 0)
412    memcpy(&(blob->front()), p, len);
413  return true;
414}
415
416bool SQLStatement::column_blob_as_string(int index, std::string* blob) {
417  DCHECK(stmt_);
418  const void* p = column_blob(index);
419  size_t len = column_bytes(index);
420  blob->resize(len);
421  if (blob->size() != len) {
422    return false;
423  }
424  blob->assign(reinterpret_cast<const char*>(p), len);
425  return true;
426}
427
428int SQLStatement::column_bytes(int index) {
429  DCHECK(stmt_);
430  return sqlite3_column_bytes(stmt_, index);
431}
432
433int SQLStatement::column_bytes16(int index) {
434  DCHECK(stmt_);
435  return sqlite3_column_bytes16(stmt_, index);
436}
437
438double SQLStatement::column_double(int index) {
439  DCHECK(stmt_);
440  return sqlite3_column_double(stmt_, index);
441}
442
443bool SQLStatement::column_bool(int index) {
444  DCHECK(stmt_);
445  return sqlite3_column_int(stmt_, index) ? true : false;
446}
447
448int SQLStatement::column_int(int index) {
449  DCHECK(stmt_);
450  return sqlite3_column_int(stmt_, index);
451}
452
453sqlite_int64 SQLStatement::column_int64(int index) {
454  DCHECK(stmt_);
455  return sqlite3_column_int64(stmt_, index);
456}
457
458const char* SQLStatement::column_text(int index) {
459  DCHECK(stmt_);
460  return reinterpret_cast<const char*>(sqlite3_column_text(stmt_, index));
461}
462
463bool SQLStatement::column_string(int index, std::string* str) {
464  DCHECK(stmt_);
465  DCHECK(str);
466  const char* s = column_text(index);
467  str->assign(s ? s : std::string());
468  return s != NULL;
469}
470
471std::string SQLStatement::column_string(int index) {
472  std::string str;
473  column_string(index, &str);
474  return str;
475}
476
477const char16* SQLStatement::column_text16(int index) {
478  DCHECK(stmt_);
479  return static_cast<const char16*>(sqlite3_column_text16(stmt_, index));
480}
481
482bool SQLStatement::column_string16(int index, string16* str) {
483  DCHECK(stmt_);
484  DCHECK(str);
485  const char* s = column_text(index);
486  str->assign(s ? UTF8ToUTF16(s) : string16());
487  return (s != NULL);
488}
489
490string16 SQLStatement::column_string16(int index) {
491  string16 str;
492  column_string16(index, &str);
493  return str;
494}
495
496bool SQLStatement::column_wstring(int index, std::wstring* str) {
497  DCHECK(stmt_);
498  DCHECK(str);
499  const char* s = column_text(index);
500  str->assign(s ? UTF8ToWide(s) : std::wstring());
501  return (s != NULL);
502}
503
504std::wstring SQLStatement::column_wstring(int index) {
505  std::wstring wstr;
506  column_wstring(index, &wstr);
507  return wstr;
508}
509