1/*
2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *       http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.tensorflow.demo;
18
19import android.content.Context;
20import android.content.res.AssetManager;
21import android.graphics.Bitmap;
22import android.graphics.Bitmap.Config;
23import android.graphics.BitmapFactory;
24import android.graphics.Canvas;
25import android.graphics.Color;
26import android.graphics.Matrix;
27import android.graphics.Paint;
28import android.graphics.Paint.Style;
29import android.graphics.Rect;
30import android.graphics.Typeface;
31import android.media.ImageReader.OnImageAvailableListener;
32import android.os.Bundle;
33import android.os.SystemClock;
34import android.util.Size;
35import android.util.TypedValue;
36import android.view.Display;
37import android.view.MotionEvent;
38import android.view.View;
39import android.view.View.OnClickListener;
40import android.view.View.OnTouchListener;
41import android.view.ViewGroup;
42import android.widget.BaseAdapter;
43import android.widget.Button;
44import android.widget.GridView;
45import android.widget.ImageView;
46import android.widget.Toast;
47import java.io.IOException;
48import java.io.InputStream;
49import java.util.ArrayList;
50import java.util.Collections;
51import java.util.Vector;
52import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
53import org.tensorflow.demo.OverlayView.DrawCallback;
54import org.tensorflow.demo.env.BorderedText;
55import org.tensorflow.demo.env.ImageUtils;
56import org.tensorflow.demo.env.Logger;
57import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
58
59/**
60 * Sample activity that stylizes the camera preview according to "A Learned Representation For
61 * Artistic Style" (https://arxiv.org/abs/1610.07629)
62 */
63public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
64  private static final Logger LOGGER = new Logger();
65
66  private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
67  private static final String INPUT_NODE = "input";
68  private static final String STYLE_NODE = "style_num";
69  private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
70  private static final int NUM_STYLES = 26;
71
72  private static final boolean SAVE_PREVIEW_BITMAP = false;
73
74  // Whether to actively manipulate non-selected sliders so that sum of activations always appears
75  // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
76  private static final boolean NORMALIZE_SLIDERS = true;
77
78  private static final float TEXT_SIZE_DIP = 12;
79
80  private static final boolean DEBUG_MODEL = false;
81
82  private static final int[] SIZES = {128, 192, 256, 384, 512, 720};
83
84  private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720);
85
86  // Start at a medium size, but let the user step up through smaller sizes so they don't get
87  // immediately stuck processing a large image.
88  private int desiredSizeIndex = -1;
89  private int desiredSize = 256;
90  private int initializedSize = 0;
91
92  private Integer sensorOrientation;
93
94  private long lastProcessingTimeMs;
95  private Bitmap rgbFrameBitmap = null;
96  private Bitmap croppedBitmap = null;
97  private Bitmap cropCopyBitmap = null;
98
99  private final float[] styleVals = new float[NUM_STYLES];
100  private int[] intValues;
101  private float[] floatValues;
102
103  private int frameNum = 0;
104
105  private Bitmap textureCopyBitmap;
106
107  private Matrix frameToCropTransform;
108  private Matrix cropToFrameTransform;
109
110  private BorderedText borderedText;
111
112  private TensorFlowInferenceInterface inferenceInterface;
113
114  private int lastOtherStyle = 1;
115
116  private boolean allZero = false;
117
118  private ImageGridAdapter adapter;
119  private GridView grid;
120
121  private final OnTouchListener gridTouchAdapter =
122      new OnTouchListener() {
123        ImageSlider slider = null;
124
125        @Override
126        public boolean onTouch(final View v, final MotionEvent event) {
127          switch (event.getActionMasked()) {
128            case MotionEvent.ACTION_DOWN:
129              for (int i = 0; i < NUM_STYLES; ++i) {
130                final ImageSlider child = adapter.items[i];
131                final Rect rect = new Rect();
132                child.getHitRect(rect);
133                if (rect.contains((int) event.getX(), (int) event.getY())) {
134                  slider = child;
135                  slider.setHilighted(true);
136                }
137              }
138              break;
139
140            case MotionEvent.ACTION_MOVE:
141              if (slider != null) {
142                final Rect rect = new Rect();
143                slider.getHitRect(rect);
144
145                final float newSliderVal =
146                    (float)
147                        Math.min(
148                            1.0,
149                            Math.max(
150                                0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
151
152                setStyle(slider, newSliderVal);
153              }
154              break;
155
156            case MotionEvent.ACTION_UP:
157              if (slider != null) {
158                slider.setHilighted(false);
159                slider = null;
160              }
161              break;
162
163            default: // fall out
164
165          }
166          return true;
167        }
168      };
169
170  @Override
171  public void onCreate(final Bundle savedInstanceState) {
172    super.onCreate(savedInstanceState);
173  }
174
175  @Override
176  protected int getLayoutId() {
177    return R.layout.camera_connection_fragment_stylize;
178  }
179
180  @Override
181  protected Size getDesiredPreviewFrameSize() {
182    return DESIRED_PREVIEW_SIZE;
183  }
184
185  public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
186    final AssetManager assetManager = context.getAssets();
187
188    Bitmap bitmap = null;
189    try {
190      final InputStream inputStream = assetManager.open(filePath);
191      bitmap = BitmapFactory.decodeStream(inputStream);
192    } catch (final IOException e) {
193      LOGGER.e("Error opening bitmap!", e);
194    }
195
196    return bitmap;
197  }
198
199  private class ImageSlider extends ImageView {
200    private float value = 0.0f;
201    private boolean hilighted = false;
202
203    private final Paint boxPaint;
204    private final Paint linePaint;
205
206    public ImageSlider(final Context context) {
207      super(context);
208      value = 0.0f;
209
210      boxPaint = new Paint();
211      boxPaint.setColor(Color.BLACK);
212      boxPaint.setAlpha(128);
213
214      linePaint = new Paint();
215      linePaint.setColor(Color.WHITE);
216      linePaint.setStrokeWidth(10.0f);
217      linePaint.setStyle(Style.STROKE);
218    }
219
220    @Override
221    public void onDraw(final Canvas canvas) {
222      super.onDraw(canvas);
223      final float y = (1.0f - value) * canvas.getHeight();
224
225      // If all sliders are zero, don't bother shading anything.
226      if (!allZero) {
227        canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
228      }
229
230      if (value > 0.0f) {
231        canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
232      }
233
234      if (hilighted) {
235        canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
236      }
237    }
238
239    @Override
240    protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
241      super.onMeasure(widthMeasureSpec, heightMeasureSpec);
242      setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
243    }
244
245    public void setValue(final float value) {
246      this.value = value;
247      postInvalidate();
248    }
249
250    public void setHilighted(final boolean highlighted) {
251      this.hilighted = highlighted;
252      this.postInvalidate();
253    }
254  }
255
256  private class ImageGridAdapter extends BaseAdapter {
257    final ImageSlider[] items = new ImageSlider[NUM_STYLES];
258    final ArrayList<Button> buttons = new ArrayList<>();
259
260    {
261      final Button sizeButton =
262          new Button(StylizeActivity.this) {
263            @Override
264            protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
265              super.onMeasure(widthMeasureSpec, heightMeasureSpec);
266              setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
267            }
268          };
269      sizeButton.setText("" + desiredSize);
270      sizeButton.setOnClickListener(
271          new OnClickListener() {
272            @Override
273            public void onClick(final View v) {
274              desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
275              desiredSize = SIZES[desiredSizeIndex];
276              sizeButton.setText("" + desiredSize);
277              sizeButton.postInvalidate();
278            }
279          });
280
281      final Button saveButton =
282          new Button(StylizeActivity.this) {
283            @Override
284            protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
285              super.onMeasure(widthMeasureSpec, heightMeasureSpec);
286              setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
287            }
288          };
289      saveButton.setText("save");
290      saveButton.setTextSize(12);
291
292      saveButton.setOnClickListener(
293          new OnClickListener() {
294            @Override
295            public void onClick(final View v) {
296              if (textureCopyBitmap != null) {
297                // TODO(andrewharp): Save as jpeg with guaranteed unique filename.
298                ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
299                Toast.makeText(
300                        StylizeActivity.this,
301                        "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
302                        Toast.LENGTH_LONG)
303                    .show();
304              }
305            }
306          });
307
308      buttons.add(sizeButton);
309      buttons.add(saveButton);
310
311      for (int i = 0; i < NUM_STYLES; ++i) {
312        LOGGER.v("Creating item %d", i);
313
314        if (items[i] == null) {
315          final ImageSlider slider = new ImageSlider(StylizeActivity.this);
316          final Bitmap bm =
317              getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
318          slider.setImageBitmap(bm);
319
320          items[i] = slider;
321        }
322      }
323    }
324
325    @Override
326    public int getCount() {
327      return buttons.size() + NUM_STYLES;
328    }
329
330    @Override
331    public Object getItem(final int position) {
332      if (position < buttons.size()) {
333        return buttons.get(position);
334      } else {
335        return items[position - buttons.size()];
336      }
337    }
338
339    @Override
340    public long getItemId(final int position) {
341      return getItem(position).hashCode();
342    }
343
344    @Override
345    public View getView(final int position, final View convertView, final ViewGroup parent) {
346      if (convertView != null) {
347        return convertView;
348      }
349      return (View) getItem(position);
350    }
351  }
352
353  @Override
354  public void onPreviewSizeChosen(final Size size, final int rotation) {
355    final float textSizePx = TypedValue.applyDimension(
356        TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
357    borderedText = new BorderedText(textSizePx);
358    borderedText.setTypeface(Typeface.MONOSPACE);
359
360    inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
361
362    previewWidth = size.getWidth();
363    previewHeight = size.getHeight();
364
365    final Display display = getWindowManager().getDefaultDisplay();
366    final int screenOrientation = display.getRotation();
367
368    LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
369
370    sensorOrientation = rotation + screenOrientation;
371
372    addCallback(
373        new DrawCallback() {
374          @Override
375          public void drawCallback(final Canvas canvas) {
376            renderDebug(canvas);
377          }
378        });
379
380    adapter = new ImageGridAdapter();
381    grid = (GridView) findViewById(R.id.grid_layout);
382    grid.setAdapter(adapter);
383    grid.setOnTouchListener(gridTouchAdapter);
384    setStyle(adapter.items[0], 1.0f);
385  }
386
387  private void setStyle(final ImageSlider slider, final float value) {
388    slider.setValue(value);
389
390    if (NORMALIZE_SLIDERS) {
391      // Slider vals correspond directly to the input tensor vals, and normalization is visually
392      // maintained by remanipulating non-selected sliders.
393      float otherSum = 0.0f;
394
395      for (int i = 0; i < NUM_STYLES; ++i) {
396        if (adapter.items[i] != slider) {
397          otherSum += adapter.items[i].value;
398        }
399      }
400
401      if (otherSum > 0.0) {
402        float highestOtherVal = 0;
403        final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
404        for (int i = 0; i < NUM_STYLES; ++i) {
405          final ImageSlider child = adapter.items[i];
406          if (child == slider) {
407            continue;
408          }
409          final float newVal = child.value * factor;
410          child.setValue(newVal > 0.01f ? newVal : 0.0f);
411
412          if (child.value > highestOtherVal) {
413            lastOtherStyle = i;
414            highestOtherVal = child.value;
415          }
416        }
417      } else {
418        // Everything else is 0, so just pick a suitable slider to push up when the
419        // selected one goes down.
420        if (adapter.items[lastOtherStyle] == slider) {
421          lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES;
422        }
423        adapter.items[lastOtherStyle].setValue(1.0f - value);
424      }
425    }
426
427    final boolean lastAllZero = allZero;
428    float sum = 0.0f;
429    for (int i = 0; i < NUM_STYLES; ++i) {
430      sum += adapter.items[i].value;
431    }
432    allZero = sum == 0.0f;
433
434    // Now update the values used for the input tensor. If nothing is set, mix in everything
435    // equally. Otherwise everything is normalized to sum to 1.0.
436    for (int i = 0; i < NUM_STYLES; ++i) {
437      styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
438
439      if (lastAllZero != allZero) {
440        adapter.items[i].postInvalidate();
441      }
442    }
443  }
444
445  private void resetPreviewBuffers() {
446    croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
447
448    frameToCropTransform = ImageUtils.getTransformationMatrix(
449        previewWidth, previewHeight,
450        desiredSize, desiredSize,
451        sensorOrientation, true);
452
453    cropToFrameTransform = new Matrix();
454    frameToCropTransform.invert(cropToFrameTransform);
455    intValues = new int[desiredSize * desiredSize];
456    floatValues = new float[desiredSize * desiredSize * 3];
457    initializedSize = desiredSize;
458  }
459
460  @Override
461  protected void processImage() {
462    if (desiredSize != initializedSize) {
463      LOGGER.i(
464          "Initializing at size preview size %dx%d, stylize size %d",
465          previewWidth, previewHeight, desiredSize);
466
467      rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
468      croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
469      frameToCropTransform = ImageUtils.getTransformationMatrix(
470          previewWidth, previewHeight,
471          desiredSize, desiredSize,
472          sensorOrientation, true);
473
474      cropToFrameTransform = new Matrix();
475      frameToCropTransform.invert(cropToFrameTransform);
476      intValues = new int[desiredSize * desiredSize];
477      floatValues = new float[desiredSize * desiredSize * 3];
478      initializedSize = desiredSize;
479    }
480    rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
481    final Canvas canvas = new Canvas(croppedBitmap);
482    canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
483
484    // For examining the actual TF input.
485    if (SAVE_PREVIEW_BITMAP) {
486      ImageUtils.saveBitmap(croppedBitmap);
487    }
488
489    runInBackground(
490        new Runnable() {
491          @Override
492          public void run() {
493            cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
494            final long startTime = SystemClock.uptimeMillis();
495            stylizeImage(croppedBitmap);
496            lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
497            textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
498            requestRender();
499            readyForNextImage();
500          }
501        });
502    if (desiredSize != initializedSize) {
503      resetPreviewBuffers();
504    }
505  }
506
507  private void stylizeImage(final Bitmap bitmap) {
508    ++frameNum;
509    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
510
511    if (DEBUG_MODEL) {
512      // Create a white square that steps through a black background 1 pixel per frame.
513      final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
514      final int centerY = bitmap.getHeight() / 2;
515      final int squareSize = 10;
516      for (int i = 0; i < intValues.length; ++i) {
517        final int x = i % bitmap.getWidth();
518        final int y = i / bitmap.getHeight();
519        final float val =
520            Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
521        floatValues[i * 3] = val;
522        floatValues[i * 3 + 1] = val;
523        floatValues[i * 3 + 2] = val;
524      }
525    } else {
526      for (int i = 0; i < intValues.length; ++i) {
527        final int val = intValues[i];
528        floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
529        floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
530        floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
531      }
532    }
533
534    // Copy the input data into TensorFlow.
535    LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight());
536    inferenceInterface.feed(
537        INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
538    inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);
539
540    inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());
541    inferenceInterface.fetch(OUTPUT_NODE, floatValues);
542
543    for (int i = 0; i < intValues.length; ++i) {
544      intValues[i] =
545          0xFF000000
546              | (((int) (floatValues[i * 3] * 255)) << 16)
547              | (((int) (floatValues[i * 3 + 1] * 255)) << 8)
548              | ((int) (floatValues[i * 3 + 2] * 255));
549    }
550
551    bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
552  }
553
554  private void renderDebug(final Canvas canvas) {
555    // TODO(andrewharp): move result display to its own View instead of using debug overlay.
556    final Bitmap texture = textureCopyBitmap;
557    if (texture != null) {
558      final Matrix matrix = new Matrix();
559      final float scaleFactor =
560          DEBUG_MODEL
561              ? 4.0f
562              : Math.min(
563                  (float) canvas.getWidth() / texture.getWidth(),
564                  (float) canvas.getHeight() / texture.getHeight());
565      matrix.postScale(scaleFactor, scaleFactor);
566      canvas.drawBitmap(texture, matrix, new Paint());
567    }
568
569    if (!isDebug()) {
570      return;
571    }
572
573    final Bitmap copy = cropCopyBitmap;
574    if (copy == null) {
575      return;
576    }
577
578    canvas.drawColor(0x55000000);
579
580    final Matrix matrix = new Matrix();
581    final float scaleFactor = 2;
582    matrix.postScale(scaleFactor, scaleFactor);
583    matrix.postTranslate(
584        canvas.getWidth() - copy.getWidth() * scaleFactor,
585        canvas.getHeight() - copy.getHeight() * scaleFactor);
586    canvas.drawBitmap(copy, matrix, new Paint());
587
588    final Vector<String> lines = new Vector<>();
589
590    final String[] statLines = inferenceInterface.getStatString().split("\n");
591    Collections.addAll(lines, statLines);
592
593    lines.add("");
594
595    lines.add("Frame: " + previewWidth + "x" + previewHeight);
596    lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
597    lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
598    lines.add("Rotation: " + sensorOrientation);
599    lines.add("Inference time: " + lastProcessingTimeMs + "ms");
600    lines.add("Desired size: " + desiredSize);
601    lines.add("Initialized size: " + initializedSize);
602
603    borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
604  }
605}
606