added them for testing

This commit is contained in:
Jimmy Song 2023-10-17 21:11:32 -05:00
parent 994358c5ea
commit 8fb382449c
5 changed files with 221 additions and 20 deletions

View File

@ -1046,3 +1046,184 @@ True
#endexercise
"""
from unittest import TestCase
import hash
import op
from ecc import (
G,
N,
PrivateKey,
S256Point,
SchnorrSignature,
)
from hash import (
hash_aux,
hash_challenge,
hash_nonce,
hash_tapbranch,
hash_tapleaf,
hash_taptweak,
sha256,
)
from helper import (
big_endian_to_int,
int_to_big_endian,
int_to_byte,
xor_bytes,
)
from op import decode_num, encode_num
from taproot import TapLeaf, TapBranch, ControlBlock
def tagged_hash(tag, msg):
tag_hash = sha256(tag)
return sha256(tag_hash + tag_hash + msg)
def xonly(self):
if self.x is None:
return b"\x00" * 32
return int_to_big_endian(self.x.num, 32)
def verify_schnorr(self, msg, schnorr_sig):
if self.parity:
point = -1 * self
else:
point = self
if schnorr_sig.r.x is None:
return False
message = schnorr_sig.r.xonly() + point.xonly() + msg
challenge = big_endian_to_int(hash_challenge(message)) % N
result = -challenge * point + schnorr_sig.s * G
if result.x is None:
return False
if result.parity:
return False
return result.xonly() == schnorr_sig.r.xonly()
def sign_schnorr(self, msg, aux=None):
if aux is None:
aux = b"\x00" * 32
if self.point.parity:
d = N - self.secret
else:
d = self.secret
if len(msg) != 32:
raise ValueError("msg needs to be 32 bytes")
if len(aux) != 32:
raise ValueError("aux needs to be 32 bytes")
t = xor_bytes(int_to_big_endian(d, 32), hash_aux(aux))
k = big_endian_to_int(hash_nonce(t + self.point.xonly() + msg)) % N
r = k * G
if r.parity:
k = N - k
r = k * G
commitment = r.xonly() + self.point.xonly() + msg
e = big_endian_to_int(hash_challenge(commitment)) % N
s = (k + e * d) % N
sig = SchnorrSignature(r, s)
if not self.point.verify_schnorr(msg, sig):
raise RuntimeError("Bad Signature")
return sig
def tweak(self, merkle_root=b""):
tweak = hash_taptweak(self.xonly() + merkle_root)
return tweak
def tweaked_key(self, merkle_root=b""):
tweak = self.tweak(merkle_root)
t = big_endian_to_int(tweak)
external_key = self + t * G
return external_key
def p2tr_script(self, merkle_root=b""):
from script import P2TRScriptPubKey
external_pubkey = self.tweaked_key(merkle_root)
return P2TRScriptPubKey(external_pubkey)
def tweaked_key_priv(self, merkle_root=b""):
tweak = self.point.tweak(merkle_root)
t = big_endian_to_int(tweak)
new_secret = (self.secret + t) % N
return self.__class__(new_secret)
def op_checksigadd_schnorr(stack, tx_obj, input_index):
if len(stack) < 3:
return False
pubkey = stack.pop()
n = decode_num(stack.pop())
signature = stack.pop()
point = S256Point.parse_xonly(pubkey)
if len(signature) == 0:
stack.append(encode_num(n))
return True
if len(signature) == 65:
hash_type = signature[-1]
signature = signature[:-1]
else:
hash_type = None
sig = SchnorrSignature.parse(signature)
msg = tx_obj.sig_hash(input_index, hash_type)
if point.verify_schnorr(msg, sig):
stack.append(encode_num(n + 1))
else:
stack.append(encode_num(n))
return True
def hash_leaf(self):
content = int_to_byte(self.tapleaf_version) + self.tap_script.serialize()
return hash_tapleaf(content)
def hash_branch(self):
left_hash = self.left.hash()
right_hash = self.right.hash()
if left_hash < right_hash:
return hash_tapbranch(left_hash + right_hash)
else:
return hash_tapbranch(right_hash + left_hash)
def merkle_root(self, tap_script):
leaf = TapLeaf(tap_script, self.tapleaf_version)
current = leaf.hash()
for h in self.hashes:
if current < h:
current = hash_tapbranch(current + h)
else:
current = hash_tapbranch(h + current)
return current
def external_pubkey(self, tap_script):
merkle_root = self.merkle_root(tap_script)
return self.internal_pubkey.tweaked_key(merkle_root)
class ATest(TestCase):
def test_apply(self):
hash.tagged_hash = tagged_hash
S256Point.xonly = xonly
S256Point.verify_schnorr = verify_schnorr
PrivateKey.sign_schnorr = sign_schnorr
S256Point.tweak = tweak
S256Point.tweaked_key = tweaked_key
S256Point.p2tr_script = p2tr_script
PrivateKey.tweaked_key = tweaked_key_priv
op.op_checksigadd_schnorr = op_checksigadd_schnorr
TapLeaf.hash = hash_leaf
TapBranch.hash = hash_branch
ControlBlock.merkle_root = merkle_root
ControlBlock.external_pubkey = external_pubkey

View File

@ -278,13 +278,13 @@ class S256Point(Point):
def tweak(self, merkle_root=b""):
"""returns the tweak for use in p2tr if there's no script path"""
# take the hash_taptweak of the xonly
# take the hash_taptweak of the xonly and the merkle root
tweak = hash_taptweak(self.xonly() + merkle_root)
return tweak
def tweaked_key(self, merkle_root=b""):
"""Creates the tweaked external key for a particular tweak."""
# Get the tweak from the tweak method
# Get the tweak from the merkle root
tweak = self.tweak(merkle_root)
# t is the tweak interpreted as big endian
t = big_endian_to_int(tweak)
@ -368,19 +368,27 @@ class S256Point(Point):
return self.verify(z, sig)
def verify_schnorr(self, msg, schnorr_sig):
# define point as self if it's even, -1 * self if odd
if self.parity:
point = -1 * self
else:
point = self
# if the sig's R is the point at infinity, return False
if schnorr_sig.r.x is None:
return False
message = schnorr_sig.r.xonly() + point.xonly() + msg
challenge = big_endian_to_int(hash_challenge(message)) % N
result = -challenge * point + schnorr_sig.s
# commitment is R||P||m use the xonly serializations
commitment = schnorr_sig.r.xonly() + point.xonly() + msg
# hash_challenge the commitment and interpret as big endian modded by N
challenge = big_endian_to_int(hash_challenge(commitment)) % N
# -hP+sG is what we want
result = -challenge * point + schnorr_sig.s * G
# make sure the resulting point is not the point at infinity
if result.x is None:
return False
# make sure the resulting point is not odd
if result.parity:
return False
# check that the xonly of the result is the same as the xonly of R
return result.xonly() == schnorr_sig.r.xonly()
@classmethod

View File

@ -78,8 +78,10 @@ class TapBranch:
self._leaves = None
def hash(self):
# get the left and right hashes
left_hash = self.left.hash()
right_hash = self.right.hash()
# use hash_tapbranch on them in alphabetical order
if left_hash < right_hash:
return hash_tapbranch(left_hash + right_hash)
else:
@ -182,6 +184,11 @@ class ControlBlock:
return cls(tapleaf_version, parity, internal_pubkey, hashes)
class TapScript(ScriptPubKey):
def tap_leaf(self):
return TapLeaf(self)
class TapRootTest(TestCase):
def test_tapleaf_hash(self):
tap_script = TapScript.parse(
@ -265,8 +272,3 @@ class TapRootTest(TestCase):
)
cb = tap_root.control_block(internal_pubkey, tap_leaf_2)
self.assertEqual(cb.serialize().hex(), hex_cb)
class TapScript(ScriptPubKey):
def tap_leaf(self):
return TapLeaf(self)

View File

@ -278,13 +278,13 @@ class S256Point(Point):
def tweak(self, merkle_root=b""):
"""returns the tweak for use in p2tr if there's no script path"""
# take the hash_taptweak of the xonly
# take the hash_taptweak of the xonly and the merkle root
tweak = hash_taptweak(self.xonly() + merkle_root)
return tweak
def tweaked_key(self, merkle_root=b""):
"""Creates the tweaked external key for a particular tweak."""
# Get the tweak from the tweak method
# Get the tweak from the merkle root
tweak = self.tweak(merkle_root)
# t is the tweak interpreted as big endian
t = big_endian_to_int(tweak)
@ -368,19 +368,27 @@ class S256Point(Point):
return self.verify(z, sig)
def verify_schnorr(self, msg, schnorr_sig):
# define point as self if it's even, -1 * self if odd
if self.parity:
point = -1 * self
else:
point = self
# if the sig's R is the point at infinity, return False
if schnorr_sig.r.x is None:
return False
message = schnorr_sig.r.xonly() + point.xonly() + msg
challenge = big_endian_to_int(hash_challenge(message)) % N
result = -challenge * point + schnorr_sig.s
# commitment is R||P||m use the xonly serializations
commitment = schnorr_sig.r.xonly() + point.xonly() + msg
# hash_challenge the commitment and interpret as big endian modded by N
challenge = big_endian_to_int(hash_challenge(commitment)) % N
# -hP+sG is what we want
result = -challenge * point + schnorr_sig.s * G
# make sure the resulting point is not the point at infinity
if result.x is None:
return False
# make sure the resulting point is not odd
if result.parity:
return False
# check that the xonly of the result is the same as the xonly of R
return result.xonly() == schnorr_sig.r.xonly()
@classmethod

View File

@ -78,8 +78,10 @@ class TapBranch:
self._leaves = None
def hash(self):
# get the left and right hashes
left_hash = self.left.hash()
right_hash = self.right.hash()
# use hash_tapbranch on them in alphabetical order
if left_hash < right_hash:
return hash_tapbranch(left_hash + right_hash)
else:
@ -182,6 +184,11 @@ class ControlBlock:
return cls(tapleaf_version, parity, internal_pubkey, hashes)
class TapScript(ScriptPubKey):
def tap_leaf(self):
return TapLeaf(self)
class TapRootTest(TestCase):
def test_tapleaf_hash(self):
tap_script = TapScript.parse(
@ -265,8 +272,3 @@ class TapRootTest(TestCase):
)
cb = tap_root.control_block(internal_pubkey, tap_leaf_2)
self.assertEqual(cb.serialize().hex(), hex_cb)
class TapScript(ScriptPubKey):
def tap_leaf(self):
return TapLeaf(self)