diff --git a/wallet/wallet.c b/wallet/wallet.c index 962bcf911..b6ec8da26 100644 --- a/wallet/wallet.c +++ b/wallet/wallet.c @@ -1601,7 +1601,7 @@ static struct channel *wallet_stmt2channel(struct wallet *w, struct db_stmt *stm chan = new_channel(peer, db_col_u64(stmt, "id"), &wshachain, - db_col_int(stmt, "state"), + channel_state_in_db(db_col_int(stmt, "state")), db_col_int(stmt, "funder"), NULL, /* Set up fresh log */ "Loaded from database", @@ -2168,7 +2168,7 @@ void wallet_channel_save(struct wallet *w, struct channel *chan) db_bind_null(stmt); db_bind_channel_id(stmt, &chan->cid); - db_bind_int(stmt, chan->state); + db_bind_int(stmt, channel_state_in_db(chan->state)); db_bind_int(stmt, chan->opener); db_bind_int(stmt, chan->channel_flags); db_bind_int(stmt, chan->minimum_depth); @@ -2352,8 +2352,8 @@ void wallet_state_change_add(struct wallet *w, db_bind_u64(stmt, channel_id); db_bind_timeabs(stmt, *timestamp); - db_bind_int(stmt, old_state); - db_bind_int(stmt, new_state); + db_bind_int(stmt, channel_state_in_db(old_state)); + db_bind_int(stmt, channel_state_in_db(new_state)); db_bind_int(stmt, state_change_in_db(cause)); db_bind_text(stmt, message); @@ -2525,7 +2525,7 @@ void wallet_channel_close(struct wallet *w, u64 wallet_id) stmt = db_prepare_v2(w->db, SQL("UPDATE channels " "SET state=? " "WHERE channels.id=?")); - db_bind_u64(stmt, CLOSED); + db_bind_u64(stmt, channel_state_in_db(CLOSED)); db_bind_u64(stmt, wallet_id); db_exec_prepared_v2(take(stmt)); } diff --git a/wallet/wallet.h b/wallet/wallet.h index b0f2c1a9a..5dfce3c38 100644 --- a/wallet/wallet.h +++ b/wallet/wallet.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -23,8 +24,6 @@ struct node_id; struct oneshot; struct peer; struct timers; -enum channel_state; -enum state_change; struct wallet { struct lightningd *ld; @@ -273,6 +272,50 @@ static inline enum htlc_state htlc_state_in_db(enum htlc_state s) fatal("%s: %u is invalid", __func__, s); } +/* DB wrapper to check channel_state */ +static inline enum channel_state channel_state_in_db(enum channel_state s) +{ + switch (s) { + case CHANNELD_AWAITING_LOCKIN: + BUILD_ASSERT(CHANNELD_AWAITING_LOCKIN == 2); + return s; + case CHANNELD_NORMAL: + BUILD_ASSERT(CHANNELD_NORMAL == 3); + return s; + case CHANNELD_SHUTTING_DOWN: + BUILD_ASSERT(CHANNELD_SHUTTING_DOWN == 4); + return s; + case CLOSINGD_SIGEXCHANGE: + BUILD_ASSERT(CLOSINGD_SIGEXCHANGE == 5); + return s; + case CLOSINGD_COMPLETE: + BUILD_ASSERT(CLOSINGD_COMPLETE == 6); + return s; + case AWAITING_UNILATERAL: + BUILD_ASSERT(AWAITING_UNILATERAL == 7); + return s; + case FUNDING_SPEND_SEEN: + BUILD_ASSERT(FUNDING_SPEND_SEEN == 8); + return s; + case ONCHAIN: + BUILD_ASSERT(ONCHAIN == 9); + return s; + case CLOSED: + BUILD_ASSERT(CLOSED == 10); + return s; + case DUALOPEND_OPEN_INIT: + BUILD_ASSERT(DUALOPEND_OPEN_INIT == 11); + return s; + case DUALOPEND_AWAITING_LOCKIN: + BUILD_ASSERT(DUALOPEND_AWAITING_LOCKIN == 12); + return s; + case CHANNELD_AWAITING_SPLICE: + BUILD_ASSERT(CHANNELD_AWAITING_SPLICE == 13); + return s; + } + fatal("%s: %u is invalid", __func__, s); +} + struct forwarding { /* channel_out is all-zero if unknown. */ struct short_channel_id channel_in, channel_out;