deployd/pow.py
Arija A. 90724afb2d
Add multithreading to pow.py
Signed-off-by: Arija A. <ari@ari.lt>
2025-07-27 02:05:49 +03:00

254 lines
6.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""proof of work"""
import concurrent.futures
import hashlib
import socket
import ssl
# import os
import struct
import sys
import threading
import time
import typing as t
from warnings import filterwarnings as filter_warnings
def print_log_line(log: bytes) -> bool:
"""Prints a log line returning whether it has finished"""
if log[0] == 0x21:
sys.stdout.write("\n")
print("[log stream finished. Ready for a new request]")
sys.stdout.flush()
return True
sys.stdout.write(log[3:].decode("ascii"))
return False
def has_leading_zero_bits(h: bytes, difficulty: int) -> bool:
"""Check bit prefix requirements"""
full_bytes: int = difficulty // 8
remainder_bits: int = difficulty % 8
if h[:full_bytes] != b"\x00" * full_bytes:
return False
if remainder_bits > 0:
next_byte: int = h[full_bytes]
mask: int = 0xFF << (8 - remainder_bits) & 0xFF
if (next_byte & mask) != 0:
return False
return True
def count_ones(num: int) -> int:
"""Fast bit-counting using Hacker's Delight method"""
num = (num - ((num >> 1) & 0x55)) & 0xFF
num = ((num & 0x33) + ((num >> 2) & 0x33)) & 0xFF
return (num + (num >> 4)) & 0x0F
def has_xor_ones(h: bytes, ones: int, challenge: bytes) -> bool:
"""Check XOR requirements"""
ok_ones: int = 0
for idx in range(32):
if count_ones(challenge[idx % 16] ^ h[idx % 32]) >= 6:
ok_ones += 1
if ok_ones == ones:
return True
return False
def solve_pow_worker(
difficulty: int,
ones: int,
challenge: bytes,
start_nonce: int,
end_nonce: int,
stop_event: threading.Event,
) -> t.Tuple[int, bool]:
"""solve_pow() worker"""
for nonce in range(start_nonce, end_nonce):
if stop_event.is_set():
return nonce, False
h: bytes = hashlib.blake2s(str(nonce).encode("ascii"), key=challenge).digest()
if has_leading_zero_bits(h, difficulty) and has_xor_ones(h, ones, challenge):
stop_event.set()
return nonce, True
return end_nonce, False
def solve_pow(
difficulty: int,
ones: int,
challenge: bytes,
nonce: int = 0,
batch: int = 2**64 - 1,
threads: int = 8,
) -> t.Tuple[int, bool]:
"""Solve proof of work"""
stop_event = threading.Event()
batch_per_thread = batch // threads
futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
for i in range(threads):
start_nonce = nonce + i * batch_per_thread
# Make sure last thread covers any leftover nonces
end_nonce = (
nonce + ((i + 1) * batch_per_thread)
if i != threads - 1
else nonce + batch
)
futures.append(
executor.submit(
solve_pow_worker,
difficulty,
ones,
challenge,
start_nonce,
end_nonce,
stop_event,
)
)
for future in concurrent.futures.as_completed(futures):
found_nonce, valid = future.result()
if valid:
return found_nonce, True
return nonce + batch, False
def read_server_info(socket: ssl.SSLSocket) -> None:
"""Read server info"""
version: int = socket.read(1)[0]
size: int = socket.read(1)[0]
print("Server info:", socket.read(size), "version", version)
def do_pow(socket: ssl.SSLSocket) -> None:
"""Do proof-of-work request"""
challenge: bytes = socket.read(16)
difficulty: int = socket.read(1)[0]
ones: int = socket.read(1)[0]
print(f"Calculating PoW with difficulty {difficulty} and {ones} ones")
print(f"Challenge: {challenge.hex()}")
nonce, solved = 0, False
sys.stdout.write("Nonce progress: 0\r")
while not solved:
socket.send(b"\x10")
if socket.read(1) != b"\x11":
socket.close()
raise ValueError("Something bad happened with PoW ping")
nonce, solved = solve_pow(difficulty, ones, challenge, nonce, 512 * 1024)
sys.stdout.write(f"Nonce progress: {nonce}" + (" " * 20) + "\r")
sys.stdout.write("\n")
print(f"\rPoW solved: {nonce}" + (" " * 20))
socket.send(b"\x13")
socket.send(struct.pack("<Q", nonce))
if socket.read(1) != b"\x12":
socket.close()
raise ValueError("We were not allowed")
def gen_token(secret: bytes, timestamp: int) -> bytes:
"""Generate a token"""
return hashlib.blake2s(
struct.pack("<Q", int((time.time() - timestamp) / 300)),
digest_size=16,
key=secret,
).digest()
def trigger_deploy(socket: ssl.SSLSocket) -> None:
"""Trigger a deploy"""
domain: str = "ari.lt"
packet: bytes = b""
packet += struct.pack("<B", 0x00) # proto COMMAND
packet += struct.pack("<B", 0x00) # command TRIGGER
packet += struct.pack("<B", 0x00) # is_unsafe = false
packet += struct.pack("<Q", 0x01) # ID
packet += struct.pack("<B", len(domain)) # domain_len
packet += domain.encode("ascii") # domain
packet += b"Mjhu8E1-WDvRf7mchNl30z9TQ2lDk4_6" # key
packet += gen_token(b"po6Dpd6rijdbvSbmEZMd-ZbX", 1753303723) # token
socket.send(packet)
while True:
log: bytes = socket.read(1024)
if print_log_line(log):
break
def exit_packet(socket: ssl.SSLSocket) -> None:
"""Send an exit packet"""
socket.send(b"\x30")
def main() -> int:
"""entry / main function"""
# print("Trying PoW...")
# t0: float = time.time()
# print(solve_pow(20, 8, os.urandom(32)))
# print(f"Took {time.time() - t0}s")
context: ssl.SSLContext = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ssl_sock: ssl.SSLSocket = context.wrap_socket(sock, server_hostname="127.0.0.1")
ssl_sock.connect(("127.0.0.1", 1234))
ssl_sock.settimeout(5)
read_server_info(ssl_sock)
do_pow(ssl_sock)
trigger_deploy(ssl_sock)
exit_packet(ssl_sock)
ssl_sock.close()
return 0
if __name__ == "__main__":
assert main.__annotations__.get("return") is int, "main() should return an integer"
filter_warnings("error", category=Warning)
raise SystemExit(main())