Skip to main content

Overview

Neurenix provides built-in API servers for serving models through multiple protocols:
  • RESTful API: Standard HTTP-based inference via FastAPI
  • WebSocket: Real-time bidirectional communication
  • gRPC: High-performance RPC for production systems

Quick Start

Simple REST Server

from neurenix.api_support import serve_model
from neurenix.nn import Sequential, Linear, ReLU

# Create your model
model = Sequential(
    Linear(784, 256),
    ReLU(),
    Linear(256, 10)
)

# Serve on REST API (default)
servers = serve_model(
    model=model,
    name="mnist_classifier",
    server_types=["rest"],
    host="0.0.0.0",
    ports={"rest": 8000}
)

print("Model serving on http://0.0.0.0:8000")

Making Predictions

curl -X POST http://localhost:8000/models/mnist_classifier/predict \
  -H "Content-Type: application/json" \
  -d '{"data": [[1.0, 2.0, 3.0, ...]]}'  # 784 features

APIManager

The APIManager class provides centralized management of API servers.

Initialization

from neurenix.api_support import APIManager

# Get singleton instance
manager = APIManager()

Creating Servers

# Create a REST server
rest_server = manager.create_server(
    server_type="rest",
    host="0.0.0.0",
    port=8000
)

# Create a WebSocket server
ws_server = manager.create_server(
    server_type="websocket",
    host="0.0.0.0",
    port=8001
)

# Create a gRPC server
grpc_server = manager.create_server(
    server_type="grpc",
    host="0.0.0.0",
    port=8002
)

Managing Models

# Add models to server
rest_server.add_model("model_v1", model_v1)
rest_server.add_model("model_v2", model_v2)

# Remove a model
rest_server.remove_model("model_v1")

# Start the server
rest_server.start()

# Stop the server
rest_server.stop()

RESTful API Server

Setup

from neurenix.api_support import RESTfulServer, create_rest_server

# Method 1: Direct instantiation
server = RESTfulServer(host="0.0.0.0", port=8000)
server.add_model("my_model", model)
server.start()

# Method 2: Using helper function
server = create_rest_server(host="0.0.0.0", port=8000)

API Endpoints

Root Endpoint:
GET /
# Response: {"message": "Neurenix API Server", "version": "1.0.0"}
List Models:
GET /models
# Response: {"models": ["model1", "model2"]}
Make Predictions:
POST /models/{model_name}/predict
Content-Type: application/json

{
  "data": [[1.0, 2.0, 3.0, ...]]
}

# Response: {"result": [[0.1, 0.9, ...]]}

Example with Python Client

import requests
import numpy as np

# Prepare input data
data = np.random.randn(1, 784).tolist()

# Make prediction
response = requests.post(
    "http://localhost:8000/models/mnist_classifier/predict",
    json={"data": data}
)

result = response.json()
print(f"Prediction: {result['result']}")

CORS Configuration

The REST server automatically enables CORS with permissive defaults. For production, configure appropriately in the source code at neurenix/api_support.py:159-165.

WebSocket Server

Setup

from neurenix.api_support import WebSocketServer, create_websocket_server

server = WebSocketServer(host="0.0.0.0", port=8001)
server.add_model("my_model", model)
server.start()

WebSocket Protocol

Messages are JSON-formatted with an action field: List Models:
{"action": "list_models"}
Make Prediction:
{
  "action": "predict",
  "model": "my_model",
  "data": [[1.0, 2.0, 3.0, ...]]
}

Python WebSocket Client

import asyncio
import websockets
import json

async def predict():
    uri = "ws://localhost:8001"
    async with websockets.connect(uri) as websocket:
        # Send prediction request
        request = {
            "action": "predict",
            "model": "mnist_classifier",
            "data": [[1.0, 2.0, 3.0, ...]]
        }
        await websocket.send(json.dumps(request))
        
        # Receive response
        response = await websocket.recv()
        result = json.loads(response)
        print(f"Result: {result}")

# Run the client
asyncio.run(predict())

JavaScript WebSocket Client

const ws = new WebSocket('ws://localhost:8001');

ws.onopen = () => {
    // Send prediction request
    ws.send(JSON.stringify({
        action: 'predict',
        model: 'mnist_classifier',
        data: [[1.0, 2.0, 3.0, ...]]
    }));
};

ws.onmessage = (event) => {
    const result = JSON.parse(event.data);
    console.log('Prediction:', result);
};

gRPC Server

Setup

from neurenix.api_support import GRPCServer, create_grpc_server

server = GRPCServer(host="0.0.0.0", port=8002)
server.add_model("my_model", model)
server.start()

Protocol Definition

The gRPC server automatically generates a proto file at neurenix/proto/neurenix_api.proto:
syntax = "proto3";

package neurenix;

service NeurenixService {
    rpc ListModels (ListModelsRequest) returns (ListModelsResponse);
    rpc Predict (PredictRequest) returns (PredictResponse);
}

message PredictRequest {
    string model_name = 1;
    repeated float data = 2;
    repeated int32 shape = 3;
}

message PredictResponse {
    repeated float result = 1;
    repeated int32 shape = 2;
}

Python gRPC Client

import grpc
import numpy as np
# Import generated modules (after server starts once)
from neurenix.proto import neurenix_api_pb2, neurenix_api_pb2_grpc

# Create channel
channel = grpc.insecure_channel('localhost:8002')
stub = neurenix_api_pb2_grpc.NeurenixServiceStub(channel)

# Prepare input
data = np.random.randn(1, 784).astype(np.float32)

# Make prediction
request = neurenix_api_pb2.PredictRequest(
    model_name="mnist_classifier",
    data=data.flatten().tolist(),
    shape=list(data.shape)
)

response = stub.Predict(request)
result = np.array(response.result).reshape(response.shape)
print(f"Prediction: {result}")

Multi-Protocol Serving

Serve a model on all three protocols simultaneously:
from neurenix.api_support import serve_model

servers = serve_model(
    model=model,
    name="multi_protocol_model",
    server_types=["rest", "websocket", "grpc"],
    host="0.0.0.0",
    ports={
        "rest": 8000,
        "websocket": 8001,
        "grpc": 8002
    }
)

print("Model available on:")
print("  REST:      http://localhost:8000")
print("  WebSocket: ws://localhost:8001")
print("  gRPC:      localhost:8002")

Advanced Configuration

Custom Preprocessing and Postprocessing

Extend the server classes to customize input/output handling:
from neurenix.api_support import RESTfulServer
import cv2
import numpy as np

class ImageRESTServer(RESTfulServer):
    def _preprocess_input(self, data):
        # Custom image preprocessing
        if isinstance(data, dict) and "image_base64" in data:
            import base64
            img_bytes = base64.b64decode(data["image_base64"])
            img_array = np.frombuffer(img_bytes, dtype=np.uint8)
            img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
            img = cv2.resize(img, (224, 224))
            img = img.astype(np.float32) / 255.0
            return Tensor(img)
        return super()._preprocess_input(data)
    
    def _postprocess_output(self, tensor):
        # Custom output formatting
        output = tensor.numpy()
        class_idx = output.argmax()
        confidence = output.max()
        return {
            "class": int(class_idx),
            "confidence": float(confidence),
            "probabilities": output.tolist()
        }

Production Deployment

Using Uvicorn (REST)

For production REST deployments:
import uvicorn
from neurenix.api_support import RESTfulServer

server = RESTfulServer(host="0.0.0.0", port=8000)
server.add_model("model", model)

# Access the FastAPI app
app = server.app

# Run with Uvicorn
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, workers=4)

Health Check Endpoints

Add health checks for production monitoring:
@app.get("/health")
async def health_check():
    return {"status": "healthy", "models": list(server.models.keys())}

Best Practices

  1. Use gRPC for High Throughput: Better performance than REST for high-volume inference
  2. WebSocket for Real-Time: Ideal for streaming predictions or continuous inference
  3. Enable Authentication: Add auth middleware for production deployments
  4. Monitor Performance: Track latency, throughput, and error rates
  5. Load Balancing: Deploy multiple server instances behind a load balancer
  6. Input Validation: Validate input shapes and types before inference
  7. Error Handling: Return meaningful error messages for debugging

Next Steps