1 #include "sqlite3.hpp"
2
3 #include <QSqlDatabase>
4 #include <QSqlError>
5 #include <QSqlQuery>
6 #include <QStringList>
7 #include <QThread>
8 #include <QVariant>
9 #include <QAtomicInt>
10
11 #include <cassert>
12 #include <cstring>
13 #include <cstdio>
14 #include <chrono>
15 #include <limits>
16 #include <climits>
17
18 #include <mbgl/util/chrono.hpp>
19 #include <mbgl/util/logging.hpp>
20 #include <mbgl/util/optional.hpp>
21 #include <mbgl/util/string.hpp>
22 #include <mbgl/util/traits.hpp>
23
24 namespace mapbox {
25 namespace sqlite {
26
27 // https://www.sqlite.org/rescode.html#ok
28 static_assert(mbgl::underlying_type(ResultCode::OK) == 0, "error");
29 // https://www.sqlite.org/rescode.html#cantopen
30 static_assert(mbgl::underlying_type(ResultCode::CantOpen) == 14, "error");
31 // https://www.sqlite.org/rescode.html#notadb
32 static_assert(mbgl::underlying_type(ResultCode::NotADB) == 26, "error");
33
checkQueryError(const QSqlQuery & query)34 void checkQueryError(const QSqlQuery& query) {
35 QSqlError lastError = query.lastError();
36 if (lastError.type() != QSqlError::NoError) {
37 #if QT_VERSION >= 0x050300
38 throw Exception { lastError.nativeErrorCode().toInt(), lastError.text().toStdString() };
39 #else
40 throw Exception { lastError.number(), lastError.text().toStdString() };
41 #endif
42 }
43 }
44
checkDatabaseError(const QSqlDatabase & db)45 void checkDatabaseError(const QSqlDatabase &db) {
46 QSqlError lastError = db.lastError();
47 if (lastError.type() != QSqlError::NoError) {
48 #if QT_VERSION >= 0x050300
49 throw Exception { lastError.nativeErrorCode().toInt(), lastError.text().toStdString() };
50 #else
51 throw Exception { lastError.number(), lastError.text().toStdString() };
52 #endif
53 }
54 }
55
56 namespace {
incrementCounter()57 QString incrementCounter() {
58 static QAtomicInt count = 0;
59 return QString::number(count.fetchAndAddAcquire(1));
60 }
61 }
62
63 class DatabaseImpl {
64 public:
DatabaseImpl(QString connectionName_)65 DatabaseImpl(QString connectionName_)
66 : connectionName(std::move(connectionName_))
67 {
68 }
69
~DatabaseImpl()70 ~DatabaseImpl() {
71 auto db = QSqlDatabase::database(connectionName);
72 db.close();
73 checkDatabaseError(db);
74 }
75
76 void setBusyTimeout(std::chrono::milliseconds timeout);
77 void exec(const std::string& sql);
78
79 QString connectionName;
80 };
81
82 class StatementImpl {
83 public:
StatementImpl(const QString & sql,const QSqlDatabase & db)84 StatementImpl(const QString& sql, const QSqlDatabase& db) : query(db) {
85 if (!query.prepare(sql)) {
86 checkQueryError(query);
87 }
88 }
89
~StatementImpl()90 ~StatementImpl() {
91 query.clear();
92 }
93
94 QSqlQuery query;
95 int64_t lastInsertRowId = 0;
96 int64_t changes = 0;
97 };
98
99 template <typename T>
100 using optional = std::experimental::optional<T>;
101
tryOpen(const std::string & filename,int flags)102 mapbox::util::variant<Database, Exception> Database::tryOpen(const std::string &filename, int flags) {
103 if (!QSqlDatabase::drivers().contains("QSQLITE")) {
104 return Exception { ResultCode::CantOpen, "SQLite driver not found." };
105 }
106
107 QString connectionName = QString::number(uint64_t(QThread::currentThread())) + incrementCounter();
108
109 assert(!QSqlDatabase::contains(connectionName));
110 auto db = QSqlDatabase::addDatabase("QSQLITE", connectionName);
111
112 QString connectOptions = db.connectOptions();
113 if (flags & OpenFlag::ReadOnly) {
114 if (!connectOptions.isEmpty()) connectOptions.append(';');
115 connectOptions.append("QSQLITE_OPEN_READONLY");
116 }
117
118 db.setConnectOptions(connectOptions);
119 db.setDatabaseName(QString(filename.c_str()));
120
121 if (!db.open()) {
122 // Assume every error when opening the data as CANTOPEN. Qt
123 // always returns -1 for `nativeErrorCode()` on database errors.
124 return Exception { ResultCode::CantOpen, "Error opening the database." };
125 }
126
127 return Database(std::make_unique<DatabaseImpl>(connectionName));
128 }
129
open(const std::string & filename,int flags)130 Database Database::open(const std::string &filename, int flags) {
131 auto result = tryOpen(filename, flags);
132 if (result.is<Exception>()) {
133 throw result.get<Exception>();
134 } else {
135 return std::move(result.get<Database>());
136 }
137 }
138
Database(std::unique_ptr<DatabaseImpl> impl_)139 Database::Database(std::unique_ptr<DatabaseImpl> impl_)
140 : impl(std::move(impl_))
141 {}
142
Database(Database && other)143 Database::Database(Database &&other)
144 : impl(std::move(other.impl)) {
145 assert(impl);
146 }
147
operator =(Database && other)148 Database &Database::operator=(Database &&other) {
149 std::swap(impl, other.impl);
150 assert(impl);
151 return *this;
152 }
153
~Database()154 Database::~Database() {
155 }
156
setBusyTimeout(std::chrono::milliseconds timeout)157 void Database::setBusyTimeout(std::chrono::milliseconds timeout) {
158 assert(impl);
159 impl->setBusyTimeout(timeout);
160 }
161
setBusyTimeout(std::chrono::milliseconds timeout)162 void DatabaseImpl::setBusyTimeout(std::chrono::milliseconds timeout) {
163 // std::chrono::milliseconds.count() is a long and Qt will cast
164 // internally to int, so we need to make sure the limits apply.
165 std::string timeoutStr = mbgl::util::toString(timeout.count() & INT_MAX);
166
167 auto db = QSqlDatabase::database(connectionName);
168 QString connectOptions = db.connectOptions();
169 if (connectOptions.isEmpty()) {
170 if (!connectOptions.isEmpty()) connectOptions.append(';');
171 connectOptions.append("QSQLITE_BUSY_TIMEOUT=").append(QString::fromStdString(timeoutStr));
172 }
173 if (db.isOpen()) {
174 db.close();
175 }
176 db.setConnectOptions(connectOptions);
177 if (!db.open()) {
178 // Assume every error when opening the data as CANTOPEN. Qt
179 // always returns -1 for `nativeErrorCode()` on database errors.
180 throw Exception { ResultCode::CantOpen, "Error opening the database." };
181 }
182 }
183
exec(const std::string & sql)184 void Database::exec(const std::string &sql) {
185 assert(impl);
186 impl->exec(sql);
187 }
188
exec(const std::string & sql)189 void DatabaseImpl::exec(const std::string& sql) {
190 QStringList statements = QString::fromStdString(sql).split(';', QString::SkipEmptyParts);
191 statements.removeAll("\n");
192 for (QString statement : statements) {
193 if (!statement.endsWith(';')) {
194 statement.append(';');
195 }
196 QSqlQuery query(QSqlDatabase::database(connectionName));
197 query.prepare(statement);
198
199 if (!query.exec()) {
200 checkQueryError(query);
201 }
202 }
203 }
204
Statement(Database & db,const char * sql)205 Statement::Statement(Database& db, const char* sql)
206 : impl(std::make_unique<StatementImpl>(QString(sql),
207 QSqlDatabase::database(db.impl->connectionName))) {
208 assert(impl);
209 }
210
~Statement()211 Statement::~Statement() {
212 #ifndef NDEBUG
213 // Crash if we're destructing this object while we know a Query object references this.
214 assert(!used);
215 #endif
216 }
217
Query(Statement & stmt_)218 Query::Query(Statement& stmt_) : stmt(stmt_) {
219 assert(stmt.impl);
220
221 #ifndef NDEBUG
222 assert(!stmt.used);
223 stmt.used = true;
224 #endif
225 }
226
~Query()227 Query::~Query() {
228 reset();
229 clearBindings();
230
231 #ifndef NDEBUG
232 stmt.used = false;
233 #endif
234 }
235
236 template void Query::bind(int, int64_t);
237
238 template <typename T>
bind(int offset,T value)239 void Query::bind(int offset, T value) {
240 assert(stmt.impl);
241 // Field numbering starts at 0.
242 stmt.impl->query.bindValue(offset - 1, QVariant::fromValue<T>(value), QSql::In);
243 checkQueryError(stmt.impl->query);
244 }
245
246 template <>
bind(int offset,std::nullptr_t)247 void Query::bind(int offset, std::nullptr_t) {
248 assert(stmt.impl);
249 // Field numbering starts at 0.
250 stmt.impl->query.bindValue(offset - 1, QVariant(QVariant::Invalid), QSql::In);
251 checkQueryError(stmt.impl->query);
252 }
253
254 template <>
bind(int offset,int32_t value)255 void Query::bind(int offset, int32_t value) {
256 bind(offset, static_cast<int64_t>(value));
257 }
258
259 template <>
bind(int offset,bool value)260 void Query::bind(int offset, bool value) {
261 bind(offset, static_cast<int>(value));
262 }
263
264 template <>
bind(int offset,int8_t value)265 void Query::bind(int offset, int8_t value) {
266 bind(offset, static_cast<int64_t>(value));
267 }
268
269 template <>
bind(int offset,uint8_t value)270 void Query::bind(int offset, uint8_t value) {
271 bind(offset, static_cast<int64_t>(value));
272 }
273
274 template <>
bind(int offset,mbgl::Timestamp value)275 void Query::bind(int offset, mbgl::Timestamp value) {
276 bind(offset, std::chrono::system_clock::to_time_t(value));
277 }
278
279 template <>
bind(int offset,optional<std::string> value)280 void Query::bind(int offset, optional<std::string> value) {
281 if (value) {
282 bind(offset, *value);
283 } else {
284 bind(offset, nullptr);
285 }
286 }
287
288 template <>
bind(int offset,optional<mbgl::Timestamp> value)289 void Query::bind(int offset, optional<mbgl::Timestamp> value) {
290 if (value) {
291 bind(offset, *value);
292 } else {
293 bind(offset, nullptr);
294 }
295 }
296
bind(int offset,const char * value,std::size_t length,bool)297 void Query::bind(int offset, const char* value, std::size_t length, bool /* retain */) {
298 assert(stmt.impl);
299 if (length > std::numeric_limits<int>::max()) {
300 // Kept for consistence with the default implementation.
301 throw std::range_error("value too long");
302 }
303
304 // Field numbering starts at 0.
305 stmt.impl->query.bindValue(offset - 1, QString(QByteArray(value, length)), QSql::In);
306
307 checkQueryError(stmt.impl->query);
308 }
309
bind(int offset,const std::string & value,bool retain)310 void Query::bind(int offset, const std::string& value, bool retain) {
311 bind(offset, value.data(), value.size(), retain);
312 }
313
bindBlob(int offset,const void * value_,std::size_t length,bool retain)314 void Query::bindBlob(int offset, const void* value_, std::size_t length, bool retain) {
315 assert(stmt.impl);
316 const char* value = reinterpret_cast<const char*>(value_);
317 if (length > std::numeric_limits<int>::max()) {
318 // Kept for consistence with the default implementation.
319 throw std::range_error("value too long");
320 }
321
322 // Field numbering starts at 0.
323 stmt.impl->query.bindValue(offset - 1, retain ? QByteArray(value, length) :
324 QByteArray::fromRawData(value, length), QSql::In | QSql::Binary);
325
326 checkQueryError(stmt.impl->query);
327 }
328
bindBlob(int offset,const std::vector<uint8_t> & value,bool retain)329 void Query::bindBlob(int offset, const std::vector<uint8_t>& value, bool retain) {
330 bindBlob(offset, value.data(), value.size(), retain);
331 }
332
run()333 bool Query::run() {
334 assert(stmt.impl);
335
336 if (!stmt.impl->query.isValid()) {
337 if (stmt.impl->query.exec()) {
338 stmt.impl->lastInsertRowId = stmt.impl->query.lastInsertId().value<int64_t>();
339 stmt.impl->changes = stmt.impl->query.numRowsAffected();
340 } else {
341 checkQueryError(stmt.impl->query);
342 }
343 }
344
345 const bool hasNext = stmt.impl->query.next();
346 if (!hasNext) stmt.impl->query.finish();
347
348 return hasNext;
349 }
350
351 template bool Query::get(int);
352 template int Query::get(int);
353 template int64_t Query::get(int);
354 template double Query::get(int);
355
get(int offset)356 template <typename T> T Query::get(int offset) {
357 assert(stmt.impl && stmt.impl->query.isValid());
358 QVariant value = stmt.impl->query.value(offset);
359 checkQueryError(stmt.impl->query);
360 return value.value<T>();
361 }
362
get(int offset)363 template <> std::vector<uint8_t> Query::get(int offset) {
364 assert(stmt.impl && stmt.impl->query.isValid());
365 QByteArray byteArray = stmt.impl->query.value(offset).toByteArray();
366 checkQueryError(stmt.impl->query);
367 std::vector<uint8_t> blob(byteArray.begin(), byteArray.end());
368 return blob;
369 }
370
get(int offset)371 template <> mbgl::Timestamp Query::get(int offset) {
372 assert(stmt.impl && stmt.impl->query.isValid());
373 QVariant value = stmt.impl->query.value(offset);
374 checkQueryError(stmt.impl->query);
375 return std::chrono::time_point_cast<std::chrono::seconds>(
376 std::chrono::system_clock::from_time_t(value.value<::time_t>()));
377 }
378
get(int offset)379 template <> optional<int64_t> Query::get(int offset) {
380 assert(stmt.impl && stmt.impl->query.isValid());
381 QVariant value = stmt.impl->query.value(offset);
382 checkQueryError(stmt.impl->query);
383 if (value.isNull())
384 return {};
385 return { value.value<int64_t>() };
386 }
387
get(int offset)388 template <> optional<double> Query::get(int offset) {
389 assert(stmt.impl && stmt.impl->query.isValid());
390 QVariant value = stmt.impl->query.value(offset);
391 checkQueryError(stmt.impl->query);
392 if (value.isNull())
393 return {};
394 return { value.value<double>() };
395 }
396
get(int offset)397 template <> std::string Query::get(int offset) {
398 assert(stmt.impl && stmt.impl->query.isValid());
399 QByteArray value = stmt.impl->query.value(offset).toByteArray();
400 checkQueryError(stmt.impl->query);
401 return std::string(value.constData(), value.size());
402 }
403
get(int offset)404 template <> optional<std::string> Query::get(int offset) {
405 assert(stmt.impl && stmt.impl->query.isValid());
406 QByteArray value = stmt.impl->query.value(offset).toByteArray();
407 checkQueryError(stmt.impl->query);
408 if (value.isNull())
409 return {};
410 return { std::string(value.constData(), value.size()) };
411 }
412
get(int offset)413 template <> optional<mbgl::Timestamp> Query::get(int offset) {
414 assert(stmt.impl && stmt.impl->query.isValid());
415 QVariant value = stmt.impl->query.value(offset);
416 checkQueryError(stmt.impl->query);
417 if (value.isNull())
418 return {};
419 return { std::chrono::time_point_cast<mbgl::Seconds>(
420 std::chrono::system_clock::from_time_t(value.value<::time_t>())) };
421 }
422
reset()423 void Query::reset() {
424 assert(stmt.impl);
425 stmt.impl->query.finish();
426 }
427
clearBindings()428 void Query::clearBindings() {
429 // no-op
430 }
431
lastInsertRowId() const432 int64_t Query::lastInsertRowId() const {
433 assert(stmt.impl);
434 return stmt.impl->lastInsertRowId;
435 }
436
changes() const437 uint64_t Query::changes() const {
438 assert(stmt.impl);
439 return (stmt.impl->changes < 0 ? 0 : stmt.impl->changes);
440 }
441
Transaction(Database & db_,Mode mode)442 Transaction::Transaction(Database& db_, Mode mode)
443 : dbImpl(*db_.impl) {
444 switch (mode) {
445 case Deferred:
446 dbImpl.exec("BEGIN DEFERRED TRANSACTION");
447 break;
448 case Immediate:
449 dbImpl.exec("BEGIN IMMEDIATE TRANSACTION");
450 break;
451 case Exclusive:
452 dbImpl.exec("BEGIN EXCLUSIVE TRANSACTION");
453 break;
454 }
455 }
456
~Transaction()457 Transaction::~Transaction() {
458 if (needRollback) {
459 try {
460 rollback();
461 } catch (...) {
462 // Ignore failed rollbacks in destructor.
463 }
464 }
465 }
466
commit()467 void Transaction::commit() {
468 needRollback = false;
469 dbImpl.exec("COMMIT TRANSACTION");
470 }
471
rollback()472 void Transaction::rollback() {
473 needRollback = false;
474 dbImpl.exec("ROLLBACK TRANSACTION");
475 }
476
477 } // namespace sqlite
478 } // namespace mapbox
479