1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16package org.tensorflow.examples; 17 18import java.io.IOException; 19import java.io.PrintStream; 20import java.nio.charset.Charset; 21import java.nio.file.Files; 22import java.nio.file.Path; 23import java.nio.file.Paths; 24import java.util.Arrays; 25import java.util.List; 26import org.tensorflow.DataType; 27import org.tensorflow.Graph; 28import org.tensorflow.Output; 29import org.tensorflow.Session; 30import org.tensorflow.Tensor; 31import org.tensorflow.TensorFlow; 32import org.tensorflow.types.UInt8; 33 34/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ 35public class LabelImage { 36 private static void printUsage(PrintStream s) { 37 final String url = 38 "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; 39 s.println( 40 "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)"); 41 s.println("to label JPEG images."); 42 s.println("TensorFlow version: " + TensorFlow.version()); 43 s.println(); 44 s.println("Usage: label_image <model dir> <image file>"); 45 s.println(); 46 s.println("Where:"); 47 s.println("<model dir> is a directory containing the unzipped contents of the inception model"); 48 s.println(" (from " + url + ")"); 49 s.println("<image file> is the path to a JPEG image file"); 50 } 51 52 public static void main(String[] args) { 53 if (args.length != 2) { 54 printUsage(System.err); 55 System.exit(1); 56 } 57 String modelDir = args[0]; 58 String imageFile = args[1]; 59 60 byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb")); 61 List<String> labels = 62 readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); 63 byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile)); 64 65 try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { 66 float[] labelProbabilities = executeInceptionGraph(graphDef, image); 67 int bestLabelIdx = maxIndex(labelProbabilities); 68 System.out.println( 69 String.format("BEST MATCH: %s (%.2f%% likely)", 70 labels.get(bestLabelIdx), 71 labelProbabilities[bestLabelIdx] * 100f)); 72 } 73 } 74 75 private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { 76 try (Graph g = new Graph()) { 77 GraphBuilder b = new GraphBuilder(g); 78 // Some constants specific to the pre-trained model at: 79 // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip 80 // 81 // - The model was trained with images scaled to 224x224 pixels. 82 // - The colors, represented as R, G, B in 1-byte each were converted to 83 // float using (value - Mean)/Scale. 84 final int H = 224; 85 final int W = 224; 86 final float mean = 117f; 87 final float scale = 1f; 88 89 // Since the graph is being constructed once per execution here, we can use a constant for the 90 // input image. If the graph were to be re-used for multiple input images, a placeholder would 91 // have been more appropriate. 92 final Output<String> input = b.constant("input", imageBytes); 93 final Output<Float> output = 94 b.div( 95 b.sub( 96 b.resizeBilinear( 97 b.expandDims( 98 b.cast(b.decodeJpeg(input, 3), Float.class), 99 b.constant("make_batch", 0)), 100 b.constant("size", new int[] {H, W})), 101 b.constant("mean", mean)), 102 b.constant("scale", scale)); 103 try (Session s = new Session(g)) { 104 return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class); 105 } 106 } 107 } 108 109 private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) { 110 try (Graph g = new Graph()) { 111 g.importGraphDef(graphDef); 112 try (Session s = new Session(g); 113 Tensor<Float> result = 114 s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) { 115 final long[] rshape = result.shape(); 116 if (result.numDimensions() != 2 || rshape[0] != 1) { 117 throw new RuntimeException( 118 String.format( 119 "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", 120 Arrays.toString(rshape))); 121 } 122 int nlabels = (int) rshape[1]; 123 return result.copyTo(new float[1][nlabels])[0]; 124 } 125 } 126 } 127 128 private static int maxIndex(float[] probabilities) { 129 int best = 0; 130 for (int i = 1; i < probabilities.length; ++i) { 131 if (probabilities[i] > probabilities[best]) { 132 best = i; 133 } 134 } 135 return best; 136 } 137 138 private static byte[] readAllBytesOrExit(Path path) { 139 try { 140 return Files.readAllBytes(path); 141 } catch (IOException e) { 142 System.err.println("Failed to read [" + path + "]: " + e.getMessage()); 143 System.exit(1); 144 } 145 return null; 146 } 147 148 private static List<String> readAllLinesOrExit(Path path) { 149 try { 150 return Files.readAllLines(path, Charset.forName("UTF-8")); 151 } catch (IOException e) { 152 System.err.println("Failed to read [" + path + "]: " + e.getMessage()); 153 System.exit(0); 154 } 155 return null; 156 } 157 158 // In the fullness of time, equivalents of the methods of this class should be auto-generated from 159 // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages 160 // like Python, C++ and Go. 161 static class GraphBuilder { 162 GraphBuilder(Graph g) { 163 this.g = g; 164 } 165 166 Output<Float> div(Output<Float> x, Output<Float> y) { 167 return binaryOp("Div", x, y); 168 } 169 170 <T> Output<T> sub(Output<T> x, Output<T> y) { 171 return binaryOp("Sub", x, y); 172 } 173 174 <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) { 175 return binaryOp3("ResizeBilinear", images, size); 176 } 177 178 <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) { 179 return binaryOp3("ExpandDims", input, dim); 180 } 181 182 <T, U> Output<U> cast(Output<T> value, Class<U> type) { 183 DataType dtype = DataType.fromClass(type); 184 return g.opBuilder("Cast", "Cast") 185 .addInput(value) 186 .setAttr("DstT", dtype) 187 .build() 188 .<U>output(0); 189 } 190 191 Output<UInt8> decodeJpeg(Output<String> contents, long channels) { 192 return g.opBuilder("DecodeJpeg", "DecodeJpeg") 193 .addInput(contents) 194 .setAttr("channels", channels) 195 .build() 196 .<UInt8>output(0); 197 } 198 199 <T> Output<T> constant(String name, Object value, Class<T> type) { 200 try (Tensor<T> t = Tensor.<T>create(value, type)) { 201 return g.opBuilder("Const", name) 202 .setAttr("dtype", DataType.fromClass(type)) 203 .setAttr("value", t) 204 .build() 205 .<T>output(0); 206 } 207 } 208 Output<String> constant(String name, byte[] value) { 209 return this.constant(name, value, String.class); 210 } 211 212 Output<Integer> constant(String name, int value) { 213 return this.constant(name, value, Integer.class); 214 } 215 216 Output<Integer> constant(String name, int[] value) { 217 return this.constant(name, value, Integer.class); 218 } 219 220 Output<Float> constant(String name, float value) { 221 return this.constant(name, value, Float.class); 222 } 223 224 private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) { 225 return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); 226 } 227 228 private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) { 229 return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); 230 } 231 private Graph g; 232 } 233} 234