8 Apr, 2025
Please don't use this yet, but this is the bare minimum I needed to connect to my hosting provider's PostgreSQL instance and call a function from a CGI script I was writing:
import base64
import hashlib
import hmac
import os
import socket
import struct
import sys
def error(*k, **p):
p['file'] = sys.stderr
print(*k, **p)
# --- Utility Functions for Message I/O ---
def send_startup_message(sock, user, database):
"""
Sends a startup message to the PostgreSQL server.
"""
protocol_version = 196608 # Protocol version 3.0
payload = struct.pack("!I", protocol_version)
for key, value in (("user", user), ("database", database)):
payload += key.encode('ascii') + b'\x00' + value.encode('ascii') + b'\x00'
payload += b'\x00'
total_length = len(payload) + 4
message = struct.pack("!I", total_length) + payload
sock.sendall(message)
def read_message(sock):
"""
Reads a complete message from the server.
Each message:
- 1 byte: message type
- 4 bytes: message length (includes these 4 bytes)
- Remaining: payload
"""
type_byte = sock.recv(1)
if not type_byte:
return None, None
msg_type = type_byte.decode('ascii')
length_bytes = sock.recv(4)
if len(length_bytes) < 4:
raise Exception("Incomplete message length")
length = struct.unpack("!I", length_bytes)[0]
payload = b""
remaining = length - 4
while remaining > 0:
chunk = sock.recv(remaining)
if not chunk:
break
payload += chunk
remaining -= len(chunk)
return msg_type, payload
def parse_error(payload):
"""
Parses an error message payload from the server.
"""
fields = payload.split(b'\x00')
for field in fields:
if field.startswith(b'M'):
return field[1:].decode('utf-8', 'replace')
return "Unknown error"
def upgrade_to_ssl(sock, server_hostname=None, ca_file=None):
"""
Sends an SSLRequest to the PostgreSQL server and, if supported, wraps
the socket with SSL. If ssl_required is True (the default) and the server
does not support SSL, an exception is raised.
"""
import ssl
ssl_request_code = 80877103 # fixed SSL request code for PostgreSQL
packet = struct.pack("!I", 8) + struct.pack("!I", ssl_request_code)
sock.sendall(packet)
response = sock.recv(1)
if response == b'S':
# The server supports SSL. Wrap the socket with a secure context.
context = ssl.create_default_context(cafile=ca_file)
# With create_default_context(), certificate validation is enabled by default.
secure_sock = context.wrap_socket(sock, server_hostname=server_hostname)
return secure_sock
else:
# Server indicated SSL is not supported ('N').
raise Exception("SSL required but server does not support SSL")
# --- SCRAM-SHA-256 Authentication Implementation ---
def scram_authenticate(sock, username, password):
"""
Performs the SCRAM-SHA-256 authentication exchange.
Implements the following steps:
1. Send SASLInitialResponse with client-first-message.
2. Receive server-first-message.
3. Compute the client-final-message and client proof.
4. Send SASLResponse with client-final-message.
5. Receive and verify server-final-message.
"""
# ---- Phase 1: Build and send client-first-message ----
client_nonce = base64.b64encode(os.urandom(18)).decode('utf-8')
gs2_header = "n,,"
def scram_escape(s):
return s.replace("=", "=3D").replace(",", "=2C")
user_escaped = scram_escape(username)
client_first_message_bare = f"n={user_escaped},r={client_nonce}"
client_first_message = gs2_header + client_first_message_bare
# Build SASLInitialResponse (message type 'p') payload:
# mechanism (null-terminated) + 4-byte length of initial response + response bytes.
mechanism = "SCRAM-SHA-256"
client_first_message_bytes = client_first_message.encode('utf-8')
payload = mechanism.encode('utf-8') + b'\x00' + struct.pack("!I", len(client_first_message_bytes)) + client_first_message_bytes
total_length = 4 + len(payload)
message = b'p' + struct.pack("!I", total_length) + payload
sock.sendall(message)
# ---- Phase 2: Receive server-first-message (Auth code 11) ----
msg_type, payload = read_message(sock)
if msg_type != 'R':
if msg_type == 'E':
error_message = parse_error(payload)
raise Exception("Server sent error during SCRAM auth: " + error_message)
else:
raise Exception(f"Expected SASLFinal message from server, got {msg_type} instead")
auth_code = struct.unpack("!I", payload[:4])[0]
if auth_code != 11:
raise Exception("Expected auth code 11 but got " + str(auth_code))
server_first_message = payload[4:].decode('utf-8')
# Expected format: "r=<combined nonce>,s=<salt>,i=<iterations>"
parts = {}
for item in server_first_message.split(','):
if '=' in item:
k, v = item.split('=', 1)
parts[k] = v
if 'r' not in parts or 's' not in parts or 'i' not in parts:
raise Exception("Invalid server-first-message: " + server_first_message)
server_nonce = parts['r']
if not server_nonce.startswith(client_nonce):
raise Exception("Server nonce does not begin with client nonce")
salt = base64.b64decode(parts['s'])
iterations = int(parts['i'])
# ---- Phase 3: Compute client proof ----
# Compute salted_password using PBKDF2-HMAC-SHA256.
salted_password = hashlib.pbkdf2_hmac("sha256", password.encode('utf-8'), salt, iterations, dklen=32)
# Compute client_key and stored_key.
client_key = hmac.new(salted_password, b"Client Key", hashlib.sha256).digest()
stored_key = hashlib.sha256(client_key).digest()
# Build the client-final-message (without proof):
gs2_header_b64 = base64.b64encode(gs2_header.encode('utf-8')).decode('utf-8')
client_final_without_proof = f"c={gs2_header_b64},r={server_nonce}"
# Form auth_message: client-first-message-bare, server-first-message, and client-final-without-proof.
auth_message = f"{client_first_message_bare},{server_first_message},{client_final_without_proof}".encode('utf-8')
# Compute client signature and then client proof by XOR'ing client_key with client_signature.
client_signature = hmac.new(stored_key, auth_message, hashlib.sha256).digest()
client_proof_bytes = bytes(a ^ b for a, b in zip(client_key, client_signature))
client_proof = base64.b64encode(client_proof_bytes).decode('utf-8')
client_final_message = client_final_without_proof + f",p={client_proof}"
# ---- Phase 4: Send client-final-message in SASLResponse (message type 'p') ----
client_final_message_bytes = client_final_message.encode('utf-8')
total_length = 4 + len(client_final_message_bytes)
message = b'p' + struct.pack("!I", total_length) + client_final_message_bytes
# error("Sending client-final-message:", client_final_message)
sock.sendall(message)
# ---- Phase 5: Receive server-final-message (Auth code 12) and verify ----
msg_type, payload = read_message(sock)
# error(msg_type, payload)
if msg_type != 'R':
if msg_type == 'E':
error_message = parse_error(payload)
raise Exception("Server error during SCRAM auth: " + error_message)
elif msg_type == 'Z':
error("Unexpectedly received ReadyForQuery — maybe we're already authenticated?")
else:
raise Exception(f"Expected SASLFinal message, got type '{msg_type}'")
auth_code = struct.unpack("!I", payload[:4])[0]
if auth_code != 12:
raise Exception("Expected auth code 12 but got " + str(auth_code))
server_final_message = payload[4:].decode('utf-8')
parts = {}
for item in server_final_message.split(','):
if '=' in item:
k, v = item.split('=', 1)
parts[k] = v
if 'v' not in parts:
raise Exception("Server did not send a verifier in final message: " + server_final_message)
server_verifier = parts['v']
# Verify server signature.
server_key = hmac.new(salted_password, b"Server Key", hashlib.sha256).digest()
expected_server_signature = hmac.new(server_key, auth_message, hashlib.sha256).digest()
expected_server_verifier = base64.b64encode(expected_server_signature).decode('utf-8')
if server_verifier != expected_server_verifier:
raise Exception("Server signature does not match!")
def read_startup_response(sock, username, password):
"""
Reads startup responses until ReadyForQuery ('Z') is received.
Handles normal authentication and, if requested, performs SCRAM-SHA-256.
"""
while True:
msg_type, payload = read_message(sock)
if msg_type is None:
raise Exception("No response from server during startup")
if msg_type == 'R':
auth_code = struct.unpack("!I", payload[:4])[0]
if auth_code == 0:
# AuthenticationOk. Continue waiting for further messages.
continue
elif auth_code == 10:
# Server requests SCRAM-SHA-256 (SASL)
scram_authenticate(sock, username, password)
else:
raise Exception(f"Unsupported authentication (code: {auth_code})")
elif msg_type == 'K':
# BackendKeyData; can be used for query cancellation.
pass
elif msg_type == 'E':
error_message = parse_error(payload)
raise Exception("Startup error: " + error_message)
elif msg_type == 'Z':
# ReadyForQuery: startup complete.
break
# --- Original code ---
def send_parse(sock, query):
statement_name = "" # unnamed
param_types = [114] # JSON OID
payload = statement_name.encode() + b'\x00'
payload += query.encode() + b'\x00'
payload += struct.pack("!H", len(param_types))
for oid in param_types:
payload += struct.pack("!I", oid)
message = b'P' + struct.pack("!I", len(payload) + 4) + payload
sock.sendall(message)
def send_bind(sock, param_value):
portal_name = b'\x00'
statement_name = b'\x00'
param_format = struct.pack("!H", 0) # text format
result_format = struct.pack("!H", 0) # text result
param_encoded = param_value.encode()
payload = portal_name + statement_name
payload += struct.pack("!H", 1) + param_format # 1 param format
payload += struct.pack("!H", 1) # 1 param
payload += struct.pack("!I", len(param_encoded)) + param_encoded
payload += struct.pack("!H", 1) + result_format # 1 result format
message = b'B' + struct.pack("!I", len(payload) + 4) + payload
sock.sendall(message)
def send_execute(sock):
payload = b'\x00' + struct.pack("!I", 0) # unnamed portal, 0 = no limit
message = b'E' + struct.pack("!I", len(payload) + 4) + payload
sock.sendall(message)
def send_sync(sock):
sock.sendall(b'S' + struct.pack("!I", 4))
def read_query_result(sock):
result = None
while True:
msg_type, payload = read_message(sock)
if msg_type == 'D':
num_fields = struct.unpack("!H", payload[:2])[0]
if num_fields != 1:
raise Exception("Expected one column")
offset = 2
field_len = struct.unpack("!I", payload[offset:offset+4])[0]
offset += 4
if field_len == 0xFFFFFFFF:
result = None
else:
result = payload[offset:offset+field_len].decode()
elif msg_type == 'E':
raise Exception("Query error: " + parse_error(payload))
elif msg_type == 'Z':
break
return result
def connect(host, port, user, database, password, on_connect, ssl_required=True, cert_host=None, ca_file=None):
# Connect to PostgreSQL server via TCP
sock = socket.create_connection((host, port))
# Upgrade to SSL if required (default is True). The server_hostname helps with cert validation.
if ssl_required:
if cert_host is None:
cert_host = host
sock = upgrade_to_ssl(sock, server_hostname=cert_host, ca_file=None)
try:
# Send the startup message and perform the authentication exchange
send_startup_message(sock, user, database)
read_startup_response(sock, user, password)
# At this point you are authenticated; now call the on_connect callback.
on_connect(sock)
finally:
sock.close()
def query(sock, name, json_arg):
query = f"SELECT {name}($1);"
send_parse(sock, query)
send_bind(sock, json_arg)
send_execute(sock)
send_sync(sock)
return read_query_result(sock)
def run(sock):
import json
result = query(sock, 'some_function', json.dumps({"hello": "world", "test": "💥"}))
message = json.loads(result)['received']['test']
print(message)
if __name__ == '__main__':
connect(
host='localhost',
port=5432,
user='postgres', # adjust as needed
database='postgres', # adjust as needed
password='postgres', # set your PostgreSQL password
on_connect=run,
ssl_required=True,
ca_file='certs/server.crt'
)
The function can be created like this:
CREATE OR REPLACE FUNCTION some_function(input JSON)
RETURNS JSON AS $$
BEGIN
RETURN json_build_object(
'received', input,
'message', 'Hello from PostgreSQL!'
);
END;
$$ LANGUAGE plpgsql;
And you can run a local test server by putting the above SQL in init.sql
and then running this:
docker run --rm --name pg-test -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres -v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro" -p 5432:5432 postgres:16
For SSL you can do:
openssl req -new -x509 -days 365 -nodes -text -out certs/server.crt -keyout certs/server.key -subj "/CN=localhost"
Now if you have sudo
you can set the file permissions to the required 999:
chmod 600 certs/server.key
sudo chown 999:999 certs/server.crt certs/server.key
docker run --rm --name pg-test -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres -v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro" -v "$(pwd)/certs:/var/lib/postgresql/certs:ro" -p 5432:5432 postgres:16 -c ssl=on -c ssl_cert_file=/var/lib/postgresql/certs/server.crt -c ssl_key_file=/var/lib/postgresql/certs/server.key
But it is probably better to create a custom postgres.Dockerfile
:
FROM postgres:16
COPY certs/server.crt /var/lib/postgresql/certs/server.crt
COPY certs/server.key /var/lib/postgresql/certs/server.key
RUN chown postgres:postgres /var/lib/postgresql/certs/server.crt /var/lib/postgresql/certs/server.key && \
chmod 600 /var/lib/postgresql/certs/server.key && \
chmod 644 /var/lib/postgresql/certs/server.crt
Then:
docker build -f postgres.Dockerfile -t my-postgres-ssl .
docker run --rm --name pg-test -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres \
-v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro" \
-p 5432:5432 \
my-postgres-ssl \
-c ssl=on \
-c ssl_cert_file=/var/lib/postgresql/certs/server.crt \
-c ssl_key_file=/var/lib/postgresql/certs/server.key
You can also use similar code as a CGI script with something like this:
def run(sock):
import json
result = query(sock, 'some_function', json.dumps({"hello": "world", "test": "💥"}))
message = json.loads(result)['received']['test']
print("Content-Type: text/html; charset=UTF8\n")
print(message)
Save that as cgi-bin/hello.cgi
and then create a CGI server for testing with serve.py
:
#!/usr/bin/env python3
from http.server import HTTPServer, CGIHTTPRequestHandler
def run(server_class=HTTPServer, handler_class=CGIHTTPRequestHandler):
server_address = ('', 8000) # Listen on all interfaces at port 8000
httpd = server_class(server_address, handler_class)
print("Starting CGI server on http://localhost:8000/cgi-bin/ ...")
httpd.serve_forever()
if __name__ == '__main__':
run()
Run that server with python3 serve.py
and then visit http://localhost:8000/cgi-bin/hello.cgi and you'll see:
This came from the database: 💥
Be the first to comment.
Copyright James Gardner 1996-2020 All Rights Reserved. Admin.