1 #include "sqlite3.hpp"
2 
3 #include <QSqlDatabase>
4 #include <QSqlError>
5 #include <QSqlQuery>
6 #include <QStringList>
7 #include <QThread>
8 #include <QVariant>
9 #include <QAtomicInt>
10 
11 #include <cassert>
12 #include <cstring>
13 #include <cstdio>
14 #include <chrono>
15 #include <limits>
16 #include <climits>
17 
18 #include <mbgl/util/chrono.hpp>
19 #include <mbgl/util/logging.hpp>
20 #include <mbgl/util/optional.hpp>
21 #include <mbgl/util/string.hpp>
22 #include <mbgl/util/traits.hpp>
23 
24 namespace mapbox {
25 namespace sqlite {
26 
27 // https://www.sqlite.org/rescode.html#ok
28 static_assert(mbgl::underlying_type(ResultCode::OK) == 0, "error");
29 // https://www.sqlite.org/rescode.html#cantopen
30 static_assert(mbgl::underlying_type(ResultCode::CantOpen) == 14, "error");
31 // https://www.sqlite.org/rescode.html#notadb
32 static_assert(mbgl::underlying_type(ResultCode::NotADB) == 26, "error");
33 
checkQueryError(const QSqlQuery & query)34 void checkQueryError(const QSqlQuery& query) {
35     QSqlError lastError = query.lastError();
36     if (lastError.type() != QSqlError::NoError) {
37 #if QT_VERSION >= 0x050300
38         throw Exception { lastError.nativeErrorCode().toInt(), lastError.text().toStdString() };
39 #else
40         throw Exception { lastError.number(), lastError.text().toStdString() };
41 #endif
42     }
43 }
44 
checkDatabaseError(const QSqlDatabase & db)45 void checkDatabaseError(const QSqlDatabase &db) {
46     QSqlError lastError = db.lastError();
47     if (lastError.type() != QSqlError::NoError) {
48 #if QT_VERSION >= 0x050300
49         throw Exception { lastError.nativeErrorCode().toInt(), lastError.text().toStdString() };
50 #else
51         throw Exception { lastError.number(), lastError.text().toStdString() };
52 #endif
53     }
54 }
55 
56 namespace {
incrementCounter()57     QString incrementCounter() {
58         static QAtomicInt count = 0;
59         return QString::number(count.fetchAndAddAcquire(1));
60     }
61 }
62 
63 class DatabaseImpl {
64 public:
DatabaseImpl(QString connectionName_)65     DatabaseImpl(QString connectionName_)
66         : connectionName(std::move(connectionName_))
67     {
68     }
69 
~DatabaseImpl()70     ~DatabaseImpl() {
71         auto db = QSqlDatabase::database(connectionName);
72         db.close();
73         checkDatabaseError(db);
74     }
75 
76     void setBusyTimeout(std::chrono::milliseconds timeout);
77     void exec(const std::string& sql);
78 
79     QString connectionName;
80 };
81 
82 class StatementImpl {
83 public:
StatementImpl(const QString & sql,const QSqlDatabase & db)84     StatementImpl(const QString& sql, const QSqlDatabase& db) : query(db) {
85         if (!query.prepare(sql)) {
86             checkQueryError(query);
87         }
88     }
89 
~StatementImpl()90     ~StatementImpl() {
91         query.clear();
92     }
93 
94     QSqlQuery query;
95     int64_t lastInsertRowId = 0;
96     int64_t changes = 0;
97 };
98 
99 template <typename T>
100 using optional = std::experimental::optional<T>;
101 
tryOpen(const std::string & filename,int flags)102 mapbox::util::variant<Database, Exception> Database::tryOpen(const std::string &filename, int flags) {
103     if (!QSqlDatabase::drivers().contains("QSQLITE")) {
104         return Exception { ResultCode::CantOpen, "SQLite driver not found." };
105     }
106 
107     QString connectionName = QString::number(uint64_t(QThread::currentThread())) + incrementCounter();
108 
109     assert(!QSqlDatabase::contains(connectionName));
110     auto db = QSqlDatabase::addDatabase("QSQLITE", connectionName);
111 
112     QString connectOptions = db.connectOptions();
113     if (flags & OpenFlag::ReadOnly) {
114         if (!connectOptions.isEmpty()) connectOptions.append(';');
115         connectOptions.append("QSQLITE_OPEN_READONLY");
116     }
117 
118     db.setConnectOptions(connectOptions);
119     db.setDatabaseName(QString(filename.c_str()));
120 
121     if (!db.open()) {
122         // Assume every error when opening the data as CANTOPEN. Qt
123         // always returns -1 for `nativeErrorCode()` on database errors.
124         return Exception { ResultCode::CantOpen, "Error opening the database." };
125     }
126 
127     return Database(std::make_unique<DatabaseImpl>(connectionName));
128 }
129 
open(const std::string & filename,int flags)130 Database Database::open(const std::string &filename, int flags) {
131     auto result = tryOpen(filename, flags);
132     if (result.is<Exception>()) {
133         throw result.get<Exception>();
134     } else {
135         return std::move(result.get<Database>());
136     }
137 }
138 
Database(std::unique_ptr<DatabaseImpl> impl_)139 Database::Database(std::unique_ptr<DatabaseImpl> impl_)
140     : impl(std::move(impl_))
141 {}
142 
Database(Database && other)143 Database::Database(Database &&other)
144         : impl(std::move(other.impl)) {
145     assert(impl);
146 }
147 
operator =(Database && other)148 Database &Database::operator=(Database &&other) {
149     std::swap(impl, other.impl);
150     assert(impl);
151     return *this;
152 }
153 
~Database()154 Database::~Database() {
155 }
156 
setBusyTimeout(std::chrono::milliseconds timeout)157 void Database::setBusyTimeout(std::chrono::milliseconds timeout) {
158     assert(impl);
159     impl->setBusyTimeout(timeout);
160 }
161 
setBusyTimeout(std::chrono::milliseconds timeout)162 void DatabaseImpl::setBusyTimeout(std::chrono::milliseconds timeout) {
163     // std::chrono::milliseconds.count() is a long and Qt will cast
164     // internally to int, so we need to make sure the limits apply.
165     std::string timeoutStr = mbgl::util::toString(timeout.count() & INT_MAX);
166 
167     auto db = QSqlDatabase::database(connectionName);
168     QString connectOptions = db.connectOptions();
169     if (connectOptions.isEmpty()) {
170         if (!connectOptions.isEmpty()) connectOptions.append(';');
171         connectOptions.append("QSQLITE_BUSY_TIMEOUT=").append(QString::fromStdString(timeoutStr));
172     }
173     if (db.isOpen()) {
174         db.close();
175     }
176     db.setConnectOptions(connectOptions);
177     if (!db.open()) {
178         // Assume every error when opening the data as CANTOPEN. Qt
179         // always returns -1 for `nativeErrorCode()` on database errors.
180         throw Exception { ResultCode::CantOpen, "Error opening the database." };
181     }
182 }
183 
exec(const std::string & sql)184 void Database::exec(const std::string &sql) {
185     assert(impl);
186     impl->exec(sql);
187 }
188 
exec(const std::string & sql)189 void DatabaseImpl::exec(const std::string& sql) {
190     QStringList statements = QString::fromStdString(sql).split(';', QString::SkipEmptyParts);
191     statements.removeAll("\n");
192     for (QString statement : statements) {
193         if (!statement.endsWith(';')) {
194             statement.append(';');
195         }
196         QSqlQuery query(QSqlDatabase::database(connectionName));
197         query.prepare(statement);
198 
199         if (!query.exec()) {
200             checkQueryError(query);
201         }
202     }
203 }
204 
Statement(Database & db,const char * sql)205 Statement::Statement(Database& db, const char* sql)
206     : impl(std::make_unique<StatementImpl>(QString(sql),
207                                            QSqlDatabase::database(db.impl->connectionName))) {
208     assert(impl);
209 }
210 
~Statement()211 Statement::~Statement() {
212 #ifndef NDEBUG
213     // Crash if we're destructing this object while we know a Query object references this.
214     assert(!used);
215 #endif
216 }
217 
Query(Statement & stmt_)218 Query::Query(Statement& stmt_) : stmt(stmt_) {
219     assert(stmt.impl);
220 
221 #ifndef NDEBUG
222     assert(!stmt.used);
223     stmt.used = true;
224 #endif
225 }
226 
~Query()227 Query::~Query() {
228     reset();
229     clearBindings();
230 
231 #ifndef NDEBUG
232     stmt.used = false;
233 #endif
234 }
235 
236 template void Query::bind(int, int64_t);
237 
238 template <typename T>
bind(int offset,T value)239 void Query::bind(int offset, T value) {
240     assert(stmt.impl);
241     // Field numbering starts at 0.
242     stmt.impl->query.bindValue(offset - 1, QVariant::fromValue<T>(value), QSql::In);
243     checkQueryError(stmt.impl->query);
244 }
245 
246 template <>
bind(int offset,std::nullptr_t)247 void Query::bind(int offset, std::nullptr_t) {
248     assert(stmt.impl);
249     // Field numbering starts at 0.
250     stmt.impl->query.bindValue(offset - 1, QVariant(QVariant::Invalid), QSql::In);
251     checkQueryError(stmt.impl->query);
252 }
253 
254 template <>
bind(int offset,int32_t value)255 void Query::bind(int offset, int32_t value) {
256     bind(offset, static_cast<int64_t>(value));
257 }
258 
259 template <>
bind(int offset,bool value)260 void Query::bind(int offset, bool value) {
261     bind(offset, static_cast<int>(value));
262 }
263 
264 template <>
bind(int offset,int8_t value)265 void Query::bind(int offset, int8_t value) {
266     bind(offset, static_cast<int64_t>(value));
267 }
268 
269 template <>
bind(int offset,uint8_t value)270 void Query::bind(int offset, uint8_t value) {
271     bind(offset, static_cast<int64_t>(value));
272 }
273 
274 template <>
bind(int offset,mbgl::Timestamp value)275 void Query::bind(int offset, mbgl::Timestamp value) {
276     bind(offset, std::chrono::system_clock::to_time_t(value));
277 }
278 
279 template <>
bind(int offset,optional<std::string> value)280 void Query::bind(int offset, optional<std::string> value) {
281     if (value) {
282         bind(offset, *value);
283     } else {
284         bind(offset, nullptr);
285     }
286 }
287 
288 template <>
bind(int offset,optional<mbgl::Timestamp> value)289 void Query::bind(int offset, optional<mbgl::Timestamp> value) {
290     if (value) {
291         bind(offset, *value);
292     } else {
293         bind(offset, nullptr);
294     }
295 }
296 
bind(int offset,const char * value,std::size_t length,bool)297 void Query::bind(int offset, const char* value, std::size_t length, bool /* retain */) {
298     assert(stmt.impl);
299     if (length > std::numeric_limits<int>::max()) {
300         // Kept for consistence with the default implementation.
301         throw std::range_error("value too long");
302     }
303 
304     // Field numbering starts at 0.
305     stmt.impl->query.bindValue(offset - 1, QString(QByteArray(value, length)), QSql::In);
306 
307     checkQueryError(stmt.impl->query);
308 }
309 
bind(int offset,const std::string & value,bool retain)310 void Query::bind(int offset, const std::string& value, bool retain) {
311     bind(offset, value.data(), value.size(), retain);
312 }
313 
bindBlob(int offset,const void * value_,std::size_t length,bool retain)314 void Query::bindBlob(int offset, const void* value_, std::size_t length, bool retain) {
315     assert(stmt.impl);
316     const char* value = reinterpret_cast<const char*>(value_);
317     if (length > std::numeric_limits<int>::max()) {
318         // Kept for consistence with the default implementation.
319         throw std::range_error("value too long");
320     }
321 
322     // Field numbering starts at 0.
323     stmt.impl->query.bindValue(offset - 1, retain ? QByteArray(value, length) :
324             QByteArray::fromRawData(value, length), QSql::In | QSql::Binary);
325 
326     checkQueryError(stmt.impl->query);
327 }
328 
bindBlob(int offset,const std::vector<uint8_t> & value,bool retain)329 void Query::bindBlob(int offset, const std::vector<uint8_t>& value, bool retain) {
330     bindBlob(offset, value.data(), value.size(), retain);
331 }
332 
run()333 bool Query::run() {
334     assert(stmt.impl);
335 
336     if (!stmt.impl->query.isValid()) {
337        if (stmt.impl->query.exec()) {
338            stmt.impl->lastInsertRowId = stmt.impl->query.lastInsertId().value<int64_t>();
339            stmt.impl->changes = stmt.impl->query.numRowsAffected();
340        } else {
341            checkQueryError(stmt.impl->query);
342        }
343     }
344 
345     const bool hasNext = stmt.impl->query.next();
346     if (!hasNext) stmt.impl->query.finish();
347 
348     return hasNext;
349 }
350 
351 template bool Query::get(int);
352 template int Query::get(int);
353 template int64_t Query::get(int);
354 template double Query::get(int);
355 
get(int offset)356 template <typename T> T Query::get(int offset) {
357     assert(stmt.impl && stmt.impl->query.isValid());
358     QVariant value = stmt.impl->query.value(offset);
359     checkQueryError(stmt.impl->query);
360     return value.value<T>();
361 }
362 
get(int offset)363 template <> std::vector<uint8_t> Query::get(int offset) {
364     assert(stmt.impl && stmt.impl->query.isValid());
365     QByteArray byteArray = stmt.impl->query.value(offset).toByteArray();
366     checkQueryError(stmt.impl->query);
367     std::vector<uint8_t> blob(byteArray.begin(), byteArray.end());
368     return blob;
369 }
370 
get(int offset)371 template <> mbgl::Timestamp Query::get(int offset) {
372     assert(stmt.impl && stmt.impl->query.isValid());
373     QVariant value = stmt.impl->query.value(offset);
374     checkQueryError(stmt.impl->query);
375     return std::chrono::time_point_cast<std::chrono::seconds>(
376         std::chrono::system_clock::from_time_t(value.value<::time_t>()));
377 }
378 
get(int offset)379 template <> optional<int64_t> Query::get(int offset) {
380     assert(stmt.impl && stmt.impl->query.isValid());
381     QVariant value = stmt.impl->query.value(offset);
382     checkQueryError(stmt.impl->query);
383     if (value.isNull())
384         return {};
385     return { value.value<int64_t>() };
386 }
387 
get(int offset)388 template <> optional<double> Query::get(int offset) {
389     assert(stmt.impl && stmt.impl->query.isValid());
390     QVariant value = stmt.impl->query.value(offset);
391     checkQueryError(stmt.impl->query);
392     if (value.isNull())
393         return {};
394     return { value.value<double>() };
395 }
396 
get(int offset)397 template <> std::string Query::get(int offset) {
398     assert(stmt.impl && stmt.impl->query.isValid());
399     QByteArray value = stmt.impl->query.value(offset).toByteArray();
400     checkQueryError(stmt.impl->query);
401     return std::string(value.constData(), value.size());
402 }
403 
get(int offset)404 template <> optional<std::string> Query::get(int offset) {
405     assert(stmt.impl && stmt.impl->query.isValid());
406     QByteArray value = stmt.impl->query.value(offset).toByteArray();
407     checkQueryError(stmt.impl->query);
408     if (value.isNull())
409         return {};
410     return { std::string(value.constData(), value.size()) };
411 }
412 
get(int offset)413 template <> optional<mbgl::Timestamp> Query::get(int offset) {
414     assert(stmt.impl && stmt.impl->query.isValid());
415     QVariant value = stmt.impl->query.value(offset);
416     checkQueryError(stmt.impl->query);
417     if (value.isNull())
418         return {};
419     return { std::chrono::time_point_cast<mbgl::Seconds>(
420         std::chrono::system_clock::from_time_t(value.value<::time_t>())) };
421 }
422 
reset()423 void Query::reset() {
424     assert(stmt.impl);
425     stmt.impl->query.finish();
426 }
427 
clearBindings()428 void Query::clearBindings() {
429     // no-op
430 }
431 
lastInsertRowId() const432 int64_t Query::lastInsertRowId() const {
433     assert(stmt.impl);
434     return stmt.impl->lastInsertRowId;
435 }
436 
changes() const437 uint64_t Query::changes() const {
438     assert(stmt.impl);
439     return (stmt.impl->changes < 0 ? 0 : stmt.impl->changes);
440 }
441 
Transaction(Database & db_,Mode mode)442 Transaction::Transaction(Database& db_, Mode mode)
443     : dbImpl(*db_.impl) {
444     switch (mode) {
445     case Deferred:
446         dbImpl.exec("BEGIN DEFERRED TRANSACTION");
447         break;
448     case Immediate:
449         dbImpl.exec("BEGIN IMMEDIATE TRANSACTION");
450         break;
451     case Exclusive:
452         dbImpl.exec("BEGIN EXCLUSIVE TRANSACTION");
453         break;
454     }
455 }
456 
~Transaction()457 Transaction::~Transaction() {
458     if (needRollback) {
459         try {
460             rollback();
461         } catch (...) {
462             // Ignore failed rollbacks in destructor.
463         }
464     }
465 }
466 
commit()467 void Transaction::commit() {
468     needRollback = false;
469     dbImpl.exec("COMMIT TRANSACTION");
470 }
471 
rollback()472 void Transaction::rollback() {
473     needRollback = false;
474     dbImpl.exec("ROLLBACK TRANSACTION");
475 }
476 
477 } // namespace sqlite
478 } // namespace mapbox
479