deployd/server/proto.c
Arija A. d1d0aec439
Add dproto documentation
Signed-off-by: Arija A. <ari@ari.lt>
2025-07-24 23:55:42 +03:00

378 lines
12 KiB
C

#include "include/conf.h"
#include "include/auth.h"
#include "include/proto.h"
#include "include/deploys.h"
#include <string.h>
const char *dp_PacketError_to_str(dp_PacketError code) {
switch (code) {
case dp_PacketError_internal: return "internal";
case dp_PacketError_type: return "invalid packet type";
case dp_PacketError_status: return "invalid/unexpected status packet";
case dp_PacketError_auth_token:
return "invalid/expired authentication token";
case dp_PacketError_auth_key: return "invalid authentication key";
case dp_PacketError_proto_packet_too_short:
return "protocol packet too short";
case dp_PacketError_proto_domain_invalid:
return "protocol domain too short";
case dp_PacketError_proto_packet_too_long:
return "protocol packet too long";
case dp_PacketError_proto_domain_not_found:
return "protocol domain not found";
case dp_PacketError_proto_packet_invalid:
return "protocol packet is invalid";
case dp_PacketError_pow_too_many_pings:
return "too many Proof-of-Work pings";
case dp_PacketError_pow_bad_solution:
return "bad Proof-of-Work solution";
case dp_PacketError_deploy_error: return "deploy";
case dp_PacketError_invalid_command: return "invalid deploy command";
}
return "unknown error";
}
bool dp_proto_command(SSL *ssl,
const uint8_t *packet,
int bytes,
sqlite3 *database,
dp_Logger *logger) {
if (!ssl || !packet || 0 == bytes || !database || !logger) {
return false;
}
/*
* Packet:
* command (1 byte)
* is_unsafe (1 byte)
* id (8 bytes, little-endian 64 bit unsigned integer)
* domain_len (1 byte, value - up to DP_DOMAIN_LEN)
* domain (domain_len bytes)
* key (DP_AUTH_KEY_SIZE bytes, always static)
* token (DP_AUTH_TOKEN_SIZE bytes, changing)
*/
if (bytes < (1 + 1 + 8 + 1 + 1 + DP_AUTH_KEY_SIZE + DP_AUTH_TOKEN_SIZE)) {
dp_proto_error(ssl, dp_PacketError_proto_packet_too_short,
"Supplied packet is too short for a COMMAND packet.",
logger);
return false;
}
const uint8_t *end = packet + bytes;
/* 1. Read command (1 byte) */
const dp_DeployCommand command = (dp_DeployCommand)*packet;
++packet;
switch (command) {
case dp_DeployCommand_trigger:
case dp_DeployCommand_teardown:
case dp_DeployCommand_deploy:
case dp_DeployCommand_rollback:
case dp_DeployCommand_cleanup:
case dp_DeployCommand_restart:
case dp_DeployCommand_sysadmin:
case dp_DeployCommand_logs: break;
default:
dp_proto_error(ssl, dp_PacketError_invalid_command,
"No such command found", logger);
return false;
}
/* 2. Read is_unsafe (1 byte) */
const bool is_unsafe = (*packet == 0x01 ? true : false);
++packet;
/* 3. Read token ID (8 bytes, little endian) */
const uint64_t tid = dp_buf2u64_le(packet);
packet += 8;
/* 4. Read domain_len (1 byte) */
const uint8_t domain_len = *packet;
packet += 1;
/* Check domain_len validity and total length */
if (domain_len == 0 || domain_len > DP_DOMAIN_LEN) {
dp_proto_error(ssl, dp_PacketError_proto_domain_invalid,
"Supplied domain is invalid", logger);
return false;
}
if (bytes <
1 + 1 + 8 + 1 + domain_len + DP_AUTH_KEY_SIZE + DP_AUTH_TOKEN_SIZE) {
dp_proto_error(ssl, dp_PacketError_proto_packet_too_short,
"Supplied packet is too short for the claimed "
"domain_len in COMMAND.",
logger);
return false;
}
/* 5. Read domain string (domain_len bytes) */
const uint8_t *domain = packet;
packet += domain_len;
/* 6. Read key (DP_AUTH_KEY_SIZE bytes) */
const uint8_t *key = packet;
packet += DP_AUTH_KEY_SIZE;
/* 7. Read token (DP_AUTH_TOKEN_SIZE bytes) */
const uint8_t *token = packet;
packet += DP_AUTH_TOKEN_SIZE;
if (packet != end) {
dp_proto_error(ssl, dp_PacketError_proto_packet_too_long,
"Supplied packet is too long.", logger);
return false;
}
sqlite3_stmt *stmt = NULL;
const char *sql =
"SELECT key_hash, key_salt, secret, timestamp FROM secrets WHERE id "
"= ? AND domain = ?";
if (sqlite3_prepare_v2(database, sql, -1, &stmt, NULL) != SQLITE_OK) {
dp_logf(logger, DP_LOG_ERROR, DP_PROTO_LOG "/trigger",
"Failed to prepare statement: %s", sqlite3_errmsg(database));
dp_proto_error(ssl, dp_PacketError_internal,
"Failed to set up SQL statement", logger);
return false;
}
sqlite3_bind_int64(stmt, 1, (sqlite3_int64)tid);
sqlite3_bind_text(stmt, 2, (const char *)domain, domain_len,
SQLITE_TRANSIENT);
if (sqlite3_step(stmt) != SQLITE_ROW) {
dp_proto_error(ssl, dp_PacketError_proto_domain_not_found,
"Domain not found", logger);
goto error;
}
const uint8_t *key_hash = sqlite3_column_blob(stmt, 0);
const uint8_t *key_salt = sqlite3_column_blob(stmt, 1);
const uint8_t *secret = sqlite3_column_text(stmt, 2);
const uint64_t timestamp = (uint64_t)sqlite3_column_int64(stmt, 3);
if (!dp_auth_verify_token(secret, timestamp, token)) {
dp_proto_error(ssl, dp_PacketError_auth_token,
"Invalid or expired token", logger);
goto error;
}
if (!dp_BLAKE2s_check_password_hash(key, DP_AUTH_KEY_SIZE, key_hash,
key_salt)) {
dp_proto_error(ssl, dp_PacketError_auth_key, "Invalid key", logger);
goto error;
}
if (!dp_trigger_online_deploy(ssl, (const char *)domain, domain_len, logger,
command, is_unsafe)) {
dp_proto_error(ssl, dp_PacketError_deploy_error,
"Online deploy failed!", logger);
goto error;
}
sqlite3_finalize(stmt);
return true;
error:
sqlite3_finalize(stmt);
return false;
}
bool dp_proto_error(SSL *ssl,
dp_PacketError code,
const char *msg,
dp_Logger *logger) {
if (!ssl || !msg || !logger) {
return false;
}
dp_logf(logger, DP_LOG_ERROR, DP_PROTO_LOG "/error", "%s (%s)", msg,
dp_PacketError_to_str(code));
size_t msg_len = strlen(msg);
if (msg_len == 0 || msg_len > UINT16_MAX) {
return false;
}
msg_len = DP_MIN(DP_PACKET_SIZE - 1 - 2 - 2, msg_len);
uint8_t buffer[DP_PACKET_SIZE] = {0};
buffer[0] = (uint8_t)dp_PacketType_error;
buffer[1] = (uint8_t)((uint16_t)msg_len & (uint8_t)0xFF);
buffer[2] = (uint8_t)(((uint16_t)msg_len >> (uint8_t)8) & (uint8_t)0xFF);
buffer[3] = (uint8_t)(code & (uint8_t)0xFF);
buffer[4] = (uint8_t)((code >> (uint8_t)8) & (uint8_t)0xFF);
memcpy(1 + 2 + 2 + buffer, msg, msg_len);
const int packet_len = 1 + 2 + 2 + (int)msg_len;
return SSL_write(ssl, buffer, packet_len) == packet_len;
}
bool dp_proto_log_chunker(SSL *ssl,
const char *chunk,
size_t chunk_size,
dp_Logger *logger) {
if (!ssl || !chunk || chunk_size == 0) {
return false;
}
const size_t max_payload_size = DP_PACKET_SIZE - 3;
size_t offset = 0;
uint8_t buffer[DP_PACKET_SIZE] = {0};
dp_log(logger, DP_LOG_DEBUG, DP_PROTO_LOG "/log_chunker",
"[begin log stream]");
while (chunk_size > 0) {
const size_t to_send =
(chunk_size > max_payload_size) ? max_payload_size : chunk_size;
buffer[0] = (uint8_t)dp_PacketType_log;
buffer[1] = (uint8_t)(to_send & (uint8_t)0xFF);
buffer[2] = (uint8_t)((to_send >> (uint8_t)8) & (uint8_t)0xFF);
memcpy(buffer + 3, chunk + offset, to_send);
dp_logf(logger, DP_LOG_DEBUG, DP_PROTO_LOG "/log_chunker",
"Sending %zu log chunk to client", to_send);
const int packet_size = 3 + (int)to_send;
if (SSL_write(ssl, buffer, packet_size) != packet_size) {
return false;
}
offset += to_send;
chunk_size -= to_send;
}
return true;
}
bool dp_proto_log_end(SSL *ssl, dp_Logger *logger) {
if (!ssl) {
return false;
}
static const uint8_t end = (uint8_t)dp_PacketType_logs_end;
dp_log(logger, DP_LOG_DEBUG, DP_PROTO_LOG "/log_chunker",
"[end log stream]");
return SSL_write(ssl, &end, 1) == 1;
}
bool dp_proto_exit(SSL *ssl) {
if (!ssl) {
return false;
}
static const uint8_t buf = (uint8_t)dp_PacketType_exit;
return SSL_write(ssl, &buf, 1) == 1;
}
bool dp_proto_ping(SSL *ssl) {
static const uint8_t buf = (uint8_t)dp_PacketType_ping;
return ssl && (SSL_write(ssl, &buf, 1) == 1);
}
bool dp_proto_ping_reply(SSL *ssl) {
static const uint8_t buf = (uint8_t)dp_PacketType_ping_reply;
return ssl && (SSL_write(ssl, &buf, 1) == 1);
}
bool dp_proto_ready(SSL *ssl) {
static const uint8_t buf = (uint8_t)dp_PacketType_ready;
return ssl && (SSL_write(ssl, &buf, 1) == 1);
}
bool dp_proto_skip(SSL *ssl, dp_PacketType type) {
if (!ssl) {
return false;
}
size_t to_read = 0;
uint8_t buf[DP_PACKET_SIZE] = {0};
switch (type) {
case dp_PacketType_command:
/*
* Packet:
* is_unsafe (1 byte)
* id (8 bytes, little-endian 64 bit unsigned integer)
* domain_len (1 byte, value - up to DP_DOMAIN_LEN)
* domain (domain_len bytes)
* key (DP_AUTH_KEY_SIZE bytes, always static)
* token (DP_AUTH_TOKEN_SIZE bytes, changing)
*/
if (SSL_read(ssl, buf, 1 + 8) != (1 + 8)) { /* id_unsafe + id */
return false;
}
if (SSL_read(ssl, buf, 1) != 1) { /* domain_len */
return false;
}
to_read = buf[0] + DP_AUTH_KEY_SIZE + DP_AUTH_TOKEN_SIZE;
break;
case dp_PacketType_ping:
case dp_PacketType_ping_reply:
case dp_PacketType_allowed:
case dp_PacketType_ready: break;
case dp_PacketType_log:
/*
* Packet:
* buffer_size (2 bytes, little endian)
* buffer
*/
if (SSL_read(ssl, buf, 2) != 2) { /* buffer_size */
return false;
}
to_read =
(uint16_t)buf[0] | (uint16_t)((uint16_t)buf[1] << (uint16_t)8);
break;
case dp_PacketType_logs_end:
case dp_PacketType_exit: break;
case dp_PacketType_error:
/*
* Packet:
* msg_len (2 bytes)
* error_code (2 bytes)
* message
*/
if (SSL_read(ssl, buf, 2) != 2) { /* msg_len */
return false;
}
to_read =
(uint16_t)buf[0] | (uint16_t)((uint16_t)buf[1] << (uint16_t)8);
to_read += 2; /* for error_code */
break;
default: return false;
}
while (to_read > 0) {
const int bytes = SSL_read(ssl, buf, (int)DP_MIN(to_read, sizeof(buf)));
if (bytes <= 0) {
return false;
}
to_read -= (size_t)bytes;
}
return true;
}