Home Blog CV Projects Patterns Notes Book Colophon Search

OpenAPI pl/pgSQL

15 Apr, 2025

Here's a tool for generating PostgreSQL functions from an OpenAPI schema.

#!/usr/bin/env python3
# cat openapi.json | python3 generate_sql.py --current_user_config "app.current_user" --impl_file impl.sql api.sql

import argparse
import json
import os
import sys

# ---- Helper Functions for SQL Generation ------------------------------------

def generate_domain_sql(property_name, pattern):
    # Generate a domain for properties that require a pattern.
    return f"CREATE DOMAIN {property_name} AS TEXT\n  CHECK (VALUE ~ '{pattern}');"

def generate_composite_type(schema_name, schema):
    # Generate a composite type from a components schema (assumed to be an object)
    fields = []
    type_map = {"string": "TEXT", "integer": "INT"}

    for prop_name, prop in schema.get("properties", {}).items():
        if "pattern" in prop:
            col_type = prop_name  # use the domain name (same as property name)
        else:
            json_type = prop.get("type", "string")
            col_type = type_map.get(json_type, "TEXT")
        fields.append(f"    {prop_name} {col_type}")

    fields_str = ",\n".join(fields)
    return f"CREATE TYPE {schema_name} AS (\n{fields_str}\n);"

def generate_parameter_composite_type(type_name, parameters):
    # Generate a composite type for GET parameters.
    fields = []
    type_map = {"string": "TEXT", "integer": "INT"}

    for param in parameters:
        pname = param["name"]
        schema = param.get("schema", {})
        if "pattern" in schema:
            col_type = pname  # reference the domain
        else:
            json_type = schema.get("type", "string")
            col_type = type_map.get(json_type, "TEXT")
        fields.append(f"    {pname} {col_type}")

    fields_str = ",\n".join(fields)
    return f"CREATE TYPE {type_name} AS (\n{fields_str}\n);"

def generate_api_function(op, current_user_config):
    op_id = op["operationId"]
    # Determine input type and parameter name.
    if "requestBody" in op:
        # For POST, assume a $ref in application/json.
        ref = op["requestBody"]["content"]["application/json"]["schema"].get("$ref")
        input_type = ref.split("/")[-1] if ref else "unknown_type"
        input_param_name = "i_" + input_type.replace("_type", "").lower()
    elif "parameters" in op and op["parameters"]:
        # For GET, generate a composite type (named: <op_id>_input)
        input_type = op_id + "_input"
        input_param_name = "i_" + input_type
    else:
        input_type = "void"
        input_param_name = "i_void"

    # Determine output type from the 200 response.
    responses = op.get("responses", {})
    resp_200 = responses.get("200", {})
    ref = resp_200.get("content", {}).get("application/json", {}).get("schema", {}).get("$ref")
    output_type = ref.split("/")[-1] if ref else "unknown_type"

    secured = bool(op.get("security"))

    func_lines = []
    func_lines.append(f"-- Generated function for operation {op_id}")
    func_lines.append(f"CREATE OR REPLACE FUNCTION {op_id}(")
    params = [f"{input_param_name} {input_type}"]
    if secured:
        params.append("i_sub TEXT")
    func_lines.append("    " + ",\n    ".join(params))
    func_lines.append(")")
    func_lines.append(f"RETURNS {output_type}")
    func_lines.append("AS $$")
    func_lines.append("DECLARE")
    func_lines.append(f"    v_result {output_type};")
    func_lines.append("BEGIN")
    if secured:
        func_lines.append(f"    PERFORM set_config('{current_user_config}', i_sub, false);")
    else:
        func_lines.append(f"    PERFORM set_config('{current_user_config}', '', false);")
    func_lines.append("")
    func_lines.append(f"    v_result := impl_{op_id}({input_param_name});")
    func_lines.append("")
    func_lines.append(f"    PERFORM set_config('{current_user_config}', '', false);")
    func_lines.append("")
    func_lines.append("    RETURN v_result;")
    func_lines.append("EXCEPTION WHEN OTHERS THEN")
    func_lines.append(f"    PERFORM set_config('{current_user_config}', '', false);")
    func_lines.append("    RAISE;")
    func_lines.append("END;")
    func_lines.append("$$ LANGUAGE plpgsql;")
    func_lines.append("")

    return "\n".join(func_lines), input_type, output_type, input_param_name

def generate_impl_function(op, input_type, output_type, input_param_name):
    op_id = op["operationId"]
    lines = []
    lines.append(f"-- Implementation function for operation {op_id}")
    lines.append(f"CREATE OR REPLACE FUNCTION impl_{op_id}(")
    lines.append(f"    {input_param_name} {input_type}")
    lines.append(")")
    lines.append(f"RETURNS {output_type}")
    lines.append("AS $$")
    lines.append("DECLARE")
    lines.append(f"    v_result {output_type};")
    lines.append("BEGIN")
    lines.append("    -- TODO: Implement operation logic")
    lines.append("")
    lines.append("    RETURN v_result;")
    lines.append("END;")
    lines.append("$$ LANGUAGE plpgsql;")
    lines.append("")
    return "\n".join(lines)

# ---- Main Code --------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Generate SQL files from an OpenAPI JSON document sent via stdin."
    )
    parser.add_argument("api_file", help="Output file for generated API SQL")
    parser.add_argument("--impl_file", help="Output file for implementation SQL (optional)")
    parser.add_argument("--current_user_config", default="app.current_user",
                        help="Configuration variable for current user (default: app.current_user)")

    args = parser.parse_args()

    openapi_json = sys.stdin.read()
    spec = json.loads(openapi_json)

    sql_api_lines = []
    sql_impl_lines = []

    # ---- Generate Domains ---------------------------------------------------
    domains = {}

    # Scan component schemas
    components = spec.get("components", {}).get("schemas", {})
    for schema_name, schema in components.items():
        if schema.get("type") == "object":
            for prop_name, prop in schema.get("properties", {}).items():
                if "pattern" in prop and prop_name not in domains:
                    domains[prop_name] = generate_domain_sql(prop_name, prop["pattern"])

    # Scan inline parameters (for GET requests)
    for path_item in spec.get("paths", {}).values():
        for method, op in path_item.items():
            if "parameters" in op:
                for param in op["parameters"]:
                    pname = param["name"]
                    schema = param.get("schema", {})
                    if "pattern" in schema and pname not in domains:
                        domains[pname] = generate_domain_sql(pname, schema["pattern"])

    for domain_sql in domains.values():
        sql_api_lines.append(domain_sql)
        sql_api_lines.append("")

    # ---- Generate Composite Types for Components ---------------------------
    composite_types = {}
    for schema_name, schema in components.items():
        if schema.get("type") == "object":
            comp_sql = generate_composite_type(schema_name, schema)
            composite_types[schema_name] = comp_sql
            sql_api_lines.append(comp_sql)
            sql_api_lines.append("")

    # ---- Generate Composite Types for GET Parameters -------------------------
    generated_input_types = {}
    for path_item in spec.get("paths", {}).values():
        for method, op in path_item.items():
            if method.lower() == "get" and "parameters" in op:
                op_id = op["operationId"]
                input_type_name = op_id + "_input"
                if input_type_name not in generated_input_types:
                    comp_sql = generate_parameter_composite_type(input_type_name, op["parameters"])
                    generated_input_types[input_type_name] = comp_sql
                    sql_api_lines.append(comp_sql)
                    sql_api_lines.append("")

    # ---- Generate Functions --------------------------------------------------
    # For each operation in each path (supporting GET and POST in this pattern).
    for path_item in spec.get("paths", {}).values():
        for method, op in path_item.items():
            if method.lower() in ("get", "post") and "operationId" in op:
                api_func_sql, inp_type, out_type, inp_param = generate_api_function(op, args.current_user_config)
                sql_api_lines.append(api_func_sql)

                # Generate implementation stub only if an impl file is requested.
                if args.impl_file:
                    impl_func_sql = generate_impl_function(op, inp_type, out_type, inp_param)
                    sql_impl_lines.append(impl_func_sql)

    # ---- Write out SQL files -----------------------------------------------
    with open(args.api_file, "w") as f:
        f.write("\n".join(sql_api_lines))

    if args.impl_file:
        with open(args.impl_file, "w") as f:
            f.write("\n".join(sql_impl_lines))

    print(f"API SQL generated to {args.api_file}")
    if args.impl_file:
        print(f"Implementation SQL generated to {args.impl_file}")

if __name__ == "__main__":
    main()

Comments

Be the first to comment.

Add Comment





Copyright James Gardner 1996-2020 All Rights Reserved. Admin.