diff --git a/wallet/db.c b/wallet/db.c index 49fa4f92e..1b1963fd0 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -452,6 +452,24 @@ void db_set_intvar(struct db *db, char *varname, s64 val) varname, val); } +void *sqlite3_column_arr_(const tal_t *ctx, sqlite3_stmt *stmt, int col, + size_t bytes, const char *label, const char *caller) +{ + size_t sourcelen = sqlite3_column_bytes(stmt, col); + void *p; + + if (sqlite3_column_type(stmt, col) == SQLITE_NULL) + return NULL; + + if (sourcelen % bytes != 0) + fatal("%s: column size %zu not a multiple of %s (%zu)", + caller, sourcelen, label, bytes); + + p = tal_alloc_arr_(ctx, bytes, sourcelen / bytes, false, true, label); + memcpy(p, sqlite3_column_blob(stmt, col), sourcelen); + return p; +} + bool sqlite3_bind_short_channel_id(sqlite3_stmt *stmt, int col, const struct short_channel_id *id) { @@ -646,15 +664,7 @@ bool sqlite3_column_sha256_double(sqlite3_stmt *stmt, int col, struct sha256_do struct secret *sqlite3_column_secrets(const tal_t *ctx, sqlite3_stmt *stmt, int col) { - struct secret *secrets; - size_t n = sqlite3_column_bytes(stmt, col) / sizeof(*secrets); - - /* Must fit exactly */ - assert(n * sizeof(struct secret) == sqlite3_column_bytes(stmt, col)); - if (n == 0) - return NULL; - secrets = tal_arr(ctx, struct secret, n); - return memcpy(secrets, sqlite3_column_blob(stmt, col), tal_len(secrets)); + return sqlite3_column_arr(ctx, stmt, col, struct secret); } bool sqlite3_bind_sha256_double(sqlite3_stmt *stmt, int col, const struct sha256_double *p) diff --git a/wallet/db.h b/wallet/db.h index 0c8151bfb..2176b5372 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -113,6 +113,13 @@ bool db_exec_prepared_mayfail_(const char *caller, struct db *db, sqlite3_stmt *stmt); +#define sqlite3_column_arr(ctx, stmt, col, type) \ + ((type *)sqlite3_column_arr_((ctx), (stmt), (col), \ + sizeof(type), TAL_LABEL(type, "[]"), \ + __func__)) +void *sqlite3_column_arr_(const tal_t *ctx, sqlite3_stmt *stmt, int col, + size_t bytes, const char *label, const char *caller); + bool sqlite3_bind_short_channel_id(sqlite3_stmt *stmt, int col, const struct short_channel_id *id); bool sqlite3_column_short_channel_id(sqlite3_stmt *stmt, int col,