wallet: wrap htlc_state enum in db function.

All enums in the db should be wrapped this way on reading/writing them.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2021-10-13 14:12:43 +10:30 committed by Christian Decker
parent 2ab4e5b42b
commit bdaec48400
2 changed files with 76 additions and 6 deletions

View File

@ -968,7 +968,7 @@ static struct fee_states *wallet_channel_fee_states_load(struct wallet *w,
/* Start with blank slate. */
fee_states = new_fee_states(w, opener, NULL);
while (db_step(stmt)) {
enum htlc_state hstate = db_column_int(stmt, 0);
enum htlc_state hstate = htlc_state_in_db(db_column_int(stmt, 0));
u32 feerate = db_column_int(stmt, 1);
if (fee_states->feerate[hstate] != NULL) {
@ -1004,7 +1004,7 @@ static struct height_states *wallet_channel_height_states_load(struct wallet *w,
/* Start with blank slate. */
states = new_height_states(w, opener, NULL);
while (db_step(stmt)) {
enum htlc_state hstate = db_column_int(stmt, 0);
enum htlc_state hstate = htlc_state_in_db(db_column_int(stmt, 0));
u32 blockheight = db_column_int(stmt, 1);
if (states->height[hstate] != NULL) {
@ -1936,7 +1936,7 @@ void wallet_channel_save(struct wallet *w, struct channel *chan)
stmt = db_prepare_v2(w->db, SQL("INSERT INTO channel_feerates "
" VALUES(?, ?, ?)"));
db_bind_u64(stmt, 0, chan->dbid);
db_bind_int(stmt, 1, i);
db_bind_int(stmt, 1, htlc_state_in_db(i));
db_bind_int(stmt, 2, *chan->fee_states->feerate[i]);
db_exec_prepared_v2(take(stmt));
}
@ -1955,7 +1955,7 @@ void wallet_channel_save(struct wallet *w, struct channel *chan)
stmt = db_prepare_v2(w->db, SQL("INSERT INTO channel_blockheights "
" VALUES(?, ?, ?)"));
db_bind_u64(stmt, 0, chan->dbid);
db_bind_int(stmt, 1, i);
db_bind_int(stmt, 1, htlc_state_in_db(i));
db_bind_int(stmt, 2, *chan->blockheight_states->height[i]);
db_exec_prepared_v2(take(stmt));
}
@ -2429,8 +2429,7 @@ void wallet_htlc_update(struct wallet *wallet, const u64 htlc_dbid,
"we_filled=?"
" WHERE id=?"));
/* FIXME: htlc_state_in_db */
db_bind_int(stmt, 0, new_state);
db_bind_int(stmt, 0, htlc_state_in_db(new_state));
db_bind_u64(stmt, 6, htlc_dbid);
if (payment_key)

View File

@ -159,6 +159,77 @@ static inline const char* forward_status_name(enum forward_status status)
bool string_to_forward_status(const char *status_str, enum forward_status *status);
/* DB wrapper to check htlc_state */
static inline enum htlc_state htlc_state_in_db(enum htlc_state s)
{
switch (s) {
case SENT_ADD_HTLC:
BUILD_ASSERT(SENT_ADD_HTLC == 0);
return s;
case SENT_ADD_COMMIT:
BUILD_ASSERT(SENT_ADD_COMMIT == 1);
return s;
case RCVD_ADD_REVOCATION:
BUILD_ASSERT(RCVD_ADD_REVOCATION == 2);
return s;
case RCVD_ADD_ACK_COMMIT:
BUILD_ASSERT(RCVD_ADD_ACK_COMMIT == 3);
return s;
case SENT_ADD_ACK_REVOCATION:
BUILD_ASSERT(SENT_ADD_ACK_REVOCATION == 4);
return s;
case RCVD_REMOVE_HTLC:
BUILD_ASSERT(RCVD_REMOVE_HTLC == 5);
return s;
case RCVD_REMOVE_COMMIT:
BUILD_ASSERT(RCVD_REMOVE_COMMIT == 6);
return s;
case SENT_REMOVE_REVOCATION:
BUILD_ASSERT(SENT_REMOVE_REVOCATION == 7);
return s;
case SENT_REMOVE_ACK_COMMIT:
BUILD_ASSERT(SENT_REMOVE_ACK_COMMIT == 8);
return s;
case RCVD_REMOVE_ACK_REVOCATION:
BUILD_ASSERT(RCVD_REMOVE_ACK_REVOCATION == 9);
return s;
case RCVD_ADD_HTLC:
BUILD_ASSERT(RCVD_ADD_HTLC == 10);
return s;
case RCVD_ADD_COMMIT:
BUILD_ASSERT(RCVD_ADD_COMMIT == 11);
return s;
case SENT_ADD_REVOCATION:
BUILD_ASSERT(SENT_ADD_REVOCATION == 12);
return s;
case SENT_ADD_ACK_COMMIT:
BUILD_ASSERT(SENT_ADD_ACK_COMMIT == 13);
return s;
case RCVD_ADD_ACK_REVOCATION:
BUILD_ASSERT(RCVD_ADD_ACK_REVOCATION == 14);
return s;
case SENT_REMOVE_HTLC:
BUILD_ASSERT(SENT_REMOVE_HTLC == 15);
return s;
case SENT_REMOVE_COMMIT:
BUILD_ASSERT(SENT_REMOVE_COMMIT == 16);
return s;
case RCVD_REMOVE_REVOCATION:
BUILD_ASSERT(RCVD_REMOVE_REVOCATION == 17);
return s;
case RCVD_REMOVE_ACK_COMMIT:
BUILD_ASSERT(RCVD_REMOVE_ACK_COMMIT == 18);
return s;
case SENT_REMOVE_ACK_REVOCATION:
BUILD_ASSERT(SENT_REMOVE_ACK_REVOCATION == 19);
return s;
case HTLC_STATE_INVALID:
/* Not in db! */
break;
}
fatal("%s: %u is invalid", __func__, s);
}
struct forwarding {
struct short_channel_id channel_in, channel_out;
struct amount_msat msat_in, msat_out, fee;