Discesa gradiente in Java

1. Introduzione

In questo tutorial, impareremo l'algoritmo Gradient Descent. Implementeremo l'algoritmo in Java e lo illustreremo passo dopo passo.

2. Che cos'è la discesa del gradiente?

Gradient Descent è un algoritmo di ottimizzazione utilizzato per trovare un minimo locale di una data funzione. È ampiamente utilizzato all'interno di algoritmi di apprendimento automatico di alto livello per ridurre al minimo le funzioni di perdita.

Gradiente è un'altra parola per pendenza e discesa significa scendere. Come suggerisce il nome, Gradient Descent scende lungo la pendenza di una funzione fino a raggiungere la fine.

3. Proprietà della discesa del gradiente

Gradient Descent trova un minimo locale, che può essere diverso dal minimo globale. Il punto locale di partenza è dato come parametro all'algoritmo.

È un algoritmo iterativo e in ogni passaggio cerca di spostarsi lungo il pendio e avvicinarsi al minimo locale.

In pratica, l'algoritmo sta tornando indietro . Illustreremo e implementeremo il backtracking Gradient Descent in questo tutorial.

4. Illustrazione dettagliata

Gradient Descent necessita di una funzione e di un punto di partenza come input. Definiamo e tracciamo una funzione:

Possiamo iniziare in qualsiasi punto desiderato. Cominciamo da x = 1:

Nella prima fase, Gradient Descent scende lungo il pendio con una dimensione del gradino predefinita:

Successivamente, va oltre con la stessa dimensione del passo. Tuttavia, questa volta finisce con un y maggiore dell'ultimo passaggio:

Ciò indica che l'algoritmo ha superato il minimo locale, quindi torna indietro con una dimensione del passo ridotta:

Successivamente, ogni volta che la y corrente è maggiore della y precedente , la dimensione del passo viene abbassata e negata. L'iterazione prosegue fino a ottenere la precisione desiderata.

Come possiamo vedere, Gradient Descent ha trovato un minimo locale qui, ma non è il minimo globale. Se partiamo da x = -1 invece di x = 1, verrà trovato il minimo globale.

5. Implementazione in Java

Esistono diversi modi per implementare Gradient Descent. Qui non calcoliamo la derivata della funzione per trovare la direzione della pendenza, quindi la nostra implementazione funziona anche per le funzioni non differenziabili.

Definiamo precision e stepCoefficient e diamo loro i valori iniziali:

double precision = 0.000001; double stepCoefficient = 0.1;

Nel primo passaggio, non abbiamo un y precedente per il confronto. Possiamo aumentare o diminuire il valore di x per vedere se y diminuisce o aumenta. Un passo positivo Coefficiente significa che stiamo aumentando il valore di x .

Ora eseguiamo il primo passaggio:

double previousX = initialX; double previousY = f.apply(previousX); currentX += stepCoefficient * previousY;

Nel codice precedente, f è una funzione e l' inizialeX è un doppio , entrambi forniti come input.

Un altro punto chiave da considerare è che non è garantito che Gradient Descent converga. Per evitare di rimanere bloccati nel ciclo, stabiliamo un limite al numero di iterazioni:

int iter = 100;

Successivamente, decrementeremo l' iter di uno ad ogni iterazione. Di conseguenza, usciremo dal ciclo a un massimo di 100 iterazioni.

Ora che abbiamo una X precedente , possiamo impostare il nostro ciclo:

while (previousStep > precision && iter > 0) { iter--; double currentY = f.apply(currentX); if (currentY > previousY) { stepCoefficient = -stepCoefficient/2; } previousX = currentX; currentX += stepCoefficient * previousY; previousY = currentY; previousStep = StrictMath.abs(currentX - previousX); }

In ogni iterazione, calcoliamo la nuova y e la confrontiamo con la y precedente . Se currentY è maggiore di previousY , cambiamo direzione e diminuiamo la dimensione del passo.

Il ciclo continua fino a quando la dimensione del nostro passo è inferiore alla precisione desiderata . Infine, possiamo restituire currentX come minimo locale:

return currentX;

6. Conclusione

In questo articolo, abbiamo esaminato l'algoritmo Gradient Descent con un'illustrazione dettagliata.

Abbiamo anche implementato Gradient Descent in Java. Il codice è disponibile su GitHub.