1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16package org.tensorflow.examples;
17
18import java.io.IOException;
19import java.io.PrintStream;
20import java.nio.charset.Charset;
21import java.nio.file.Files;
22import java.nio.file.Path;
23import java.nio.file.Paths;
24import java.util.Arrays;
25import java.util.List;
26import org.tensorflow.DataType;
27import org.tensorflow.Graph;
28import org.tensorflow.Output;
29import org.tensorflow.Session;
30import org.tensorflow.Tensor;
31import org.tensorflow.TensorFlow;
32import org.tensorflow.types.UInt8;
33
34/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
35public class LabelImage {
36  private static void printUsage(PrintStream s) {
37    final String url =
38        "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
39    s.println(
40        "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
41    s.println("to label JPEG images.");
42    s.println("TensorFlow version: " + TensorFlow.version());
43    s.println();
44    s.println("Usage: label_image <model dir> <image file>");
45    s.println();
46    s.println("Where:");
47    s.println("<model dir> is a directory containing the unzipped contents of the inception model");
48    s.println("            (from " + url + ")");
49    s.println("<image file> is the path to a JPEG image file");
50  }
51
52  public static void main(String[] args) {
53    if (args.length != 2) {
54      printUsage(System.err);
55      System.exit(1);
56    }
57    String modelDir = args[0];
58    String imageFile = args[1];
59
60    byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
61    List<String> labels =
62        readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
63    byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
64
65    try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
66      float[] labelProbabilities = executeInceptionGraph(graphDef, image);
67      int bestLabelIdx = maxIndex(labelProbabilities);
68      System.out.println(
69          String.format("BEST MATCH: %s (%.2f%% likely)",
70              labels.get(bestLabelIdx),
71              labelProbabilities[bestLabelIdx] * 100f));
72    }
73  }
74
75  private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
76    try (Graph g = new Graph()) {
77      GraphBuilder b = new GraphBuilder(g);
78      // Some constants specific to the pre-trained model at:
79      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
80      //
81      // - The model was trained with images scaled to 224x224 pixels.
82      // - The colors, represented as R, G, B in 1-byte each were converted to
83      //   float using (value - Mean)/Scale.
84      final int H = 224;
85      final int W = 224;
86      final float mean = 117f;
87      final float scale = 1f;
88
89      // Since the graph is being constructed once per execution here, we can use a constant for the
90      // input image. If the graph were to be re-used for multiple input images, a placeholder would
91      // have been more appropriate.
92      final Output<String> input = b.constant("input", imageBytes);
93      final Output<Float> output =
94          b.div(
95              b.sub(
96                  b.resizeBilinear(
97                      b.expandDims(
98                          b.cast(b.decodeJpeg(input, 3), Float.class),
99                          b.constant("make_batch", 0)),
100                      b.constant("size", new int[] {H, W})),
101                  b.constant("mean", mean)),
102              b.constant("scale", scale));
103      try (Session s = new Session(g)) {
104        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
105      }
106    }
107  }
108
109  private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
110    try (Graph g = new Graph()) {
111      g.importGraphDef(graphDef);
112      try (Session s = new Session(g);
113          Tensor<Float> result =
114              s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
115        final long[] rshape = result.shape();
116        if (result.numDimensions() != 2 || rshape[0] != 1) {
117          throw new RuntimeException(
118              String.format(
119                  "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
120                  Arrays.toString(rshape)));
121        }
122        int nlabels = (int) rshape[1];
123        return result.copyTo(new float[1][nlabels])[0];
124      }
125    }
126  }
127
128  private static int maxIndex(float[] probabilities) {
129    int best = 0;
130    for (int i = 1; i < probabilities.length; ++i) {
131      if (probabilities[i] > probabilities[best]) {
132        best = i;
133      }
134    }
135    return best;
136  }
137
138  private static byte[] readAllBytesOrExit(Path path) {
139    try {
140      return Files.readAllBytes(path);
141    } catch (IOException e) {
142      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
143      System.exit(1);
144    }
145    return null;
146  }
147
148  private static List<String> readAllLinesOrExit(Path path) {
149    try {
150      return Files.readAllLines(path, Charset.forName("UTF-8"));
151    } catch (IOException e) {
152      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
153      System.exit(0);
154    }
155    return null;
156  }
157
158  // In the fullness of time, equivalents of the methods of this class should be auto-generated from
159  // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
160  // like Python, C++ and Go.
161  static class GraphBuilder {
162    GraphBuilder(Graph g) {
163      this.g = g;
164    }
165
166    Output<Float> div(Output<Float> x, Output<Float> y) {
167      return binaryOp("Div", x, y);
168    }
169
170    <T> Output<T> sub(Output<T> x, Output<T> y) {
171      return binaryOp("Sub", x, y);
172    }
173
174    <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
175      return binaryOp3("ResizeBilinear", images, size);
176    }
177
178    <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
179      return binaryOp3("ExpandDims", input, dim);
180    }
181
182    <T, U> Output<U> cast(Output<T> value, Class<U> type) {
183      DataType dtype = DataType.fromClass(type);
184      return g.opBuilder("Cast", "Cast")
185          .addInput(value)
186          .setAttr("DstT", dtype)
187          .build()
188          .<U>output(0);
189    }
190
191    Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
192      return g.opBuilder("DecodeJpeg", "DecodeJpeg")
193          .addInput(contents)
194          .setAttr("channels", channels)
195          .build()
196          .<UInt8>output(0);
197    }
198
199    <T> Output<T> constant(String name, Object value, Class<T> type) {
200      try (Tensor<T> t = Tensor.<T>create(value, type)) {
201        return g.opBuilder("Const", name)
202            .setAttr("dtype", DataType.fromClass(type))
203            .setAttr("value", t)
204            .build()
205            .<T>output(0);
206      }
207    }
208    Output<String> constant(String name, byte[] value) {
209      return this.constant(name, value, String.class);
210    }
211
212    Output<Integer> constant(String name, int value) {
213      return this.constant(name, value, Integer.class);
214    }
215
216    Output<Integer> constant(String name, int[] value) {
217      return this.constant(name, value, Integer.class);
218    }
219
220    Output<Float> constant(String name, float value) {
221      return this.constant(name, value, Float.class);
222    }
223
224    private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
225      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
226    }
227
228    private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
229      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
230    }
231    private Graph g;
232  }
233}
234