Regressione logistica in Java

1. Introduzione

La regressione logistica è uno strumento importante nella cassetta degli attrezzi per professionisti dell'apprendimento automatico (ML).

In questo tutorial, esploreremo l'idea principale alla base della regressione logistica .

Innanzitutto, iniziamo con una breve panoramica dei paradigmi e degli algoritmi ML.

2. Panoramica

Il ML ci consente di risolvere problemi che possiamo formulare in termini a misura d'uomo. Tuttavia, questo fatto può rappresentare una sfida per noi sviluppatori di software. Ci siamo abituati ad affrontare i problemi che possiamo formulare in termini informatici. Ad esempio, come esseri umani, possiamo facilmente rilevare gli oggetti su una foto o stabilire l'atmosfera di una frase. Come potremmo formulare un problema del genere per un computer?

Per trovare una soluzione, in ML esiste una fase speciale chiamata formazione . Durante questa fase, forniamo i dati di input al nostro algoritmo in modo che cerchi di trovare un insieme ottimale di parametri (i cosiddetti pesi). Più dati di input possiamo fornire all'algoritmo, più previsioni precise possiamo aspettarci da esso.

La formazione fa parte di un flusso di lavoro ML iterativo:

Cominciamo con l'acquisizione dei dati. Spesso i dati provengono da fonti diverse. Pertanto, dobbiamo renderlo dello stesso formato. Dovremmo anche controllare che il set di dati rappresenti in modo equo il dominio di studio. Se il modello non è mai stato addestrato sulle mele rosse, difficilmente può prevederlo.

Successivamente, dovremmo costruire un modello che consumerà i dati e sarà in grado di fare previsioni. In ML, non ci sono modelli predefiniti che funzionano bene in tutte le situazioni.

Durante la ricerca del modello corretto, potrebbe facilmente accadere che costruiamo un modello, lo addestriamo, vediamo le sue previsioni e scartiamo il modello perché non siamo soddisfatti delle previsioni che fa. In questo caso, dovremmo fare un passo indietro e costruire un altro modello e ripetere nuovamente il processo.

3. Paradigmi di ML

In ML, in base al tipo di dati di input che abbiamo a nostra disposizione, possiamo individuare tre paradigmi principali:

  • apprendimento supervisionato (classificazione delle immagini, riconoscimento degli oggetti, analisi del sentiment)
  • apprendimento non supervisionato (rilevamento anomalie)
  • apprendimento per rinforzo (strategie di gioco)

Il caso che descriveremo in questo tutorial appartiene all'apprendimento supervisionato.

4. Casella degli strumenti ML

In ML, c'è una serie di strumenti che possiamo applicare durante la creazione di un modello. Citiamo alcuni di loro:

  • Regressione lineare
  • Regressione logistica
  • Reti neurali
  • Supporta la macchina vettoriale
  • k-Nearest Neighbors

È possibile combinare diversi strumenti durante la creazione di un modello con un'elevata predittività. Infatti, per questo tutorial, il nostro modello utilizzerà la regressione logistica e le reti neurali.

5. Librerie ML

Anche se Java non è il linguaggio più popolare per la prototipazione di modelli ML,ha la reputazione di uno strumento affidabile per la creazione di software robusto in molte aree, incluso il ML. Pertanto, possiamo trovare librerie ML scritte in Java.

In questo contesto, possiamo menzionare la libreria standard de facto Tensorflow che ha anche una versione Java. Un altro degno di nota è una libreria di deep learning chiamata Deeplearning4j. Questo è uno strumento molto potente e lo useremo anche in questo tutorial.

6. Regressione logistica sul riconoscimento delle cifre

L'idea principale della regressione logistica è costruire un modello che preveda le etichette dei dati di input nel modo più preciso possibile.

Addestriamo il modello fino a quando la cosiddetta funzione di perdita o funzione obiettivo raggiunge un valore minimo. La funzione di perdita dipende dalle previsioni del modello effettivo e da quelle previste (le etichette dei dati di input). Il nostro obiettivo è ridurre al minimo la divergenza tra le previsioni del modello effettivo e quelle previste.

Se non siamo soddisfatti di quel valore minimo, dovremmo costruire un altro modello ed eseguire nuovamente l'addestramento.

Per vedere in azione la regressione logistica, la illustriamo sul riconoscimento delle cifre scritte a mano. Questo problema è già diventato classico. La libreria Deeplearning4j ha una serie di esempi realistici che mostrano come utilizzare la sua API. La parte relativa al codice di questo tutorial è fortemente basata su MNIST Classifier .

6.1. Dati in ingresso

Come dati di input, utilizziamo il noto database MNIST di cifre scritte a mano. Come dati di input, abbiamo immagini in scala di grigi 28 × 28 pixel. Ogni immagine ha un'etichetta naturale che è la cifra che l'immagine rappresenta:

Per stimare l'efficienza del modello che andremo a costruire, suddividiamo i dati di input in training e set di test:

DataSetIterator train = new RecordReaderDataSetIterator(...); DataSetIterator test = new RecordReaderDataSetIterator(...);

Una volta etichettate le immagini in ingresso e divise nei due set, la fase di “elaborazione dati” è terminata e si può passare alla “costruzione del modello”.

6.2. Modellismo

Come abbiamo detto, non esistono modelli che funzionino bene in ogni situazione. Tuttavia, dopo molti anni di ricerca in ML, gli scienziati hanno trovato modelli che funzionano molto bene nel riconoscimento delle cifre scritte a mano. Qui usiamo il cosiddetto modello LeNet-5.

LeNet-5 è una rete neurale costituita da una serie di strati che trasformano l'immagine di 28 × 28 pixel in un vettore a dieci dimensioni:

Il vettore di output a dieci dimensioni contiene probabilità che l'etichetta dell'immagine di input sia 0, o 1, o 2 e così via.

Ad esempio, se il vettore di output ha la seguente forma:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

significa che la probabilità che l'immagine in ingresso sia zero è 0,1, a uno è 0, a due è 0,3, ecc. Vediamo che la probabilità massima (0,3) corrisponde all'etichetta 3.

Let's dive into details of model building. We omit Java-specific details and concentrate on ML concepts.

We set up the model by creating a MultiLayerNetwork object:

MultiLayerNetwork model = new MultiLayerNetwork(config);

In its constructor, we should pass a MultiLayerConfiguration object. This is the very object that describes the geometry of the neural network. In order to define the network geometry, we should define every layer.

Let's show how we do this with the first and the second one:

ConvolutionLayer layer1 = new ConvolutionLayer .Builder(5, 5).nIn(channels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build(); SubsamplingLayer layer2 = new SubsamplingLayer .Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build();

We see that layers' definitions contain a considerable amount of ad-hoc parameters which impact significantly on the whole network performance. This is exactly where our ability to find a good model in the landscape of all ones becomes crucial.

Now, we are ready to construct the MultiLayerConfiguration object:

MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() // preparation steps .list() .layer(layer1) .layer(layer2) // other layers and final steps .build();

that we pass to the MultiLayerNetwork constructor.

6.3. Training

The model that we constructed contains 431080 parameters or weights. We're not going to give here the exact calculation of this number, but we should be aware that just the first layer has more than 24x24x20 = 11520 weights.

The training stage is as simple as:

model.fit(train); 

Initially, the 431080 parameters have some random values, but after the training, they acquire some values that determine the model performance. We may evaluate the model's predictiveness:

Evaluation eval = model.evaluate(test); logger.info(eval.stats());

The LeNet-5 model achieves quite a high accuracy of almost 99% even in just a single training iteration (epoch). If we want to achieve higher accuracy, we should make more iterations using a plain for-loop:

for (int i = 0; i < epochs; i++) { model.fit(train); train.reset(); test.reset(); } 

6.4. Prediction

Now, as we trained the model and we are happy with its predictions on the test data, we can try the model on some absolutely new input. To this end, let's create a new class MnistPrediction in which we'll load an image from a file that we select from the filesystem:

INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file); new ImagePreProcessingScaler(0, 1).transform(image);

The variable image contains our picture being reduced to 28×28 grayscale one. We can feed it to our model:

INDArray output = model.output(image);

The variable output will contain the probabilities of the image to be zero, one, two, etc.

Ora giochiamo un po 'e scriviamo una cifra 2, digitalizziamo questa immagine e alimentiamo il modello. Potremmo ottenere qualcosa del genere:

Come si vede, il componente con valore massimo 0,99 ha indice due. Significa che il modello ha riconosciuto correttamente la nostra cifra scritta a mano.

7. Conclusione

In questo tutorial, abbiamo descritto i concetti generali dell'apprendimento automatico. Abbiamo illustrato questi concetti su un esempio di regressione logistica che abbiamo applicato a un riconoscimento di cifre scritte a mano.

Come sempre, possiamo trovare gli snippet di codice corrispondenti nel nostro repository GitHub.