Introduzione a Tensorflow per Java

1. Panoramica

TensorFlow è una libreria open source per la programmazione del flusso di dati . Questo è stato originariamente sviluppato da Google ed è disponibile per un'ampia gamma di piattaforme. Sebbene TensorFlow possa funzionare su un singolo core, può facilmente trarre vantaggio dalla disponibilità di più CPU, GPU o TPU .

In questo tutorial, esamineremo le basi di TensorFlow e come utilizzarlo in Java. Tieni presente che l'API Java di TensorFlow è un'API sperimentale e quindi non è coperta da alcuna garanzia di stabilità. Tratteremo più avanti nel tutorial i possibili casi d'uso per l'utilizzo dell'API Java di TensorFlow.

2. Nozioni di base

Il calcolo di TensorFlow ruota fondamentalmente attorno a due concetti fondamentali: Grafico e Sessione . Esaminiamoli rapidamente per ottenere lo sfondo necessario per passare attraverso il resto del tutorial.

2.1. Grafico TensorFlow

Per cominciare, comprendiamo gli elementi costitutivi fondamentali dei programmi TensorFlow. I calcoli sono rappresentati come grafici in TensorFlow . Un grafico è tipicamente un grafico aciclico diretto di operazioni e dati, ad esempio:

L'immagine sopra rappresenta il grafico computazionale per la seguente equazione:

f(x, y) = z = a*x + b*y

Un grafico computazionale TensorFlow è costituito da due elementi:

  1. Tensor: sono le unità di dati principali in TensorFlow. Sono rappresentati come i bordi in un grafico computazionale, che rappresenta il flusso di dati attraverso il grafico. Un tensore può avere una forma con qualsiasi numero di dimensioni. Il numero di dimensioni in un tensore viene solitamente indicato come il suo rango. Quindi uno scalare è un tensore di rango 0, un vettore è un tensore di rango 1, una matrice è un tensore di rango 2 e così via.
  2. Operazione: questi sono i nodi in un grafo computazionale. Si riferiscono a un'ampia varietà di calcoli che possono essere eseguiti sui tensori che alimentano l'operazione. Spesso risultano anche in tensori che emanano dall'operazione in un grafo computazionale.

2.2. TensorFlow Session

Ora, un grafico TensorFlow è un semplice schema del calcolo che in realtà non contiene valori. Tale grafico deve essere eseguito all'interno di quella che viene chiamata sessione TensorFlow affinché i tensori nel grafico vengano valutati . La sessione può richiedere un gruppo di tensori per essere valutata da un grafico come parametri di input. Quindi scorre all'indietro nel grafico ed esegue tutti i nodi necessari per valutare quei tensori.

Con questa conoscenza, ora siamo pronti per prenderlo e applicarlo all'API Java!

3. Installazione di Maven

Configureremo un rapido progetto Maven per creare ed eseguire un grafico TensorFlow in Java. Abbiamo solo bisogno della dipendenza tensorflow :

 org.tensorflow tensorflow 1.12.0 

4. Creazione del grafico

Proviamo ora a costruire il grafico di cui abbiamo discusso nella sezione precedente utilizzando l'API Java di TensorFlow. Più precisamente, per questo tutorial utilizzeremo l'API Java di TensorFlow per risolvere la funzione rappresentata dalla seguente equazione:

z = 3*x + 2*y

Il primo passo è dichiarare e inizializzare un grafico:

Graph graph = new Graph()

Ora dobbiamo definire tutte le operazioni richieste. Ricorda che le operazioni in TensorFlow consumano e producono zero o più tensori . Inoltre, ogni nodo nel grafico è un'operazione che include costanti e segnaposto. Può sembrare controintuitivo, ma sopportalo per un momento!

La classe Graph ha una funzione generica chiamata opBuilder () per costruire qualsiasi tipo di operazione su TensorFlow.

4.1. Definizione di costanti

Per cominciare, definiamo le operazioni costanti nel nostro grafico sopra. Nota che un'operazione costante avrà bisogno di un tensore per il suo valore :

Operation a = graph.opBuilder("Const", "a") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(3.0, Double.class)) .build(); Operation b = graph.opBuilder("Const", "b") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(2.0, Double.class)) .build();

Qui, abbiamo definito un funzionamento di tipo continuo, alimentando nel Tensor con doppie valori 2.0 e 3.0. All'inizio può sembrare un po 'opprimente, ma per ora è proprio così nell'API Java. Questi costrutti sono molto più concisi in linguaggi come Python.

4.2. Definizione di segnaposto

Sebbene sia necessario fornire valori alle nostre costanti, i segnaposto non hanno bisogno di un valore al momento della definizione . I valori dei segnaposto devono essere forniti quando il grafico viene eseguito all'interno di una sessione. Esamineremo quella parte più avanti nel tutorial.

Per ora, vediamo come possiamo definire i nostri segnaposto:

Operation x = graph.opBuilder("Placeholder", "x") .setAttr("dtype", DataType.fromClass(Double.class)) .build(); Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.fromClass(Double.class)) .build();

Tieni presente che non dovevamo fornire alcun valore per i nostri segnaposto. Questi valori verranno forniti come tensori durante l'esecuzione.

4.3. Definizione di funzioni

Infine, dobbiamo definire le operazioni matematiche della nostra equazione, vale a dire moltiplicazione e addizione per ottenere il risultato.

Anche in questo caso non sono altro che le operazioni in TensorFlow e Graph.opBuilder () che sono di nuovo utili:

Operation ax = graph.opBuilder("Mul", "ax") .addInput(a.output(0)) .addInput(x.output(0)) .build(); Operation by = graph.opBuilder("Mul", "by") .addInput(b.output(0)) .addInput(y.output(0)) .build(); Operation z = graph.opBuilder("Add", "z") .addInput(ax.output(0)) .addInput(by.output(0)) .build();

Qui abbiamo definito l' operazione , due per moltiplicare i nostri input e l'ultima per sommare i risultati intermedi. Nota che le operazioni qui ricevono tensori che non sono altro che l'output delle nostre operazioni precedenti.

Please note that we are getting the output Tensor from the Operation using index ‘0'. As we discussed earlier, an Operation can result in one or more Tensor and hence while retrieving a handle for it, we need to mention the index. Since we know that our operations are only returning one Tensor, ‘0' works just fine!

5. Visualizing the Graph

It is difficult to keep a tab on the graph as it grows in size. This makes it important to visualize it in some way. We can always create a hand drawing like the small graph we created previously but it is not practical for larger graphs. TensorFlow provides a utility called TensorBoard to facilitate this.

Unfortunately, Java API doesn't have the capability to generate an event file which is consumed by TensorBoard. But using APIs in Python we can generate an event file like:

writer = tf.summary.FileWriter('.') ...... writer.add_graph(tf.get_default_graph()) writer.flush()

Please do not bother if this does not make sense in the context of Java, this has been added here just for the sake of completeness and not necessary to continue rest of the tutorial.

We can now load and visualize the event file in TensorBoard like:

tensorboard --logdir .

TensorBoard comes as part of TensorFlow installation.

Note the similarity between this and the manually drawn graph earlier!

6. Working with Session

We have now created a computational graph for our simple equation in TensorFlow Java API. But how do we run it? Before addressing that, let's see what is the state of Graph we have just created at this point. If we try to print the output of our final Operation “z”:

System.out.println(z.output(0));

This will result in something like:


    

This isn't what we expected! But if we recall what we discussed earlier, this actually makes sense. The Graph we have just defined has not been run yet, so the tensors therein do not actually hold any actual value. The output above just says that this will be a Tensor of type Double.

Let's now define a Session to run our Graph:

Session sess = new Session(graph)

Finally, we are now ready to run our Graph and get the output we have been expecting:

Tensor tensor = sess.runner().fetch("z") .feed("x", Tensor.create(3.0, Double.class)) .feed("y", Tensor.create(6.0, Double.class)) .run().get(0).expect(Double.class); System.out.println(tensor.doubleValue());

So what are we doing here? It should be fairly intuitive:

  • Get a Runner from the Session
  • Define the Operation to fetch by its name “z”
  • Feed in tensors for our placeholders “x” and “y”
  • Run the Graph in the Session

And now we see the scalar output:

21.0

This is what we expected, isn't it!

7. The Use Case for Java API

At this point, TensorFlow may sound like overkill for performing basic operations. But, of course, TensorFlow is meant to run graphs much much larger than this.

Additionally, the tensors it deals with in real-world models are much larger in size and rank. These are the actual machine learning models where TensorFlow finds its real use.

It's not difficult to see that working with the core API in TensorFlow can become very cumbersome as the size of the graph increases. To this end, TensorFlow provides high-level APIs like Keras to work with complex models. Unfortunately, there is little to no official support for Keras on Java just yet.

However, we can use Python to define and train complex models either directly in TensorFlow or using high-level APIs like Keras. Subsequently, we can export a trained model and use that in Java using the TensorFlow Java API.

Now, why would we want to do something like that? This is particularly useful for situations where we want to use machine learning enabled features in existing clients running on Java. For instance, recommending caption for user images on an Android device. Nevertheless, there are several instances where we are interested in the output of a machine learning model but do not necessarily want to create and train that model in Java.

This is where TensorFlow Java API finds the bulk of its use. We'll go through how this can be achieved in the next section.

8. Using Saved Models

We'll now understand how we can save a model in TensorFlow to the file system and load that back possibly in a completely different language and platform. TensorFlow provides APIs to generate model files in a language and platform neutral structure called Protocol Buffer.

8.1. Saving Models to the File System

We'll begin by defining the same graph we created earlier in Python and saving that to the file system.

Let's see we can do this in Python:

import tensorflow as tf graph = tf.Graph() builder = tf.saved_model.builder.SavedModelBuilder('./model') with graph.as_default(): a = tf.constant(2, name="a") b = tf.constant(3, name="b") x = tf.placeholder(tf.int32, name="x") y = tf.placeholder(tf.int32, name="y") z = tf.math.add(a*x, b*y, name="z") sess = tf.Session() sess.run(z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING]) builder.save()

As the focus of this tutorial in Java, let's not pay much attention to the details of this code in Python, except for the fact that it generates a file called “saved_model.pb”. Do note in passing the brevity in defining a similar graph compared to Java!

8.2. Loading Models from the File System

We'll now load “saved_model.pb” into Java. Java TensorFlow API has SavedModelBundle to work with saved models:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("z") .feed("x", Tensor.create(3, Integer.class)) .feed("y", Tensor.create(3, Integer.class)) .run().get(0).expect(Integer.class); System.out.println(tensor.intValue());

It should by now be fairly intuitive to understand what the above code is doing. It simply loads the model graph from the protocol buffer and makes available the session therein. From there onward, we can pretty much do anything with this graph as we would have done for a locally-defined graph.

9. Conclusion

To sum up, in this tutorial we went through the basic concepts related to the TensorFlow computational graph. We saw how to use the TensorFlow Java API to create and run such a graph. Then, we talked about the use cases for the Java API with respect to TensorFlow.

In the process, we also understood how to visualize the graph using TensorBoard, and save and reload a model using Protocol Buffer.

Come sempre, il codice per gli esempi è disponibile su GitHub.