Home Blog CV Projects Patterns Notes Book Colophon Search

Tiny Python PostgreSQL Driver

8 Apr, 2025

Please don't use this yet, but this is the bare minimum I needed to connect to my hosting provider's PostgreSQL instance and call a function from a CGI script I was writing:

import base64
import hashlib
import hmac
import os
import socket
import struct
import sys

def error(*k, **p):
    p['file'] = sys.stderr
    print(*k, **p)

# --- Utility Functions for Message I/O ---

def send_startup_message(sock, user, database):
    """
    Sends a startup message to the PostgreSQL server.
    """
    protocol_version = 196608  # Protocol version 3.0
    payload = struct.pack("!I", protocol_version)
    for key, value in (("user", user), ("database", database)):
        payload += key.encode('ascii') + b'\x00' + value.encode('ascii') + b'\x00'
    payload += b'\x00'
    total_length = len(payload) + 4
    message = struct.pack("!I", total_length) + payload
    sock.sendall(message)

def read_message(sock):
    """
    Reads a complete message from the server.
    Each message:
      - 1 byte: message type
      - 4 bytes: message length (includes these 4 bytes)
      - Remaining: payload
    """
    type_byte = sock.recv(1)
    if not type_byte:
        return None, None
    msg_type = type_byte.decode('ascii')
    length_bytes = sock.recv(4)
    if len(length_bytes) < 4:
        raise Exception("Incomplete message length")
    length = struct.unpack("!I", length_bytes)[0]
    payload = b""
    remaining = length - 4
    while remaining > 0:
        chunk = sock.recv(remaining)
        if not chunk:
            break
        payload += chunk
        remaining -= len(chunk)
    return msg_type, payload

def parse_error(payload):
    """
    Parses an error message payload from the server.
    """
    fields = payload.split(b'\x00')
    for field in fields:
        if field.startswith(b'M'):
            return field[1:].decode('utf-8', 'replace')
    return "Unknown error"

def upgrade_to_ssl(sock, server_hostname=None, ca_file=None):
    """
    Sends an SSLRequest to the PostgreSQL server and, if supported, wraps
    the socket with SSL. If ssl_required is True (the default) and the server
    does not support SSL, an exception is raised.
    """
    import ssl
    ssl_request_code = 80877103  # fixed SSL request code for PostgreSQL
    packet = struct.pack("!I", 8) + struct.pack("!I", ssl_request_code)
    sock.sendall(packet)
    response = sock.recv(1)
    if response == b'S':
        # The server supports SSL. Wrap the socket with a secure context.
        context = ssl.create_default_context(cafile=ca_file)
        # With create_default_context(), certificate validation is enabled by default.
        secure_sock = context.wrap_socket(sock, server_hostname=server_hostname)
        return secure_sock
    else:
        # Server indicated SSL is not supported ('N').
        raise Exception("SSL required but server does not support SSL")


# --- SCRAM-SHA-256 Authentication Implementation ---

def scram_authenticate(sock, username, password):
    """
    Performs the SCRAM-SHA-256 authentication exchange.

    Implements the following steps:
      1. Send SASLInitialResponse with client-first-message.
      2. Receive server-first-message.
      3. Compute the client-final-message and client proof.
      4. Send SASLResponse with client-final-message.
      5. Receive and verify server-final-message.
    """
    # ---- Phase 1: Build and send client-first-message ----
    client_nonce = base64.b64encode(os.urandom(18)).decode('utf-8')
    gs2_header = "n,,"
    def scram_escape(s):
        return s.replace("=", "=3D").replace(",", "=2C")
    user_escaped = scram_escape(username)
    client_first_message_bare = f"n={user_escaped},r={client_nonce}"
    client_first_message = gs2_header + client_first_message_bare

    # Build SASLInitialResponse (message type 'p') payload:
    #   mechanism (null-terminated) + 4-byte length of initial response + response bytes.
    mechanism = "SCRAM-SHA-256"
    client_first_message_bytes = client_first_message.encode('utf-8')
    payload = mechanism.encode('utf-8') + b'\x00' + struct.pack("!I", len(client_first_message_bytes)) + client_first_message_bytes
    total_length = 4 + len(payload)
    message = b'p' + struct.pack("!I", total_length) + payload
    sock.sendall(message)

    # ---- Phase 2: Receive server-first-message (Auth code 11) ----
    msg_type, payload = read_message(sock)
    if msg_type != 'R':
        if msg_type == 'E':
            error_message = parse_error(payload)
            raise Exception("Server sent error during SCRAM auth: " + error_message)
        else:
            raise Exception(f"Expected SASLFinal message from server, got {msg_type} instead")

    auth_code = struct.unpack("!I", payload[:4])[0]
    if auth_code != 11:
        raise Exception("Expected auth code 11 but got " + str(auth_code))
    server_first_message = payload[4:].decode('utf-8')
    # Expected format: "r=<combined nonce>,s=<salt>,i=<iterations>"
    parts = {}
    for item in server_first_message.split(','):
        if '=' in item:
            k, v = item.split('=', 1)
            parts[k] = v
    if 'r' not in parts or 's' not in parts or 'i' not in parts:
        raise Exception("Invalid server-first-message: " + server_first_message)
    server_nonce = parts['r']
    if not server_nonce.startswith(client_nonce):
        raise Exception("Server nonce does not begin with client nonce")
    salt = base64.b64decode(parts['s'])
    iterations = int(parts['i'])

    # ---- Phase 3: Compute client proof ----
    # Compute salted_password using PBKDF2-HMAC-SHA256.
    salted_password = hashlib.pbkdf2_hmac("sha256", password.encode('utf-8'), salt, iterations, dklen=32)
    # Compute client_key and stored_key.
    client_key = hmac.new(salted_password, b"Client Key", hashlib.sha256).digest()
    stored_key = hashlib.sha256(client_key).digest()
    # Build the client-final-message (without proof):
    gs2_header_b64 = base64.b64encode(gs2_header.encode('utf-8')).decode('utf-8')
    client_final_without_proof = f"c={gs2_header_b64},r={server_nonce}"
    # Form auth_message: client-first-message-bare, server-first-message, and client-final-without-proof.
    auth_message = f"{client_first_message_bare},{server_first_message},{client_final_without_proof}".encode('utf-8')
    # Compute client signature and then client proof by XOR'ing client_key with client_signature.
    client_signature = hmac.new(stored_key, auth_message, hashlib.sha256).digest()
    client_proof_bytes = bytes(a ^ b for a, b in zip(client_key, client_signature))
    client_proof = base64.b64encode(client_proof_bytes).decode('utf-8')
    client_final_message = client_final_without_proof + f",p={client_proof}"

    # ---- Phase 4: Send client-final-message in SASLResponse (message type 'p') ----
    client_final_message_bytes = client_final_message.encode('utf-8')
    total_length = 4 + len(client_final_message_bytes)
    message = b'p' + struct.pack("!I", total_length) + client_final_message_bytes
    # error("Sending client-final-message:", client_final_message)
    sock.sendall(message)

    # ---- Phase 5: Receive server-final-message (Auth code 12) and verify ----
    msg_type, payload = read_message(sock)
    # error(msg_type, payload)
    if msg_type != 'R':
        if msg_type == 'E':
            error_message = parse_error(payload)
            raise Exception("Server error during SCRAM auth: " + error_message)
        elif msg_type == 'Z':
            error("Unexpectedly received ReadyForQuery — maybe we're already authenticated?")
        else:
            raise Exception(f"Expected SASLFinal message, got type '{msg_type}'")

    auth_code = struct.unpack("!I", payload[:4])[0]
    if auth_code != 12:
        raise Exception("Expected auth code 12 but got " + str(auth_code))
    server_final_message = payload[4:].decode('utf-8')
    parts = {}
    for item in server_final_message.split(','):
        if '=' in item:
            k, v = item.split('=', 1)
            parts[k] = v
    if 'v' not in parts:
        raise Exception("Server did not send a verifier in final message: " + server_final_message)
    server_verifier = parts['v']

    # Verify server signature.
    server_key = hmac.new(salted_password, b"Server Key", hashlib.sha256).digest()
    expected_server_signature = hmac.new(server_key, auth_message, hashlib.sha256).digest()
    expected_server_verifier = base64.b64encode(expected_server_signature).decode('utf-8')
    if server_verifier != expected_server_verifier:
        raise Exception("Server signature does not match!")


def read_startup_response(sock, username, password):
    """
    Reads startup responses until ReadyForQuery ('Z') is received.
    Handles normal authentication and, if requested, performs SCRAM-SHA-256.
    """
    while True:
        msg_type, payload = read_message(sock)
        if msg_type is None:
            raise Exception("No response from server during startup")
        if msg_type == 'R':
            auth_code = struct.unpack("!I", payload[:4])[0]
            if auth_code == 0:
                # AuthenticationOk. Continue waiting for further messages.
                continue
            elif auth_code == 10:
                # Server requests SCRAM-SHA-256 (SASL)
                scram_authenticate(sock, username, password)
            else:
                raise Exception(f"Unsupported authentication (code: {auth_code})")
        elif msg_type == 'K':
            # BackendKeyData; can be used for query cancellation.
            pass
        elif msg_type == 'E':
            error_message = parse_error(payload)
            raise Exception("Startup error: " + error_message)
        elif msg_type == 'Z':
            # ReadyForQuery: startup complete.
            break

# --- Original code ---

def send_parse(sock, query):
    statement_name = ""  # unnamed
    param_types = [114]  # JSON OID
    payload = statement_name.encode() + b'\x00'
    payload += query.encode() + b'\x00'
    payload += struct.pack("!H", len(param_types))
    for oid in param_types:
        payload += struct.pack("!I", oid)
    message = b'P' + struct.pack("!I", len(payload) + 4) + payload
    sock.sendall(message)

def send_bind(sock, param_value):
    portal_name = b'\x00'
    statement_name = b'\x00'
    param_format = struct.pack("!H", 0)  # text format
    result_format = struct.pack("!H", 0)  # text result
    param_encoded = param_value.encode()
    payload = portal_name + statement_name
    payload += struct.pack("!H", 1) + param_format  # 1 param format
    payload += struct.pack("!H", 1)  # 1 param
    payload += struct.pack("!I", len(param_encoded)) + param_encoded
    payload += struct.pack("!H", 1) + result_format  # 1 result format
    message = b'B' + struct.pack("!I", len(payload) + 4) + payload
    sock.sendall(message)

def send_execute(sock):
    payload = b'\x00' + struct.pack("!I", 0)  # unnamed portal, 0 = no limit
    message = b'E' + struct.pack("!I", len(payload) + 4) + payload
    sock.sendall(message)

def send_sync(sock):
    sock.sendall(b'S' + struct.pack("!I", 4))

def read_query_result(sock):
    result = None
    while True:
        msg_type, payload = read_message(sock)
        if msg_type == 'D':
            num_fields = struct.unpack("!H", payload[:2])[0]
            if num_fields != 1:
                raise Exception("Expected one column")
            offset = 2
            field_len = struct.unpack("!I", payload[offset:offset+4])[0]
            offset += 4
            if field_len == 0xFFFFFFFF:
                result = None
            else:
                result = payload[offset:offset+field_len].decode()
        elif msg_type == 'E':
            raise Exception("Query error: " + parse_error(payload))
        elif msg_type == 'Z':
            break
    return result

def connect(host, port, user, database, password, on_connect, ssl_required=True, cert_host=None, ca_file=None):

    # Connect to PostgreSQL server via TCP
    sock = socket.create_connection((host, port))

    # Upgrade to SSL if required (default is True). The server_hostname helps with cert validation.
    if ssl_required:
        if cert_host is None:
            cert_host = host
        sock = upgrade_to_ssl(sock, server_hostname=cert_host, ca_file=None)

    try:
        # Send the startup message and perform the authentication exchange
        send_startup_message(sock, user, database)
        read_startup_response(sock, user, password)
        # At this point you are authenticated; now call the on_connect callback.
        on_connect(sock)
    finally:
        sock.close()

def query(sock, name, json_arg):
    query = f"SELECT {name}($1);"
    send_parse(sock, query)
    send_bind(sock, json_arg)
    send_execute(sock)
    send_sync(sock)
    return read_query_result(sock)


def run(sock):
    import json
    result = query(sock, 'some_function', json.dumps({"hello": "world", "test": "💥"}))
    message = json.loads(result)['received']['test']

    print(message)

if __name__ == '__main__':
    connect(
        host='localhost',
        port=5432,
        user='postgres',         # adjust as needed
        database='postgres',     # adjust as needed
        password='postgres',     # set your PostgreSQL password
        on_connect=run,
        ssl_required=True,
        ca_file='certs/server.crt'
    )

The function can be created like this:

CREATE OR REPLACE FUNCTION some_function(input JSON)
RETURNS JSON AS $$
BEGIN
    RETURN json_build_object(
        'received', input,
        'message', 'Hello from PostgreSQL!'
    );
END;
$$ LANGUAGE plpgsql;

And you can run a local test server by putting the above SQL in init.sql and then running this:

docker run --rm --name pg-test -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres -v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro" -p 5432:5432 postgres:16

For SSL you can do:

openssl req -new -x509 -days 365 -nodes -text   -out certs/server.crt   -keyout certs/server.key   -subj "/CN=localhost"

Now if you have sudo you can set the file permissions to the required 999:

chmod 600 certs/server.key
sudo chown 999:999 certs/server.crt certs/server.key
docker run --rm --name pg-test   -e POSTGRES_PASSWORD=postgres   -e POSTGRES_USER=postgres   -v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro"   -v "$(pwd)/certs:/var/lib/postgresql/certs:ro"   -p 5432:5432   postgres:16   -c ssl=on   -c ssl_cert_file=/var/lib/postgresql/certs/server.crt   -c ssl_key_file=/var/lib/postgresql/certs/server.key

But it is probably better to create a custom postgres.Dockerfile:

FROM postgres:16
COPY certs/server.crt /var/lib/postgresql/certs/server.crt
COPY certs/server.key /var/lib/postgresql/certs/server.key
RUN chown postgres:postgres /var/lib/postgresql/certs/server.crt /var/lib/postgresql/certs/server.key && \
    chmod 600 /var/lib/postgresql/certs/server.key && \
    chmod 644 /var/lib/postgresql/certs/server.crt

Then:

docker build -f postgres.Dockerfile -t my-postgres-ssl .
docker run --rm --name pg-test -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres \
  -v "$(pwd)/init.sql:/docker-entrypoint-initdb.d/init.sql:ro" \
  -p 5432:5432 \
  my-postgres-ssl \
  -c ssl=on \
  -c ssl_cert_file=/var/lib/postgresql/certs/server.crt \
  -c ssl_key_file=/var/lib/postgresql/certs/server.key

You can also use similar code as a CGI script with something like this:

def run(sock):
    import json
    result = query(sock, 'some_function', json.dumps({"hello": "world", "test": "💥"}))
    message = json.loads(result)['received']['test']

    print("Content-Type: text/html; charset=UTF8\n")
    print(message)

Save that as cgi-bin/hello.cgi and then create a CGI server for testing with serve.py:

#!/usr/bin/env python3
from http.server import HTTPServer, CGIHTTPRequestHandler

def run(server_class=HTTPServer, handler_class=CGIHTTPRequestHandler):
    server_address = ('', 8000)  # Listen on all interfaces at port 8000
    httpd = server_class(server_address, handler_class)
    print("Starting CGI server on http://localhost:8000/cgi-bin/ ...")
    httpd.serve_forever()

if __name__ == '__main__':
    run()

Run that server with python3 serve.py and then visit http://localhost:8000/cgi-bin/hello.cgi and you'll see:

This came from the database: 💥

Comments

Be the first to comment.

Add Comment





Copyright James Gardner 1996-2020 All Rights Reserved. Admin.