37  Implementacion Practica: PyTorch y JAX

Autor/a

Diego Villalba

Fecha de publicación

19 de mayo de 2026

En el capitulo anterior construimos redes neuronales desde cero utilizando exclusivamente NumPy. El objetivo de ese recorrido fue exponer con total transparencia la matematica subyacente: la propagacion hacia adelante, el calculo del gradiente mediante la regla de la cadena, la actualizacion de pesos en el descenso por gradiente estocastico y las convoluciones como operaciones lineales sobre mapas de caracteristicas. Ese ejercicio es indispensable para comprender por que los algoritmos funcionan, pero no es la forma en que se construyen sistemas de aprendizaje profundo en la practica.

El presente capitulo es el complemento practico: tomamos los mismos conceptos y los implementamos utilizando frameworks de produccion. Existen tres opciones principales segun el nivel de abstracion y el uso previsto. Scikit-learn (Pedregosa et al. 2011) ofrece un MLPClassifier que se integra con la API fit/predict estandar, suficiente para conjuntos de datos tabulares pequenos pero sin capacidad de escalar a arquitecturas personalizadas. PyTorch (Paszke et al. 2019) provee grafos computacionales dinamicos, diferenciacion automatica y una API orientada a objetos que equilibra flexibilidad con ergonomia, lo que lo convierte en el framework dominante en investigacion academica. JAX (Bradbury et al. 2018), desarrollado en Google, adopta un paradigma funcional: las funciones deben ser puras, el estado es externo y la compilacion XLA es transparente mediante jax.jit. Sobre JAX se construye Flax (Heek et al. 2023) como libreria de modulos de redes neuronales, y Optax (DeepMind et al. 2020) como libreria de optimizadores. Este capitulo desarrolla ejemplos completos en PyTorch y JAX, comparando sus filosofias y mostrando como expresar el mismo modelo en cada framework.

1 El ecosistema de aprendizaje profundo

1.1 scikit-learn: MLPClassifier

Scikit-learn (Pedregosa et al. 2011) incluye sklearn.neural_network.MLPClassifier, un perceptron multicapa que se integra sin fricciones en cualquier pipeline de sklearn. Acepta los mismos hiperparametros conceptuales que ya conocemos: numero y tamano de capas ocultas (hidden_layer_sizes), funcion de activacion, optimizador (lbfgs, sgd o adam) y tasa de aprendizaje. La interfaz es identica a la de cualquier otro clasificador de sklearn: fit para entrenar, predict para inferencia y score para exactitud. Esta uniformidad hace que sea trivial sustituirlo por un arbol de decision o una SVM en un experimento comparativo.

Las limitaciones son igualmente claras. No permite definir arquitecturas personalizadas: capas convolucionales, mecanismos de atencion, conexiones residuales o cualquier topologia no secuencial estan fuera de alcance. El entrenamiento ocurre en CPU y no escala a conjuntos de datos de imagen o texto de tamano real. La ausencia de control sobre el ciclo de entrenamiento impide aplicar tecnicas avanzadas como curriculum learning, ciclos de tasa de aprendizaje o perdidas compuestas. Para todo esto se requieren frameworks de bajo nivel.

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
import numpy as np

X, y = make_moons(n_samples=400, noise=0.2, random_state=42)
clf = MLPClassifier(hidden_layer_sizes=(16, 16), max_iter=500, random_state=42)
clf.fit(X, y)
print(f"Accuracy: {clf.score(X, y):.3f}")

1.2 PyTorch

PyTorch (Paszke et al. 2019) introdujo en 2017 el concepto de grafo computacional dinamico (define-by-run): el grafo se construye en tiempo de ejecucion a medida que se realizan operaciones sobre tensores, a diferencia del grafo estatico de TensorFlow 1.x que debia compilarse antes de ejecutarse. Esta decision de diseno tiene consecuencias profundas: el codigo de PyTorch es Python normal, se puede depurar con print y pdb, y las estructuras de control condicionales o los bucles se comportan exactamente como se esperaria.

El motor de diferenciacion automatica de PyTorch se llama Autograd. Cada tensor que tiene requires_grad=True registra las operaciones que se le aplican; al llamar .backward() sobre un escalar (la perdida), Autograd recorre el grafo en sentido inverso aplicando la regla de la cadena y acumula los gradientes en el atributo .grad de cada parametro. La clase base para todos los modelos es nn.Module, que organiza los parametros y sub-modulos, y obliga a definir un metodo forward que describe el computo. El ciclo de entrenamiento es explicito: el programador escribe el bucle, llama al paso hacia adelante, calcula la perdida, llama a .backward() y actualiza los pesos con el optimizador. Esta explicitacion es una caracteristica, no un defecto: otorga control total.

1.3 JAX

JAX (Bradbury et al. 2018) parte de una premisa diferente: las transformaciones funcionales son ciudadanos de primera clase. El nucleo de JAX son cuatro transformaciones componibles: jax.grad (diferenciacion automatica de funciones escalares), jax.jit (compilacion XLA transparente), jax.vmap (vectorizacion sobre dimensiones de lote) y jax.pmap (paralelismo en multiples dispositivos). Estas transformaciones se pueden anidar y componer libremente.

Para beneficiarse de jax.jit, las funciones deben ser puras: no pueden tener efectos secundarios ni depender de estado global mutable. El estado del modelo (los pesos) se almacena como pytrees: estructuras de datos anidadas de Python (diccionarios o listas de arrays de JAX) que el sistema puede recorrer automaticamente para aplicar transformaciones. Flax (Heek et al. 2023) provee una API de modulos similar a la de PyTorch, pero manteniendo la disciplina funcional: model.init(key, x) devuelve los parametros iniciales y model.apply(params, x) ejecuta el paso hacia adelante sin modificar ningun estado interno. Optax (DeepMind et al. 2020) proporciona optimizadores funcionales compatibles con esta convencion.

1.4 Comparativa de frameworks

La eleccion entre frameworks depende del caso de uso. Para prototipado rapido con datos tabulares, sklearn es imbatible en simplicidad. Para investigacion en arquitecturas novedosas o fine-tuning de modelos preentrenados del ecosistema HuggingFace, PyTorch es el estandar de facto. Para investigacion que requiere transformaciones funcionales, diferenciacion de orden superior o integracion nativa con TPUs de Google, JAX es la opcion mas poderosa. El siguiente radar resume estas caracteristicas de forma comparativa.

Mostrar codigo
import plotly.graph_objects as go

categories = [
    "Facilidad de uso",
    "Flexibilidad",
    "Velocidad de prototipado",
    "Escalabilidad",
    "Soporte para investigacion"
]

frameworks = {
    "scikit-learn": {"scores": [5, 2, 5, 2, 1], "color": "rgba(99,110,250,0.6)"},
    "PyTorch": {"scores": [3, 5, 4, 5, 5], "color": "rgba(239,85,59,0.6)"},
    "JAX": {"scores": [2, 5, 3, 5, 5], "color": "rgba(0,204,150,0.6)"},
}

fig = go.Figure()

for name, data in frameworks.items():
    scores = data["scores"] + [data["scores"][0]]
    cats = categories + [categories[0]]
    fig.add_trace(go.Scatterpolar(
        r=scores,
        theta=cats,
        fill="toself",
        name=name,
        line_color=data["color"],
        fillcolor=data["color"],
        opacity=0.7
    ))

fig.update_layout(
    polar=dict(
        radialaxis=dict(visible=True, range=[0, 5], tickfont=dict(size=10)),
        angularaxis=dict(tickfont=dict(size=11))
    ),
    showlegend=True,
    title=dict(
        text="Comparativa de frameworks de aprendizaje profundo",
        x=0.5,
        font=dict(size=15)
    ),
    legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5),
    height=480,
    margin=dict(t=60, b=80)
)

fig.show()
Figura 1: Comparativa de frameworks de aprendizaje profundo en cinco dimensiones.

2 PyTorch: fundamentos

2.1 Tensores y autograd

Los tensores de PyTorch son analogos a los arrays de NumPy: admiten las mismas operaciones aritmeticas, indexacion y broadcasting. La diferencia fundamental es que un tensor puede residir en GPU (mediante .to("cuda")) y puede tener activado el rastreo de gradientes mediante requires_grad=True. Cuando se activa el rastreo, cada operacion sobre el tensor crea un nodo en el grafo computacional y registra la funcion que permite calcular el gradiente local.

Para un parametro \(\theta\) y una perdida escalar \(\mathcal{L}\), la llamada loss.backward() computa:

\[ \frac{\partial \mathcal{L}}{\partial \theta} \quad \forall \theta \text{ con } \texttt{requires\_grad=True} \tag{1}\]

El gradiente se acumula en theta.grad. Es importante notar que .backward() acumula (suma) el gradiente al valor existente de .grad; por eso el ciclo de entrenamiento debe llamar optimizer.zero_grad() antes de cada paso para limpiar los gradientes del paso anterior. Si no se limpiaran, el gradiente del paso \(t\) se sumaria al del paso \(t-1\), lo que producria actualizaciones incorrectas.

Para funciones compuestas, Autograd aplica la regla de la cadena automaticamente. Consideremos \(z = (w \cdot x + b)^2\) y \(\mathcal{L} = z\). El gradiente respecto a \(w\) es \(\partial \mathcal{L} / \partial w = 2(wx + b) \cdot x\), y Autograd lo calcula recorriendo el grafo en sentido inverso nodo a nodo. La siguiente figura ilustra este grafo para valores concretos \(w = 0.5\), \(x = 2.0\), \(b = 0.3\).

Mostrar codigo
import plotly.graph_objects as go
import numpy as np

# Node positions (x, y)
nodes = {
    "x":    (0.0, 2.0),
    "w":    (0.0, 0.0),
    "b":    (0.0, 4.0),
    "mul":  (1.5, 1.0),
    "add":  (3.0, 2.0),
    "sq":   (4.5, 2.0),
    "L":    (6.0, 2.0),
}

node_labels = {
    "x":   "x = 2.0",
    "w":   "w = 0.5",
    "b":   "b = 0.3",
    "mul": "w*x = 1.0",
    "add": "z1 = 1.3",
    "sq":  "z = 1.69",
    "L":   "L = 1.69",
}

# Forward edges: (from, to)
forward_edges = [
    ("x", "mul"),
    ("w", "mul"),
    ("mul", "add"),
    ("b", "add"),
    ("add", "sq"),
    ("sq", "L"),
]

# Backward annotations: (edge, formula)
backward_annots = [
    (("x", "mul"),  "dL/dx = 2z1*w = 1.3"),
    (("w", "mul"),  "dL/dw = 2z1*x = 2.6"),
    (("b", "add"),  "dL/db = 2z1 = 2.6"),
    (("mul", "add"),"dL/dmul = 2z1 = 2.6"),
    (("add", "sq"), "dL/dadd = 2z1 = 2.6"),
    (("sq", "L"),   "dL/dsq = 1"),
]

fig = go.Figure()

# Draw forward edges (black arrows)
for (src, dst) in forward_edges:
    x0, y0 = nodes[src]
    x1, y1 = nodes[dst]
    fig.add_annotation(
        x=x1, y=y1, ax=x0, ay=y0,
        xref="x", yref="y", axref="x", ayref="y",
        showarrow=True,
        arrowhead=3, arrowsize=1.3, arrowwidth=2,
        arrowcolor="black"
    )

# Draw backward gradient annotations (orange dashed)
for ((src, dst), formula) in backward_annots:
    x0, y0 = nodes[src]
    x1, y1 = nodes[dst]
    # midpoint for label
    mx = (x0 + x1) / 2
    my = (y0 + y1) / 2 - 0.35
    fig.add_annotation(
        x=x0, y=y0, ax=x1, ay=y1,
        xref="x", yref="y", axref="x", ayref="y",
        showarrow=True,
        arrowhead=2, arrowsize=1.1, arrowwidth=1.5,
        arrowcolor="orange",
        opacity=0.85
    )
    fig.add_annotation(
        x=mx, y=my,
        text=formula,
        showarrow=False,
        font=dict(size=8, color="darkorange"),
        bgcolor="rgba(255,255,220,0.8)",
        bordercolor="orange",
        borderwidth=1
    )

# Draw nodes
node_x = [v[0] for v in nodes.values()]
node_y = [v[1] for v in nodes.values()]
node_text = [node_labels[k] for k in nodes.keys()]

fig.add_trace(go.Scatter(
    x=node_x, y=node_y,
    mode="markers+text",
    marker=dict(size=42, color="steelblue", line=dict(color="white", width=2)),
    text=node_text,
    textposition="top center",
    textfont=dict(size=9),
    hoverinfo="none"
))

# Legend items (manual)
fig.add_trace(go.Scatter(
    x=[None], y=[None], mode="lines",
    line=dict(color="black", width=2),
    name="Forward pass"
))
fig.add_trace(go.Scatter(
    x=[None], y=[None], mode="lines",
    line=dict(color="orange", width=2, dash="dash"),
    name="Backward pass (gradientes)"
))

fig.update_layout(
    title=dict(
        text="Grafo computacional y propagacion hacia atras en PyTorch",
        x=0.5, font=dict(size=14)
    ),
    xaxis=dict(visible=False, range=[-0.5, 6.8]),
    yaxis=dict(visible=False, range=[-0.8, 5.0]),
    showlegend=True,
    legend=dict(x=0.01, y=0.01),
    height=400,
    margin=dict(t=60, b=20, l=20, r=20),
    plot_bgcolor="white"
)

fig.show()
Figura 2: Grafo computacional para z = (w*x + b)^2 con propagacion hacia adelante (negro) y hacia atras (naranja).

2.2 El patron nn.Module

La clase nn.Module es la pieza central de PyTorch para definir modelos. Toda arquitectura, ya sea un MLP simple o un transformador con miles de millones de parametros, es una subclase de nn.Module. El contrato es sencillo: el constructor __init__ declara las capas y sub-modulos (que deben ser atributos de la instancia para que PyTorch los registre), y el metodo forward describe el computo que convierte una entrada en una salida.

Mostrar codigo
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        dims = [input_dim] + hidden_dims + [output_dim]
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

nn.Linear(in, out) implementa la transformacion afin \(y = xW^T + b\) con pesos inicializados por el metodo de Kaiming y bias en cero. nn.ReLU() aplica la activacion rectificadora \(\max(0, z)\) elemento a elemento. nn.Sequential encadena modulos de modo que la salida de cada uno es la entrada del siguiente, evitando la necesidad de escribir explicitamente cada paso en forward.

Dos metodos de nn.Module son especialmente utiles en la practica. model.parameters() devuelve un iterador sobre todos los tensores con requires_grad=True en el modelo (pesos y biases de todas las capas); este iterador se pasa directamente al optimizador. model.state_dict() devuelve un diccionario ordenado con los valores actuales de todos los parametros, lo que permite guardar y cargar modelos con torch.save y torch.load. Adicionalmente, model.train() y model.eval() cambian el modo del modelo: en modo evaluacion, capas como nn.Dropout desactivan el abandono aleatorio y nn.BatchNorm usa estadisticas acumuladas en lugar de estadisticas de lote, lo que es esencial para obtener predicciones deterministas y correctas.

2.3 El ciclo de entrenamiento

El ciclo de entrenamiento explicito es una de las caracteristicas mas apreciadas de PyTorch. A diferencia de los frameworks que ocultan el bucle interno (como Keras o sklearn), PyTorch expone cada paso, lo que facilita la implementacion de logicas de entrenamiento no estandar.

Mostrar codigo
# Patron ilustrativo — no ejecutar directamente
import torch
import torch.nn as nn

def training_loop_example():
    # Suponer que model, X_train, y_train ya estan definidos
    model = MLP(2, [16], 1)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss()
    n_epochs = 10

    for epoch in range(n_epochs):
        model.train()
        optimizer.zero_grad()           # limpiar gradientes acumulados
        logits = model(torch.zeros(8, 2))  # forward pass
        loss = criterion(logits, torch.zeros(8, 1))
        loss.backward()                 # backward pass: calcula gradientes
        optimizer.step()                # actualizar pesos: theta -= lr * grad

Cada linea del bucle cumple una funcion especifica. optimizer.zero_grad() es necesario porque PyTorch acumula gradientes por defecto; llamarlo al inicio de cada paso garantiza que los gradientes del epoch anterior no contaminen el actual. model.train() activa el modo entrenamiento antes del paso hacia adelante. loss.backward() recorre el grafo computacional en sentido inverso y llena .grad en todos los parametros. optimizer.step() aplica la actualizacion: para Adam, esto implica mantener momentos de primer y segundo orden y calcular la actualizacion adaptativa.

Se prefiere nn.BCEWithLogitsLoss sobre nn.BCELoss por razones numericas: la primera combina la sigmoide y la entropia cruzada en una sola operacion usando el truco log-sum-exp, lo que evita desbordamientos cuando el logit es muy grande o muy negativo. La formula interna es:

\[ \mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \left[ \max(z_i, 0) - z_i y_i + \log\left(1 + e^{-|z_i|}\right) \right] \tag{2}\]

donde \(z_i\) es el logit (salida antes de la sigmoide) e \(y_i \in \{0, 1\}\) es la etiqueta.

3 Clasificacion con PyTorch

Aplicamos el patron descrito en la seccion anterior a un problema de clasificacion binaria con make_moons. Este conjunto de datos es un clasico de prueba para fronteras de decision no lineales: dos semilunares entrelazadas que no son separables linealmente. Una red con dos capas ocultas de 32 neuronas con activacion ReLU deberia aprender la frontera correctamente.

La arquitectura es \([2 \to 32 \to 32 \to 1]\): la capa de entrada recibe las dos coordenadas \((x_1, x_2)\), las dos capas ocultas aprenden representaciones no lineales, y la capa de salida produce un unico logit que se convierte en probabilidad mediante la sigmoide. Para la clasificacion, el umbral es 0.5: si \(\sigma(z) > 0.5\) se predice clase 1, de lo contrario clase 0.

Mostrar codigo
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import make_moons
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---- Data preparation ----
X_np, y_np = make_moons(n_samples=500, noise=0.25, random_state=42)
n_train = int(0.8 * len(X_np))
idx = np.random.RandomState(42).permutation(len(X_np))
train_idx, val_idx = idx[:n_train], idx[n_train:]

X_train = torch.tensor(X_np[train_idx], dtype=torch.float32)
y_train = torch.tensor(y_np[train_idx], dtype=torch.float32).unsqueeze(1)
X_val   = torch.tensor(X_np[val_idx],   dtype=torch.float32)
y_val   = torch.tensor(y_np[val_idx],   dtype=torch.float32).unsqueeze(1)

# ---- Model definition ----
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        dims = [input_dim] + hidden_dims + [output_dim]
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

torch.manual_seed(42)
model_pt = MLP(2, [32, 32], 1)
optimizer = torch.optim.Adam(model_pt.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

# ---- Training ----
train_losses, val_losses = [], []
train_accs, val_accs = [], []

for epoch in range(300):
    model_pt.train()
    optimizer.zero_grad()
    logits = model_pt(X_train)
    loss = criterion(logits, y_train)
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        train_pred = (torch.sigmoid(logits) >= 0.5).float()
        train_acc = (train_pred == y_train).float().mean().item()

        model_pt.eval()
        val_logits = model_pt(X_val)
        val_loss = criterion(val_logits, y_val).item()
        val_pred = (torch.sigmoid(val_logits) >= 0.5).float()
        val_acc = (val_pred == y_val).float().mean().item()

    train_losses.append(loss.item())
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

epochs = list(range(1, 301))

# ---- Figure ----
fig = make_subplots(rows=1, cols=2,
    subplot_titles=["Perdida (BCE)", "Exactitud"])

fig.add_trace(go.Scatter(x=epochs, y=train_losses, name="Train loss",
    line=dict(color="steelblue", width=2)), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs, y=val_losses, name="Val loss",
    line=dict(color="salmon", width=2, dash="dash")), row=1, col=1)

fig.add_trace(go.Scatter(x=epochs, y=train_accs, name="Train acc",
    line=dict(color="steelblue", width=2), showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs, y=val_accs, name="Val acc",
    line=dict(color="salmon", width=2, dash="dash"), showlegend=False), row=1, col=2)

fig.update_xaxes(title_text="Epoch")
fig.update_yaxes(title_text="BCE Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", row=1, col=2)

fig.update_layout(
    title=dict(text="Entrenamiento del MLP en PyTorch (make_moons)", x=0.5, font=dict(size=14)),
    height=400,
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5)
)
fig.show()
Figura 3: Curvas de perdida y exactitud durante el entrenamiento del MLP en PyTorch sobre make_moons.

Las curvas de perdida muestran el comportamiento tipico de un modelo bien ajustado: la perdida de entrenamiento disminuye monotonamente mientras que la perdida de validacion converge a un valor cercano pero ligeramente superior, indicando que el modelo generaliza sin sobreajuste significativo. Las curvas de exactitud muestran que ambas fracciones (entrenamiento y validacion) alcanzan valores por encima del 90% tras los primeros 50 epochs, con una mejora marginal en los epochs restantes.

3.1 Sensibilidad a la tasa de aprendizaje

La tasa de aprendizaje \(\eta\) es el hiperparametro mas critico de cualquier red neuronal. Un valor demasiado bajo produce convergencia lenta o estancamiento en minimos locales poco favorables; un valor demasiado alto provoca oscilaciones en la perdida o incluso divergencia. La siguiente figura entrena cuatro instancias del mismo MLP con tasas de aprendizaje que cubren tres ordenes de magnitud, permitiendo observar estos regimenes de forma directa.

Mostrar codigo
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import make_moons
import plotly.graph_objects as go

X_lr, y_lr = make_moons(n_samples=500, noise=0.25, random_state=42)
n_tr_lr = int(0.8 * len(X_lr))
idx_lr = np.random.RandomState(42).permutation(len(X_lr))
X_tr_lr = torch.tensor(X_lr[idx_lr[:n_tr_lr]], dtype=torch.float32)
y_tr_lr = torch.tensor(y_lr[idx_lr[:n_tr_lr]], dtype=torch.float32).unsqueeze(1)
X_vl_lr = torch.tensor(X_lr[idx_lr[n_tr_lr:]], dtype=torch.float32)
y_vl_lr = torch.tensor(y_lr[idx_lr[n_tr_lr:]], dtype=torch.float32).unsqueeze(1)

learning_rates = [0.0005, 0.005, 0.05, 0.5]
lr_colors      = ["#636EFA", "#00CC96", "#EF553B", "#AB63FA"]
n_ep_lr        = 300
epochs_lr      = list(range(1, n_ep_lr + 1))
criterion_lr   = nn.BCEWithLogitsLoss()

fig = go.Figure()

for lr, color in zip(learning_rates, lr_colors):
    torch.manual_seed(42)
    net = nn.Sequential(
        nn.Linear(2, 32), nn.ReLU(),
        nn.Linear(32, 32), nn.ReLU(),
        nn.Linear(32, 1)
    )
    opt_lr = torch.optim.Adam(net.parameters(), lr=lr)
    vl_curve = []

    for _ in range(n_ep_lr):
        net.train()
        opt_lr.zero_grad()
        loss_lr = criterion_lr(net(X_tr_lr), y_tr_lr)
        loss_lr.backward()
        opt_lr.step()
        with torch.no_grad():
            net.eval()
            vl_curve.append(criterion_lr(net(X_vl_lr), y_vl_lr).item())

    fig.add_trace(go.Scatter(
        x=epochs_lr, y=vl_curve,
        name=f"lr = {lr}",
        line=dict(color=color, width=2)
    ))

fig.add_shape(type="line",
    x0=0, x1=n_ep_lr, y0=0.15, y1=0.15,
    line=dict(color="gray", dash="dot", width=1))
fig.add_annotation(
    x=250, y=0.12,
    text="Umbral de buen ajuste",
    showarrow=False, font=dict(size=10, color="gray"))

fig.update_layout(
    title=dict(text="Sensibilidad a la tasa de aprendizaje (val loss)", x=0.5, font=dict(size=14)),
    xaxis_title="Epoch",
    yaxis_title="Perdida de validacion (BCE)",
    legend=dict(orientation="h", yanchor="bottom", y=-0.22, xanchor="center", x=0.5),
    height=420,
    margin=dict(t=60, b=80)
)
fig.show()
Figura 4: Perdida de validacion para cuatro tasas de aprendizaje distintas. La zona optima produce una curva suave y decreciente; tasas demasiado bajas convergen lentamente y tasas demasiado altas oscilan o divergen.

La frontera de decision aprendida por la red se puede visualizar evaluando el modelo sobre una malla densa de puntos en el espacio de entrada. Cada punto de la malla recibe una probabilidad predicha \(\hat{p} = \sigma(f_\theta(x))\), y la isocurva \(\hat{p} = 0.5\) define la frontera de decision.

Mostrar codigo
import plotly.graph_objects as go
import numpy as np
import torch

# Build grid
x1_range = np.linspace(X_np[:, 0].min() - 0.5, X_np[:, 0].max() + 0.5, 200)
x2_range = np.linspace(X_np[:, 1].min() - 0.5, X_np[:, 1].max() + 0.5, 200)
xx1, xx2 = np.meshgrid(x1_range, x2_range)
grid = np.c_[xx1.ravel(), xx2.ravel()]
grid_t = torch.tensor(grid, dtype=torch.float32)

model_pt.eval()
with torch.no_grad():
    probs = torch.sigmoid(model_pt(grid_t)).numpy().reshape(200, 200)

final_val_acc = val_accs[-1]

# Scatter colors
colors_train = ["#EF553B" if c == 1 else "#636EFA" for c in y_np[train_idx]]

fig = go.Figure()

fig.add_trace(go.Contour(
    x=x1_range, y=x2_range, z=probs,
    colorscale="RdBu_r",
    opacity=0.75,
    showscale=True,
    colorbar=dict(title="P(clase=1)"),
    contours=dict(
        start=0, end=1, size=0.05,
        showlines=False
    )
))

# White decision boundary at 0.5
fig.add_trace(go.Contour(
    x=x1_range, y=x2_range, z=probs,
    showscale=False,
    contours=dict(
        start=0.5, end=0.5, size=0,
        coloring="lines"
    ),
    line=dict(color="white", width=3),
    name="Frontera (p=0.5)"
))

fig.add_trace(go.Scatter(
    x=X_np[train_idx, 0], y=X_np[train_idx, 1],
    mode="markers",
    marker=dict(
        color=y_np[train_idx].tolist(),
        colorscale=[[0, "#636EFA"], [1, "#EF553B"]],
        size=6, line=dict(color="white", width=0.5)
    ),
    name="Puntos de entrenamiento"
))

fig.add_annotation(
    x=0.02, y=0.97, xref="paper", yref="paper",
    text=f"Val accuracy: {final_val_acc:.3f}",
    showarrow=False,
    font=dict(size=12, color="white"),
    bgcolor="rgba(0,0,0,0.5)",
    bordercolor="white", borderwidth=1
)

fig.update_layout(
    title=dict(text="Frontera de decision del MLP entrenado con PyTorch", x=0.5, font=dict(size=14)),
    xaxis_title="x1", yaxis_title="x2",
    height=480,
    legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5)
)
fig.show()
Figura 5: Frontera de decision del MLP entrenado con PyTorch sobre make_moons. La isocurva blanca corresponde a probabilidad 0.5.

La frontera de decision resultante es claramente no lineal: la red ha aprendido a separar las dos lunas con una curva suave que sigue la estructura geometrica de los datos. La isocurva blanca en \(\hat{p} = 0.5\) delimita la region de incertidumbre maxima, donde el modelo considera igualmente probables ambas clases. Las regiones de color intenso corresponden a zonas de alta confianza en la prediccion.

3.2 Animacion del aprendizaje

Una de las visualizaciones mas reveladoras del proceso de entrenamiento de una red neuronal es observar como evoluciona la frontera de decision a lo largo de las epocas. Al comienzo del entrenamiento, con pesos aleatorios pequenos, la frontera es casi lineal porque las activaciones ReLU operan en su region lineal. A medida que los pesos crecen y las no linealidades se activan, la frontera adquiere curvatura progresivamente hasta ajustarse a la geometria de los datos. El deslizador permite navegar por este proceso epoch a epoch.

Mostrar codigo
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import make_moons
import plotly.graph_objects as go

# Dataset
X_an, y_an = make_moons(n_samples=500, noise=0.25, random_state=42)
n_tr_an = int(0.8 * len(X_an))
idx_an  = np.random.RandomState(42).permutation(len(X_an))
X_tr_an = torch.tensor(X_an[idx_an[:n_tr_an]], dtype=torch.float32)
y_tr_an = torch.tensor(y_an[idx_an[:n_tr_an]], dtype=torch.float32).unsqueeze(1)

# Decision boundary grid
h_an   = 0.05
x0_min, x0_max = X_an[:, 0].min() - 0.4, X_an[:, 0].max() + 0.4
x1_min, x1_max = X_an[:, 1].min() - 0.4, X_an[:, 1].max() + 0.4
gx, gy = np.meshgrid(np.arange(x0_min, x0_max, h_an),
                     np.arange(x1_min, x1_max, h_an))
grid_an = torch.tensor(np.c_[gx.ravel(), gy.ravel()], dtype=torch.float32)
gx_vals = np.arange(x0_min, x0_max, h_an)
gy_vals = np.arange(x1_min, x1_max, h_an)

# Train and capture snapshots every 10 epochs
torch.manual_seed(42)
net_an = nn.Sequential(
    nn.Linear(2, 32), nn.ReLU(),
    nn.Linear(32, 32), nn.ReLU(),
    nn.Linear(32, 1)
)
opt_an  = torch.optim.Adam(net_an.parameters(), lr=0.01)
crit_an = nn.BCEWithLogitsLoss()

snap_epochs = [0] + list(range(10, 301, 10))
snapshots_an = []

with torch.no_grad():
    net_an.eval()
    zz0 = torch.sigmoid(net_an(grid_an)).numpy().reshape(gx.shape)
snapshots_an.append((0, zz0.copy()))

for ep in range(1, 301):
    net_an.train()
    opt_an.zero_grad()
    crit_an(net_an(X_tr_an), y_tr_an).backward()
    opt_an.step()
    if ep in snap_epochs:
        with torch.no_grad():
            net_an.eval()
            zz = torch.sigmoid(net_an(grid_an)).numpy().reshape(gx.shape)
        snapshots_an.append((ep, zz.copy()))

cl0_an = X_an[y_an == 0]
cl1_an = X_an[y_an == 1]

def make_frame_data(zz):
    return [
        go.Contour(z=zz, x=gx_vals, y=gy_vals,
                   colorscale="RdBu", zmin=0, zmax=1,
                   showscale=False, opacity=0.65,
                   contours=dict(showlines=False)),
        go.Contour(z=zz, x=gx_vals, y=gy_vals,
                   colorscale=[[0, "white"], [1, "white"]],
                   zmin=0.495, zmax=0.505, showscale=False,
                   contours=dict(coloring="lines", showlines=True,
                                 start=0.5, end=0.5, size=0.01),
                   line=dict(color="white", width=2.5)),
        go.Scatter(x=cl0_an[:, 0], y=cl0_an[:, 1], mode="markers",
                   marker=dict(color="#EF553B", size=5, opacity=0.85),
                   name="Clase 0", showlegend=False),
        go.Scatter(x=cl1_an[:, 0], y=cl1_an[:, 1], mode="markers",
                   marker=dict(color="#636EFA", size=5, opacity=0.85),
                   name="Clase 1", showlegend=False),
    ]

frames_an = [
    go.Frame(data=make_frame_data(zz),
             name=str(ep),
             layout=go.Layout(title_text=f"Epoch {ep}  |  frontera de decision"))
    for ep, zz in snapshots_an
]

init_ep_an, init_zz_an = snapshots_an[0]
init_data = make_frame_data(init_zz_an)
init_data[0]["showscale"] = True
init_data[0]["colorbar"]  = dict(title="P(clase 1)", len=0.7)

fig = go.Figure(data=init_data, frames=frames_an)

slider_steps = [
    dict(args=[[str(ep)],
               dict(mode="immediate",
                    frame=dict(duration=0, redraw=True),
                    transition=dict(duration=0))],
         label=str(ep), method="animate")
    for ep, _ in snapshots_an
]

fig.update_layout(
    title=dict(text=f"Epoch 0  |  frontera de decision", x=0.5, font=dict(size=14)),
    xaxis=dict(title="x1", showgrid=False),
    yaxis=dict(title="x2", showgrid=False, scaleanchor="x"),
    height=530,
    margin=dict(t=60, b=130),
    updatemenus=[dict(
        type="buttons", showactive=False,
        y=-0.1, x=0.5, xanchor="center",
        buttons=[
            dict(label="Reproducir",
                 method="animate",
                 args=[None, dict(frame=dict(duration=180, redraw=True),
                                  fromcurrent=True, mode="immediate")]),
            dict(label="Pausa",
                 method="animate",
                 args=[[None], dict(frame=dict(duration=0, redraw=False),
                                    mode="immediate")])
        ]
    )],
    sliders=[dict(
        active=0,
        steps=slider_steps,
        currentvalue=dict(prefix="Epoch: ", font=dict(size=12)),
        pad=dict(t=45, b=10),
        len=0.9, x=0.05
    )]
)
fig.show()
Figura 6: Evolucion de la frontera de decision del MLP durante 300 epocas de entrenamiento. Al inicio la frontera es casi lineal; progresivamente adquiere la curvatura necesaria para separar las dos lunas. Usa el deslizador o los botones Reproducir/Pausa para navegar por el entrenamiento.

4 CNN con PyTorch

4.1 Arquitectura CNN en PyTorch

Para ilustrar las redes convolucionales en PyTorch utilizamos el conjunto de datos digits de sklearn, que contiene 1797 imagenes de digitos manuscritos en escala de grises de \(8 \times 8\) pixeles y 10 clases (del 0 al 9). Aunque es considerablemente mas pequeno que MNIST (\(28 \times 28\)), es suficiente para demostrar la arquitectura CNN y su entrenamiento en un tiempo razonable.

PyTorch utiliza la convencion de canales primero (channels-first): los tensores de imagen tienen forma (N, C, H, W), donde \(N\) es el tamano del lote, \(C\) es el numero de canales, \(H\) es la altura y \(W\) es el ancho. Para digits, las imagenes son escala de grises (\(C=1\)), por lo que la forma es (N, 1, 8, 8).

nn.Conv2d(in_channels, out_channels, kernel_size, padding) implementa la convolucion discreta: aplica out_channels filtros de tamano kernel_size x kernel_size a la entrada, produciendo out_channels mapas de caracteristicas. Con padding=1 y kernel_size=3, las dimensiones espaciales se conservan. nn.MaxPool2d(2) reduce las dimensiones a la mitad seleccionando el maximo en ventanas de \(2 \times 2\). nn.Flatten() convierte el tensor multidimensional en un vector 1D para conectar con capas densas. nn.Dropout(0.3) desactiva aleatoriamente el 30% de las neuronas durante el entrenamiento para regularizar.

Mostrar codigo
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import load_digits
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---- Data ----
digits = load_digits()
X_dig = digits.data.astype(np.float32) / 16.0   # normalize to [0, 1]
y_dig = digits.target.astype(np.int64)

rng = np.random.RandomState(42)
idx_dig = rng.permutation(len(X_dig))
n_tr = int(0.8 * len(X_dig))
tr_idx, vl_idx = idx_dig[:n_tr], idx_dig[n_tr:]

X_tr_t = torch.tensor(X_dig[tr_idx].reshape(-1, 1, 8, 8))
y_tr_t = torch.tensor(y_dig[tr_idx])
X_vl_t = torch.tensor(X_dig[vl_idx].reshape(-1, 1, 8, 8))
y_vl_t = torch.tensor(y_dig[vl_idx])

# ---- Model ----
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),   # (N, 16, 8, 8)
            nn.ReLU(),
            nn.MaxPool2d(2),                               # (N, 16, 4, 4)
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # (N, 32, 4, 4)
            nn.ReLU(),
            nn.MaxPool2d(2),                               # (N, 32, 2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 2 * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

torch.manual_seed(42)
cnn_pt = CNN()
opt_cnn = torch.optim.Adam(cnn_pt.parameters(), lr=1e-3)
crit_cnn = nn.CrossEntropyLoss()

cnn_tr_losses, cnn_vl_losses = [], []
cnn_tr_accs, cnn_vl_accs = [], []

for epoch in range(50):
    cnn_pt.train()
    opt_cnn.zero_grad()
    logits_cnn = cnn_pt(X_tr_t)
    loss_cnn = crit_cnn(logits_cnn, y_tr_t)
    loss_cnn.backward()
    opt_cnn.step()

    with torch.no_grad():
        tr_pred = logits_cnn.argmax(dim=1)
        tr_acc = (tr_pred == y_tr_t).float().mean().item()

        cnn_pt.eval()
        vl_logits = cnn_pt(X_vl_t)
        vl_loss = crit_cnn(vl_logits, y_vl_t).item()
        vl_pred = vl_logits.argmax(dim=1)
        vl_acc = (vl_pred == y_vl_t).float().mean().item()

    cnn_tr_losses.append(loss_cnn.item())
    cnn_vl_losses.append(vl_loss)
    cnn_tr_accs.append(tr_acc)
    cnn_vl_accs.append(vl_acc)

epochs_cnn = list(range(1, 51))

fig = make_subplots(rows=1, cols=2,
    subplot_titles=["Perdida (Cross-Entropy)", "Exactitud"])

fig.add_trace(go.Scatter(x=epochs_cnn, y=cnn_tr_losses, name="Train loss",
    line=dict(color="steelblue", width=2)), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs_cnn, y=cnn_vl_losses, name="Val loss",
    line=dict(color="salmon", width=2, dash="dash")), row=1, col=1)

fig.add_trace(go.Scatter(x=epochs_cnn, y=cnn_tr_accs, name="Train acc",
    line=dict(color="steelblue", width=2), showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs_cnn, y=cnn_vl_accs, name="Val acc",
    line=dict(color="salmon", width=2, dash="dash"), showlegend=False), row=1, col=2)

fig.update_xaxes(title_text="Epoch")
fig.update_yaxes(title_text="Cross-Entropy Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", row=1, col=2)

fig.update_layout(
    title=dict(text="Entrenamiento de la CNN en PyTorch (digits)", x=0.5, font=dict(size=14)),
    height=400,
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5)
)
fig.show()
Figura 7: Curvas de perdida y exactitud durante el entrenamiento de la CNN en PyTorch sobre el conjunto digits.

La CNN converge rapidamente en las imagenes de \(8 \times 8\) de digits. La exactitud de validacion supera el 95% en pocos epochs, lo que refleja que estas imagenes son relativamente faciles para una red convolucional: los patrones de cada digito son altamente consistentes y la resolucion baja no introduce demasiada variabilidad. El Dropout con tasa 0.3 previene el sobreajuste, como se observa en la cercania entre las curvas de entrenamiento y validacion.

Una metrica especialmente util para la clasificacion multiclase es la matriz de confusion, que muestra cuantas instancias de la clase verdadera \(i\) fueron predichas como clase \(j\). Los elementos diagonales corresponden a predicciones correctas; los elementos fuera de la diagonal revelan que pares de clases se confunden entre si, lo que puede guiar el analisis de errores y el ajuste de la arquitectura.

Mostrar codigo
import plotly.graph_objects as go
import numpy as np
import torch

# Compute confusion matrix
cnn_pt.eval()
with torch.no_grad():
    vl_preds_np = cnn_pt(X_vl_t).argmax(dim=1).numpy()
    vl_true_np  = y_vl_t.numpy()

n_classes = 10
conf_mat = np.zeros((n_classes, n_classes), dtype=int)
for t, p in zip(vl_true_np, vl_preds_np):
    conf_mat[t, p] += 1

labels = [str(i) for i in range(10)]
text_mat = [[str(conf_mat[i, j]) for j in range(n_classes)] for i in range(n_classes)]

fig = go.Figure(go.Heatmap(
    z=conf_mat.tolist(),
    x=labels, y=labels,
    colorscale="Blues",
    text=text_mat,
    texttemplate="%{text}",
    textfont=dict(size=11),
    colorbar=dict(title="Conteo")
))

fig.update_layout(
    title=dict(text="Matriz de confusion -- CNN en PyTorch (digits)", x=0.5, font=dict(size=14)),
    xaxis=dict(title="Prediccion", tickfont=dict(size=11)),
    yaxis=dict(title="Etiqueta real", tickfont=dict(size=11), autorange="reversed"),
    height=500,
    width=560
)
fig.show()
Figura 8: Matriz de confusion de la CNN entrenada en PyTorch, evaluada sobre el conjunto de validacion de digits.

La diagonal dominante en la matriz de confusion confirma el buen desempeno del modelo. Las confusiones mas frecuentes suelen ocurrir entre digitos con morfologia similar: el 4 y el 9 comparten una curvatura superior cerrada, y el 3 y el 8 tienen regiones de activacion superpuestas. Estos errores son consistentes con los que comete un clasificador humano cuando la escritura es ambigua, lo que sugiere que el modelo ha aprendido representaciones genuinamente relevantes y no artefactos espurios.

4.2 Filtros aprendidos por la primera capa convolucional

Una propiedad notable de las CNNs es que sus filtros no son disenados manualmente: emergen del proceso de optimizacion como los detectores de caracteristicas mas utiles para la tarea. En los primeros niveles de una CNN entrenada en imagenes naturales (ImageNet), los filtros suelen especializarse en bordes orientados, manchas de color y texturas de baja frecuencia. En nuestra CNN entrenada sobre digitos 8x8, la primera capa aprende 16 filtros 3x3 que capturan los patrones mas discriminativos de la escritura numerica.

Mostrar codigo
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Extract conv1 weights: shape (16, 1, 3, 3) -> (16, 3, 3)
conv1_w = cnn_pt.features[0].weight.detach().numpy()[:, 0, :, :]  # (16, 3, 3)
n_filt  = conv1_w.shape[0]  # 16

fig = make_subplots(
    rows=2, cols=8,
    subplot_titles=[f"F{i+1}" for i in range(n_filt)],
    horizontal_spacing=0.02,
    vertical_spacing=0.14
)

for i in range(n_filt):
    row = i // 8 + 1
    col = i  % 8 + 1
    w   = conv1_w[i]
    vmax = max(abs(w.min()), abs(w.max())) + 1e-6
    show_cb = (i == n_filt - 1)
    fig.add_trace(
        go.Heatmap(
            z=w,
            colorscale="RdBu",
            zmin=-vmax, zmax=vmax,
            showscale=show_cb,
            colorbar=dict(title="w", len=0.45, y=0.25, thickness=12)
                   if show_cb else None,
            xgap=1, ygap=1
        ),
        row=row, col=col
    )
    fig.update_xaxes(showticklabels=False, row=row, col=col)
    fig.update_yaxes(showticklabels=False, row=row, col=col)

fig.update_layout(
    title=dict(text="Filtros aprendidos — Conv1 (16 filtros 3x3)", x=0.5, font=dict(size=14)),
    height=310,
    margin=dict(t=55, b=15, l=10, r=30)
)
fig.show()
Figura 9: Pesos de los 16 filtros 3x3 aprendidos por la primera capa convolucional de la CNN entrenada en load_digits. Los colores rojo y azul representan pesos positivos y negativos respectivamente. Los patrones revelan detectores de bordes, esquinas y cambios de contraste que la red desarrollo de forma automatica.

5 JAX: fundamentos

5.1 El paradigma funcional

JAX (Bradbury et al. 2018) redefine la relacion entre Python y el hardware de aceleracion. En lugar de un motor de ejecucion ansiosa con estado interno (como PyTorch en modo eager), JAX trata el codigo Python como una descripcion de un programa funcional que puede ser transformado, compilado y ejecutado de forma optima. La condicion es que las funciones sean puras: para los mismos argumentos, siempre devuelven los mismos resultados y no modifican ningun estado externo.

jax.numpy es un reemplazo casi identico a NumPy en terminos de API, pero sus operaciones pueden ejecutarse en GPU o TPU sin cambio de codigo. jax.grad(f)(x) calcula el gradiente de la funcion escalar f en el punto x usando diferenciacion automatica en modo inverso, analogamente al .backward() de PyTorch pero de forma completamente funcional. jax.jit(f) compila f con XLA: la primera llamada traza la funcion y genera codigo optimizado; las llamadas subsiguientes con argumentos de la misma forma reutilizan el codigo compilado y son significativamente mas rapidas. jax.vmap(f) transforma una funcion que opera sobre un solo ejemplo en una que opera sobre un lote, sin necesidad de escribir bucles o indices de lote manualmente.

Mostrar codigo
import jax
import jax.numpy as jnp

def loss(w):
    return jnp.sum(w ** 2)

grad_loss = jax.grad(loss)
w = jnp.array([1.0, 2.0, 3.0])
print("Gradiente de ||w||^2:", grad_loss(w))   # [2., 4., 6.]
Gradiente de ||w||^2: [2. 4. 6.]

5.2 Parametros como pytrees

En JAX no existe un objeto “modelo” con estado interno. Los parametros son diccionarios de arrays de JAX que se pasan explicitamente como argumentos a la funcion de paso hacia adelante. Esta disciplina, aunque inicialmente extranya, tiene ventajas importantes: los parametros son inspeccionables en cualquier momento, se pueden guardar y cargar como diccionarios de Python, y las transformaciones de JAX pueden recorrerlos automaticamente.

Un pytree es cualquier estructura de datos de Python (diccionario, lista, tupla, o combinacion anidada) cuyos “hojas” son arrays de JAX. JAX sabe recorrer estas estructuras y aplicar transformaciones hoja a hoja. Por ejemplo, jax.grad sobre una funcion cuyo primer argumento es un diccionario de parametros devolvera un diccionario de gradientes con exactamente la misma estructura, lo que simplifica enormemente la implementacion del paso de actualizacion.

Mostrar codigo
import jax
import jax.numpy as jnp

def mlp_forward(params, x):
    for i in range(len(params["W"]) - 1):
        x = jnp.maximum(0, x @ params["W"][i] + params["b"][i])
    # last layer: no activation (logit)
    x = x @ params["W"][-1] + params["b"][-1]
    return x

5.3 Visualizacion: grad de la entropia cruzada

Para ilustrar jax.grad de forma concreta, calculamos la perdida de entropia cruzada binaria y su gradiente respecto a un peso escalar \(w\) variando en \([-3, 3]\), con entrada fija \(x = 1\) e etiqueta \(y = 1\). La perdida es:

\[ \mathcal{L}(w) = -\left[ y \log \sigma(wx) + (1-y) \log(1 - \sigma(wx)) \right] \tag{3}\]

donde \(\sigma(z) = 1 / (1 + e^{-z})\). Con \(y = x = 1\), la perdida simplifica a \(\mathcal{L}(w) = -\log \sigma(w)\), que es una funcion decreciente (a mayor \(w\), la prediccion \(\sigma(w)\) se acerca a 1 y la perdida se acerca a 0). El gradiente es \(\partial \mathcal{L} / \partial w = \sigma(w) - 1 \leq 0\), siempre negativo: el descenso por gradiente siempre incrementa \(w\) para que la prediccion se acerque a \(y=1\).

Mostrar codigo
import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

x_fixed = 1.0
y_fixed = 1.0

def bce_single(w):
    z = w * x_fixed
    sigma = jax.nn.sigmoid(z)
    eps = 1e-7
    sigma = jnp.clip(sigma, eps, 1 - eps)
    return -(y_fixed * jnp.log(sigma) + (1 - y_fixed) * jnp.log(1 - sigma))

grad_bce = jax.grad(bce_single)

w_vals = np.linspace(-3.0, 3.0, 200)
loss_vals = np.array([float(bce_single(float(w))) for w in w_vals])
grad_vals = np.array([float(grad_bce(float(w))) for w in w_vals])

fig = make_subplots(rows=2, cols=1,
    subplot_titles=["Perdida BCE", "Gradiente dL/dw"],
    shared_xaxes=True, vertical_spacing=0.12)

fig.add_trace(go.Scatter(
    x=w_vals, y=loss_vals,
    line=dict(color="steelblue", width=2.5),
    name="L(w)"
), row=1, col=1)

fig.add_trace(go.Scatter(
    x=w_vals, y=grad_vals,
    line=dict(color="darkorange", width=2.5),
    name="dL/dw"
), row=2, col=1)

fig.add_hline(y=0, line=dict(color="gray", dash="dot", width=1), row=2, col=1)

fig.update_xaxes(title_text="w", row=2, col=1)
fig.update_yaxes(title_text="Perdida", row=1, col=1)
fig.update_yaxes(title_text="Gradiente", row=2, col=1)

fig.update_layout(
    title=dict(text="jax.grad: perdida y gradiente de la entropia cruzada", x=0.5, font=dict(size=14)),
    height=480,
    showlegend=True,
    legend=dict(orientation="h", yanchor="bottom", y=-0.12, xanchor="center", x=0.5)
)
fig.show()
Figura 10: Perdida de entropia cruzada binaria y su gradiente calculado con jax.grad, para x=1, y=1.

La figura confirma la intuicion matematica: la perdida es alta y decreciente para \(w\) negativo (la red esta haciendo predicciones incorrectas con alta confianza) y se aproxima a cero a medida que \(w\) crece. El gradiente es siempre negativo, lo que significa que el descenso por gradiente siempre empuja \(w\) hacia valores mas grandes. En \(w = 0\) (inicializacion tipica), el gradiente vale exactamente \(-0.5\), lo que corresponde al caso de incertidumbre maxima donde \(\sigma(0) = 0.5\).

5.4 Derivadas de orden superior: una ventaja exclusiva de JAX

Una capacidad que distingue a JAX de otros frameworks es la diferenciacion de orden superior mediante composicion de jax.grad. En PyTorch es posible pero requiere configuraciones especiales (create_graph=True); en JAX es natural: jax.grad(jax.grad(f)) es exactamente lo que parece. Esta capacidad es fundamental para metodos de optimizacion de segundo orden (Newton-Raphson, Gauss-Newton), regularizacion basada en la traza del Hessiano, y meta-aprendizaje (donde se diferencia a traves de un bucle de entrenamiento completo).

La segunda derivada de la perdida respecto a un peso, \(\partial^2 \mathcal{L} / \partial w^2\), mide la curvatura del paisaje de optimizacion. Una curvatura alta indica que pequenos cambios en el peso producen grandes cambios en la perdida; una curvatura baja indica una region plana. Los optimizadores de segundo orden utilizan esta informacion para dar pasos adaptados a la geometria local, a diferencia de SGD que usa el mismo tamano de paso en todas las direcciones.

Mostrar codigo
import jax
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def bce_ho(w):
    sigma = jax.nn.sigmoid(w)
    return -jnp.log(jnp.clip(sigma, 1e-7, 1.0 - 1e-7))

grad1_ho = jax.grad(bce_ho)
grad2_ho = jax.grad(jax.grad(bce_ho))

w_ho = np.linspace(-5, 5, 300)
loss_ho  = np.array([float(bce_ho(jnp.array(float(w))))  for w in w_ho])
g1_ho    = np.array([float(grad1_ho(jnp.array(float(w)))) for w in w_ho])
g2_ho    = np.array([float(grad2_ho(jnp.array(float(w)))) for w in w_ho])

fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=["Perdida L(w)", "Primera derivada  dL/dw", "Segunda derivada  d2L/dw2"]
)

palette = ["#636EFA", "#EF553B", "#00CC96"]
for col, (vals, name) in enumerate(
    zip([loss_ho, g1_ho, g2_ho], ["L(w)", "dL/dw", "d2L/dw2"]), start=1
):
    fig.add_trace(
        go.Scatter(x=w_ho, y=vals, mode="lines",
                   line=dict(color=palette[col - 1], width=2.5), name=name),
        row=1, col=col
    )
    fig.add_hline(y=0, line_dash="dot", line_color="gray", line_width=1, row=1, col=col)
    fig.add_vline(x=0, line_dash="dot", line_color="gray", line_width=1, row=1, col=col)

fig.add_annotation(
    x=0, y=float(grad2_ho(jnp.array(0.0))),
    text="Curvatura max en w=0",
    showarrow=True, arrowhead=2, ax=60, ay=-30,
    font=dict(size=10), row=1, col=3
)

fig.update_xaxes(title_text="w")
fig.update_layout(
    title=dict(text="Derivadas de orden superior con jax.grad compuesto", x=0.5, font=dict(size=14)),
    showlegend=False,
    height=380,
    margin=dict(t=70, b=50)
)
fig.show()
Figura 11: Perdida, primer derivada y segunda derivada de la entropia cruzada binaria, calculadas con jax.grad anidado. La segunda derivada (curvatura) es siempre positiva, confirmando que la perdida es convexa respecto a este peso. La curvatura maxima ocurre en w=0 donde la incertidumbre es maxima.

6 Clasificacion con JAX puro

Implementamos un MLP completo en JAX puro (sin Flax) sobre el mismo conjunto make_moons usado en la seccion de PyTorch. Esto permite una comparacion directa del codigo y los resultados. La arquitectura es identica: \([2 \to 32 \to 32 \to 1]\).

En JAX puro, la inicializacion de parametros produce un diccionario Python con listas de matrices de pesos y vectores de bias. La clave de diseno es que este diccionario se pasa explicitamente a cada funcion que lo necesite. El optimizador de Optax (DeepMind et al. 2020) tambien es funcional: optimizer.init(params) devuelve el estado inicial del optimizador (momentos de Adam, contadores de paso), y optimizer.update(grads, opt_state, params) devuelve las actualizaciones y el nuevo estado, sin modificar nada in situ.

Mostrar codigo
import jax
import jax.numpy as jnp
import optax
import numpy as np
from sklearn.datasets import make_moons
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---- Data (same split as PyTorch experiment) ----
X_np, y_np = make_moons(n_samples=500, noise=0.25, random_state=42)
n_train = int(0.8 * len(X_np))
idx_jax = np.random.RandomState(42).permutation(len(X_np))
tr_idx_j, vl_idx_j = idx_jax[:n_train], idx_jax[n_train:]

X_tr_j = jnp.array(X_np[tr_idx_j], dtype=jnp.float32)
y_tr_j = jnp.array(y_np[tr_idx_j], dtype=jnp.float32)
X_vl_j = jnp.array(X_np[vl_idx_j], dtype=jnp.float32)
y_vl_j = jnp.array(y_np[vl_idx_j], dtype=jnp.float32)

# ---- Parameter initialization ----
def init_params(layer_dims, key):
    params = {"W": [], "b": []}
    for i in range(len(layer_dims) - 1):
        key, subkey = jax.random.split(key)
        fan_in = layer_dims[i]
        W = jax.random.normal(subkey, (layer_dims[i], layer_dims[i+1])) * jnp.sqrt(2.0 / fan_in)
        b = jnp.zeros(layer_dims[i+1])
        params["W"].append(W)
        params["b"].append(b)
    return params

# ---- Forward pass ----
def forward_jax(params, x):
    for i in range(len(params["W"]) - 1):
        x = x @ params["W"][i] + params["b"][i]
        x = jax.nn.relu(x)
    x = x @ params["W"][-1] + params["b"][-1]
    return jax.nn.sigmoid(x).squeeze(-1)

# ---- Loss function ----
def bce_loss_jax(params, x, y):
    yhat = forward_jax(params, x)
    eps = 1e-7
    yhat = jnp.clip(yhat, eps, 1 - eps)
    return -jnp.mean(y * jnp.log(yhat) + (1 - y) * jnp.log(1 - yhat))

# ---- Optimizer ----
optimizer_jax = optax.adam(1e-2)
key = jax.random.PRNGKey(42)
params_jax = init_params([2, 32, 32, 1], key)
opt_state_jax = optimizer_jax.init(params_jax)

# ---- JIT-compiled training step ----
@jax.jit
def train_step_jax(params, opt_state, x, y):
    loss_val, grads = jax.value_and_grad(bce_loss_jax)(params, x, y)
    updates, new_opt_state = optimizer_jax.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss_val

# ---- Training loop ----
jax_tr_losses, jax_vl_losses = [], []
jax_tr_accs, jax_vl_accs = [], []

for epoch in range(300):
    params_jax, opt_state_jax, tr_loss = train_step_jax(
        params_jax, opt_state_jax, X_tr_j, y_tr_j
    )

    tr_preds = jax.device_get(forward_jax(params_jax, X_tr_j))
    tr_acc = np.mean((tr_preds >= 0.5).astype(int) == np.array(y_tr_j))

    vl_preds = jax.device_get(forward_jax(params_jax, X_vl_j))
    vl_loss = float(bce_loss_jax(params_jax, X_vl_j, y_vl_j))
    vl_acc = np.mean((vl_preds >= 0.5).astype(int) == np.array(y_vl_j))

    jax_tr_losses.append(float(tr_loss))
    jax_vl_losses.append(vl_loss)
    jax_tr_accs.append(tr_acc)
    jax_vl_accs.append(vl_acc)

epochs_j = list(range(1, 301))

# ---- Figure ----
fig = make_subplots(rows=1, cols=2,
    subplot_titles=["Perdida (BCE)", "Exactitud"])

fig.add_trace(go.Scatter(x=epochs_j, y=jax_tr_losses, name="Train loss",
    line=dict(color="mediumseagreen", width=2)), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs_j, y=jax_vl_losses, name="Val loss",
    line=dict(color="coral", width=2, dash="dash")), row=1, col=1)

fig.add_trace(go.Scatter(x=epochs_j, y=jax_tr_accs, name="Train acc",
    line=dict(color="mediumseagreen", width=2), showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs_j, y=jax_vl_accs, name="Val acc",
    line=dict(color="coral", width=2, dash="dash"), showlegend=False), row=1, col=2)

fig.update_xaxes(title_text="Epoch")
fig.update_yaxes(title_text="BCE Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", row=1, col=2)

fig.update_layout(
    title=dict(text="Entrenamiento del MLP en JAX puro (make_moons)", x=0.5, font=dict(size=14)),
    height=400,
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5)
)
fig.show()
Figura 12: Curvas de perdida y exactitud durante el entrenamiento del MLP en JAX puro sobre make_moons.

Las curvas de entrenamiento en JAX son practicamente identicas a las de PyTorch bajo las mismas condiciones: misma arquitectura, mismo optimizador (Adam con lr=0.01), mismo conjunto de datos y misma particion. Esto ilustra que los dos frameworks, a pesar de sus filosofias radicalmente distintas, convergen a resultados equivalentes cuando los hiperparametros son los mismos. La diferencia observable es que la primera llamada a train_step_jax es mas lenta (trazado y compilacion JIT), pero las 299 iteraciones siguientes son sustancialmente mas rapidas que su equivalente en PyTorch eager.

La frontera de decision aprendida por la red JAX deberia ser visualmente indistinguible de la del MLP de PyTorch, confirmando que el aprendizaje de representaciones no esta ligado al framework sino a la arquitectura y los datos.

Mostrar codigo
import plotly.graph_objects as go
import numpy as np
import jax
import jax.numpy as jnp

# Build grid
x1r = np.linspace(X_np[:, 0].min() - 0.5, X_np[:, 0].max() + 0.5, 200)
x2r = np.linspace(X_np[:, 1].min() - 0.5, X_np[:, 1].max() + 0.5, 200)
xx1j, xx2j = np.meshgrid(x1r, x2r)
grid_j = jnp.array(np.c_[xx1j.ravel(), xx2j.ravel()], dtype=jnp.float32)

probs_j = jax.device_get(forward_jax(params_jax, grid_j)).reshape(200, 200)
final_vl_acc_j = jax_vl_accs[-1]

fig = go.Figure()

fig.add_trace(go.Contour(
    x=x1r, y=x2r, z=probs_j.tolist(),
    colorscale="Teal",
    opacity=0.75,
    showscale=True,
    colorbar=dict(title="P(clase=1)"),
    contours=dict(start=0, end=1, size=0.05, showlines=False)
))

fig.add_trace(go.Contour(
    x=x1r, y=x2r, z=probs_j.tolist(),
    showscale=False,
    contours=dict(start=0.5, end=0.5, size=0, coloring="lines"),
    line=dict(color="white", width=3),
    name="Frontera (p=0.5)"
))

fig.add_trace(go.Scatter(
    x=X_np[tr_idx_j, 0], y=X_np[tr_idx_j, 1],
    mode="markers",
    marker=dict(
        color=y_np[tr_idx_j].tolist(),
        colorscale=[[0, "#636EFA"], [1, "#EF553B"]],
        size=6, line=dict(color="white", width=0.5)
    ),
    name="Puntos de entrenamiento"
))

fig.add_annotation(
    x=0.02, y=0.97, xref="paper", yref="paper",
    text=f"Val accuracy: {final_vl_acc_j:.3f}",
    showarrow=False,
    font=dict(size=12, color="white"),
    bgcolor="rgba(0,0,0,0.5)",
    bordercolor="white", borderwidth=1
)

fig.update_layout(
    title=dict(text="Frontera de decision del MLP entrenado con JAX", x=0.5, font=dict(size=14)),
    xaxis_title="x1", yaxis_title="x2",
    height=480,
    legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5)
)
fig.show()
Figura 13: Frontera de decision del MLP entrenado con JAX puro sobre make_moons. La isocurva blanca corresponde a probabilidad 0.5.

La frontera aprendida por JAX captura la misma estructura curvilinea que la de PyTorch. Las pequenas diferencias visuales, si las hay, se deben a diferencias en la inicializacion aleatoria (JAX usa claves PRNG explicitas mientras que PyTorch usa un generador global) y posiblemente a diferencias en el orden de las operaciones de punto flotante producidas por XLA versus la cadena de herramientas de PyTorch. En ninguno de los dos casos estas diferencias son significativas para la calidad del clasificador final.

7 CNN con Flax

7.1 flax.linen.Module

Flax (Heek et al. 2023) ofrece una API de modulos de nivel superior sobre JAX que es superficialmente similar a nn.Module de PyTorch, pero mantiene el paradigma funcional en su nucleo. La diferencia clave es que una instancia de flax.linen.Module no almacena los parametros en su interior: es simplemente una descripcion del computo. Los parametros se obtienen llamando a model.init(key, x) (que devuelve un diccionario de parametros inicializados) y se pasan explicitamente a model.apply(params, x) en cada evaluacion.

El decorador @nn.compact permite definir el computo en un unico metodo __call__, declarando las capas inline en lugar de en __init__. Flax nombra automaticamente cada sublayer segun su posicion en el codigo, lo que produce un diccionario de parametros con claves como "Conv_0", "Dense_1", etc.

Una diferencia crucial respecto a PyTorch es la convencion de forma para tensores de imagen. Mientras que PyTorch usa channels-first (N, C, H, W), JAX/Flax usa channels-last (N, H, W, C). Para el conjunto digits de \(8 \times 8\) pixeles con un canal, la forma correcta en Flax es (N, 8, 8, 1) en lugar de (N, 1, 8, 8). Esta distincion es facil de olvidar al portar codigo entre frameworks y causa errores de forma confusos.

Mostrar codigo
import jax
import jax.numpy as jnp
import optax
import numpy as np
import flax.linen as fnn
from sklearn.datasets import load_digits
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---- Data: channels-last (N, H, W, C) for Flax ----
digits = load_digits()
X_dig_f = digits.data.astype(np.float32) / 16.0
y_dig_f = digits.target.astype(np.int32)

rng_f = np.random.RandomState(42)
idx_f = rng_f.permutation(len(X_dig_f))
n_tr_f = int(0.8 * len(X_dig_f))
tr_f, vl_f = idx_f[:n_tr_f], idx_f[n_tr_f:]

# Channels-last: (N, 8, 8, 1)
X_tr_f = jnp.array(X_dig_f[tr_f].reshape(-1, 8, 8, 1))
y_tr_f = jnp.array(y_dig_f[tr_f])
X_vl_f = jnp.array(X_dig_f[vl_f].reshape(-1, 8, 8, 1))
y_vl_f = jnp.array(y_dig_f[vl_f])

# ---- Flax CNN definition ----
class FlaxCNN(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Conv(features=16, kernel_size=(3, 3), padding="SAME")(x)
        x = fnn.relu(x)
        x = fnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = fnn.Conv(features=32, kernel_size=(3, 3), padding="SAME")(x)
        x = fnn.relu(x)
        x = fnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))   # flatten
        x = fnn.Dense(64)(x)
        x = fnn.relu(x)
        x = fnn.Dense(10)(x)
        return x

# ---- Initialize model ----
key_f = jax.random.PRNGKey(42)
model_flax = FlaxCNN()
params_flax = model_flax.init(key_f, X_tr_f[:1])

# ---- Optimizer ----
optimizer_flax = optax.adam(1e-3)
opt_state_flax = optimizer_flax.init(params_flax)

# ---- JIT-compiled training step ----
@jax.jit
def train_step_flax(params, opt_state, x, y):
    def loss_fn(params):
        logits = model_flax.apply(params, x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
    loss_val, grads = jax.value_and_grad(loss_fn)(params)
    updates, new_opt_state = optimizer_flax.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss_val

# ---- Training loop ----
flax_tr_losses, flax_vl_losses = [], []
flax_tr_accs, flax_vl_accs = [], []

for epoch in range(50):
    params_flax, opt_state_flax, tr_loss_f = train_step_flax(
        params_flax, opt_state_flax, X_tr_f, y_tr_f
    )

    # Validation metrics (use jax.device_get before numpy operations)
    tr_logits_f = jax.device_get(model_flax.apply(params_flax, X_tr_f))
    tr_acc_f = np.mean(tr_logits_f.argmax(axis=1) == np.array(y_tr_f))

    vl_logits_f = jax.device_get(model_flax.apply(params_flax, X_vl_f))
    vl_loss_f = float(
        optax.softmax_cross_entropy_with_integer_labels(
            jnp.array(vl_logits_f), y_vl_f
        ).mean()
    )
    vl_acc_f = np.mean(vl_logits_f.argmax(axis=1) == np.array(y_vl_f))

    flax_tr_losses.append(float(tr_loss_f))
    flax_vl_losses.append(vl_loss_f)
    flax_tr_accs.append(tr_acc_f)
    flax_vl_accs.append(vl_acc_f)

epochs_f = list(range(1, 51))

# ---- Figure ----
fig = make_subplots(rows=1, cols=2,
    subplot_titles=["Perdida (Cross-Entropy)", "Exactitud"])

fig.add_trace(go.Scatter(x=epochs_f, y=flax_tr_losses, name="Train loss",
    line=dict(color="darkorchid", width=2)), row=1, col=1)
fig.add_trace(go.Scatter(x=epochs_f, y=flax_vl_losses, name="Val loss",
    line=dict(color="goldenrod", width=2, dash="dash")), row=1, col=1)

fig.add_trace(go.Scatter(x=epochs_f, y=flax_tr_accs, name="Train acc",
    line=dict(color="darkorchid", width=2), showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=epochs_f, y=flax_vl_accs, name="Val acc",
    line=dict(color="goldenrod", width=2, dash="dash"), showlegend=False), row=1, col=2)

fig.update_xaxes(title_text="Epoch")
fig.update_yaxes(title_text="Cross-Entropy Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", row=1, col=2)

fig.update_layout(
    title=dict(text="Entrenamiento de la CNN con Flax/JAX (digits)", x=0.5, font=dict(size=14)),
    height=400,
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5)
)
fig.show()
Figura 14: Curvas de perdida y exactitud durante el entrenamiento de la CNN con Flax/JAX sobre el conjunto digits.

La CNN de Flax alcanza una exactitud de validacion comparable a la version de PyTorch en el mismo numero de epochs, lo que confirma que la capacidad del modelo es equivalente. Una diferencia observable es que usamos avg_pool en lugar de MaxPool2d: el average pooling es ligeramente menos agresivo en la seleccion de caracteristicas pero igualmente efectivo para reducir la dimension espacial. En redes mas profundas, la eleccion entre max y average pooling puede tener un impacto measurable en el rendimiento.

El flujo de trabajo con Flax merece un comentario adicional. La llamada model_flax.apply(params_flax, x) es una funcion pura: produce la misma salida para los mismos parametros y la misma entrada, sin ninguna dependencia de estado oculto. Esto facilita la verificacion formal, el testing y la reproduccion de resultados. El precio es la verbosidad: hay que pasar params_flax explicitamente en cada llamada, lo que contrasta con el model(x) de PyTorch donde los parametros estan implicitamente encapsulados en el objeto.

8 Comparativa PyTorch vs JAX

8.1 Filosofia de diseno

La diferencia mas profunda entre PyTorch y JAX no es tecnica sino filosofica. PyTorch adopta el paradigma imperativo orientado a objetos: un modelo es un objeto que encapsula sus parametros, y el ciclo de entrenamiento es un bucle Python ordinario donde el programador invoca metodos sobre objetos. Esta filosofia maximiza la ergonomia y la facilidad de debugging: se puede insertar un print o un punto de ruptura en cualquier lugar del codigo y el estado del sistema es inmediatamente inspeccionable.

JAX adopta el paradigma funcional: los modelos son funciones, los parametros son datos externos (pytrees) y las transformaciones (grad, jit, vmap) son operaciones de alto orden sobre funciones. Esta filosofia maximiza la componibilidad: cualquier transformacion se puede aplicar a cualquier funcion pura, incluyendo combinaciones como jax.jit(jax.grad(loss)) o jax.vmap(jax.grad(loss)) (gradiente por elemento para un lote de datos). El precio es la curva de aprendizaje mas empinada y la mayor dificultad para debuggear codigo dentro de una region jit-compilada (los efectos secundarios como print no funcionan dentro de jax.jit).

8.2 Flujo de trabajo tipico

La siguiente tabla resume las diferencias practicas entre los dos frameworks para las operaciones mas comunes en el desarrollo de redes neuronales.

Aspecto PyTorch JAX / Flax
Definicion del modelo nn.Module con estado flax.linen.Module + params externos
Forward pass model(x) model.apply(params, x)
Gradientes loss.backward() automatico jax.grad(loss_fn)(params) explicito
Optimizador torch.optim.Adam optax.adam
Compilacion Eager por defecto, torch.compile opcional jax.jit explicito, XLA siempre
Debugging Facil (eager, pythonic) Mas complejo (trazado estatico en jit)
Ecosistema Vasto (torchvision, HuggingFace) Creciente (Flax, Optax, Haiku, Equinox)

Ninguno de los dos frameworks es objetivamente superior. PyTorch es la mejor opcion por defecto para la mayoria de los practicantes: su ecosistema es mucho mas amplio (modelos preentrenados, datasets, herramientas de deployment), su curva de aprendizaje es mas suave y su comunidad es enorme. JAX es la opcion preferida cuando se necesita diferenciacion de orden superior, vectorizacion automatica eficiente o integracion nativa con TPUs de Google Cloud. En los ultimos anyos, proyectos como Equinox han emergido para ofrecer en JAX una experiencia de usuario mas cercana a PyTorch, mientras que torch.compile ha acercado PyTorch a la velocidad de JAX para los casos de uso comunes.

8.3 Comparativa cuantitativa

La siguiente figura presenta una comparacion subjetiva pero estructurada de las dos opciones principales en seis dimensiones relevantes para el practicante de aprendizaje profundo.

Mostrar codigo
import plotly.graph_objects as go

metrics = [
    "Curva de aprendizaje",
    "Velocidad (JIT)",
    "Flexibilidad de arq.",
    "Tamano del ecosistema",
    "Facilidad de debugging",
    "Soporte para investigacion"
]

pytorch_scores = [4, 3, 5, 5, 5, 5]
jax_scores     = [2, 5, 5, 3, 2, 5]

fig = go.Figure()

fig.add_trace(go.Bar(
    name="PyTorch",
    x=metrics,
    y=pytorch_scores,
    marker_color="steelblue",
    text=pytorch_scores,
    textposition="outside"
))

fig.add_trace(go.Bar(
    name="JAX / Flax",
    x=metrics,
    y=jax_scores,
    marker_color="mediumseagreen",
    text=jax_scores,
    textposition="outside"
))

fig.update_layout(
    barmode="group",
    title=dict(
        text="PyTorch vs JAX: comparativa de caracteristicas",
        x=0.5, font=dict(size=14)
    ),
    yaxis=dict(title="Puntuacion (1-5)", range=[0, 6.5]),
    xaxis=dict(tickfont=dict(size=11)),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5),
    height=450,
    margin=dict(t=60, b=100)
)

fig.show()
Figura 15: Comparativa de PyTorch versus JAX en seis dimensiones de interes practico (escala 1-5).

La comparativa revela los patrones complementarios de los dos frameworks. PyTorch destaca en facilidad de debugging, tamano del ecosistema y accesibilidad para nuevos usuarios (curva de aprendizaje mas suave). JAX destaca en velocidad de ejecucion gracias a la compilacion XLA, que produce codigo significativamente mas rapido en hardware especializado. Ambos frameworks son equivalentes en flexibilidad de arquitecturas (cualquier computo diferenciable es expresable en ambos) y soporte para investigacion (las dos opciones son ampliamente usadas en publicaciones de primer nivel). La puntuacion baja de JAX en “Curva de aprendizaje” no refleja una deficiencia del framework sino la naturaleza del paradigma funcional, que requiere un cambio conceptual sustancial para programadores acostumbrados al estilo imperativo.

8.4 Convergencia directa: PyTorch vs JAX

La prueba definitiva de que ambos frameworks son equivalentes en capacidad expresiva es comparar sus curvas de convergencia sobre el mismo problema, con la misma arquitectura y los mismos hiperparametros. La siguiente figura superpone la exactitud de validacion de los dos modelos MLP entrenados a lo largo de este capitulo — el de PyTorch (azul solido) y el de JAX puro (verde discontinuo) — sobre make_moons con arquitectura \([2 \to 32 \to 32 \to 1]\), optimizador Adam con \(\eta = 0.01\) y 300 epocas.

Mostrar codigo
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=list(range(1, len(val_accs) + 1)),
    y=val_accs,
    name="PyTorch",
    line=dict(color="steelblue", width=2.5)
))

fig.add_trace(go.Scatter(
    x=list(range(1, len(jax_vl_accs) + 1)),
    y=jax_vl_accs,
    name="JAX puro",
    line=dict(color="mediumseagreen", width=2.5, dash="dash")
))

pt_final  = val_accs[-1]
jax_final = jax_vl_accs[-1]

fig.add_annotation(
    x=len(val_accs) - 1, y=pt_final,
    text=f"PyTorch final: {pt_final:.3f}",
    showarrow=True, arrowhead=2, ax=-70, ay=-30,
    font=dict(color="steelblue", size=11)
)
fig.add_annotation(
    x=len(jax_vl_accs) - 1, y=jax_final,
    text=f"JAX final: {jax_final:.3f}",
    showarrow=True, arrowhead=2, ax=-70, ay=30,
    font=dict(color="mediumseagreen", size=11)
)

fig.update_layout(
    title=dict(
        text="PyTorch vs JAX: exactitud de validacion (make_moons, misma arquitectura)",
        x=0.5, font=dict(size=14)
    ),
    xaxis_title="Epoch",
    yaxis_title="Exactitud de validacion",
    yaxis=dict(range=[0.5, 1.02]),
    legend=dict(orientation="h", yanchor="bottom", y=-0.22, xanchor="center", x=0.5),
    height=410,
    margin=dict(t=60, b=80)
)
fig.show()
Figura 16: Comparacion directa de exactitud de validacion entre el MLP de PyTorch y el MLP de JAX puro sobre make_moons. Ambos modelos convergen al mismo nivel de rendimiento, confirmando que la eleccion del framework no afecta la capacidad del modelo sino unicamente el paradigma de programacion y las posibilidades de optimizacion del sistema.

9 Resumen

Este capitulo presento la implementacion practica de redes neuronales utilizando tres frameworks: scikit-learn (Pedregosa et al. 2011), PyTorch (Paszke et al. 2019) y JAX (Bradbury et al. 2018) con Flax (Heek et al. 2023) y Optax (DeepMind et al. 2020).

Los puntos clave son los siguientes. En PyTorch, el patron central es: definir una subclase de nn.Module con forward, instanciar un optimizador con model.parameters(), y escribir el bucle de entrenamiento con zero_grad() / backward() / step(). En JAX, el patron es: los parametros son pytrees externos, las funciones son puras, jax.value_and_grad computa la perdida y los gradientes en un solo paso, y @jax.jit compila el paso de entrenamiento para ejecucion eficiente. Flax provee un nivel de abstraccion sobre JAX que simplifica la definicion de arquitecturas manteniendo la disciplina funcional.

Los resultados experimentales confirman que ambos frameworks producen modelos equivalentes bajo las mismas condiciones: el MLP sobre make_moons y la CNN sobre digits alcanzan exactitudes de validacion similares con las mismas arquitecturas y optimizadores. La eleccion entre frameworks es principalmente una decision de ergonomia y ecosistema para la mayoria de los casos de uso, y una decision de rendimiento y expresividad funcional para casos avanzados.

Referencias