1/* 2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package org.tensorflow.demo; 18 19import android.content.Context; 20import android.content.res.AssetManager; 21import android.graphics.Bitmap; 22import android.graphics.Bitmap.Config; 23import android.graphics.BitmapFactory; 24import android.graphics.Canvas; 25import android.graphics.Color; 26import android.graphics.Matrix; 27import android.graphics.Paint; 28import android.graphics.Paint.Style; 29import android.graphics.Rect; 30import android.graphics.Typeface; 31import android.media.ImageReader.OnImageAvailableListener; 32import android.os.Bundle; 33import android.os.SystemClock; 34import android.util.Size; 35import android.util.TypedValue; 36import android.view.Display; 37import android.view.MotionEvent; 38import android.view.View; 39import android.view.View.OnClickListener; 40import android.view.View.OnTouchListener; 41import android.view.ViewGroup; 42import android.widget.BaseAdapter; 43import android.widget.Button; 44import android.widget.GridView; 45import android.widget.ImageView; 46import android.widget.Toast; 47import java.io.IOException; 48import java.io.InputStream; 49import java.util.ArrayList; 50import java.util.Collections; 51import java.util.Vector; 52import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 53import org.tensorflow.demo.OverlayView.DrawCallback; 54import org.tensorflow.demo.env.BorderedText; 55import org.tensorflow.demo.env.ImageUtils; 56import org.tensorflow.demo.env.Logger; 57import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. 58 59/** 60 * Sample activity that stylizes the camera preview according to "A Learned Representation For 61 * Artistic Style" (https://arxiv.org/abs/1610.07629) 62 */ 63public class StylizeActivity extends CameraActivity implements OnImageAvailableListener { 64 private static final Logger LOGGER = new Logger(); 65 66 private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb"; 67 private static final String INPUT_NODE = "input"; 68 private static final String STYLE_NODE = "style_num"; 69 private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid"; 70 private static final int NUM_STYLES = 26; 71 72 private static final boolean SAVE_PREVIEW_BITMAP = false; 73 74 // Whether to actively manipulate non-selected sliders so that sum of activations always appears 75 // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless. 76 private static final boolean NORMALIZE_SLIDERS = true; 77 78 private static final float TEXT_SIZE_DIP = 12; 79 80 private static final boolean DEBUG_MODEL = false; 81 82 private static final int[] SIZES = {128, 192, 256, 384, 512, 720}; 83 84 private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720); 85 86 // Start at a medium size, but let the user step up through smaller sizes so they don't get 87 // immediately stuck processing a large image. 88 private int desiredSizeIndex = -1; 89 private int desiredSize = 256; 90 private int initializedSize = 0; 91 92 private Integer sensorOrientation; 93 94 private long lastProcessingTimeMs; 95 private Bitmap rgbFrameBitmap = null; 96 private Bitmap croppedBitmap = null; 97 private Bitmap cropCopyBitmap = null; 98 99 private final float[] styleVals = new float[NUM_STYLES]; 100 private int[] intValues; 101 private float[] floatValues; 102 103 private int frameNum = 0; 104 105 private Bitmap textureCopyBitmap; 106 107 private Matrix frameToCropTransform; 108 private Matrix cropToFrameTransform; 109 110 private BorderedText borderedText; 111 112 private TensorFlowInferenceInterface inferenceInterface; 113 114 private int lastOtherStyle = 1; 115 116 private boolean allZero = false; 117 118 private ImageGridAdapter adapter; 119 private GridView grid; 120 121 private final OnTouchListener gridTouchAdapter = 122 new OnTouchListener() { 123 ImageSlider slider = null; 124 125 @Override 126 public boolean onTouch(final View v, final MotionEvent event) { 127 switch (event.getActionMasked()) { 128 case MotionEvent.ACTION_DOWN: 129 for (int i = 0; i < NUM_STYLES; ++i) { 130 final ImageSlider child = adapter.items[i]; 131 final Rect rect = new Rect(); 132 child.getHitRect(rect); 133 if (rect.contains((int) event.getX(), (int) event.getY())) { 134 slider = child; 135 slider.setHilighted(true); 136 } 137 } 138 break; 139 140 case MotionEvent.ACTION_MOVE: 141 if (slider != null) { 142 final Rect rect = new Rect(); 143 slider.getHitRect(rect); 144 145 final float newSliderVal = 146 (float) 147 Math.min( 148 1.0, 149 Math.max( 150 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight())); 151 152 setStyle(slider, newSliderVal); 153 } 154 break; 155 156 case MotionEvent.ACTION_UP: 157 if (slider != null) { 158 slider.setHilighted(false); 159 slider = null; 160 } 161 break; 162 163 default: // fall out 164 165 } 166 return true; 167 } 168 }; 169 170 @Override 171 public void onCreate(final Bundle savedInstanceState) { 172 super.onCreate(savedInstanceState); 173 } 174 175 @Override 176 protected int getLayoutId() { 177 return R.layout.camera_connection_fragment_stylize; 178 } 179 180 @Override 181 protected Size getDesiredPreviewFrameSize() { 182 return DESIRED_PREVIEW_SIZE; 183 } 184 185 public static Bitmap getBitmapFromAsset(final Context context, final String filePath) { 186 final AssetManager assetManager = context.getAssets(); 187 188 Bitmap bitmap = null; 189 try { 190 final InputStream inputStream = assetManager.open(filePath); 191 bitmap = BitmapFactory.decodeStream(inputStream); 192 } catch (final IOException e) { 193 LOGGER.e("Error opening bitmap!", e); 194 } 195 196 return bitmap; 197 } 198 199 private class ImageSlider extends ImageView { 200 private float value = 0.0f; 201 private boolean hilighted = false; 202 203 private final Paint boxPaint; 204 private final Paint linePaint; 205 206 public ImageSlider(final Context context) { 207 super(context); 208 value = 0.0f; 209 210 boxPaint = new Paint(); 211 boxPaint.setColor(Color.BLACK); 212 boxPaint.setAlpha(128); 213 214 linePaint = new Paint(); 215 linePaint.setColor(Color.WHITE); 216 linePaint.setStrokeWidth(10.0f); 217 linePaint.setStyle(Style.STROKE); 218 } 219 220 @Override 221 public void onDraw(final Canvas canvas) { 222 super.onDraw(canvas); 223 final float y = (1.0f - value) * canvas.getHeight(); 224 225 // If all sliders are zero, don't bother shading anything. 226 if (!allZero) { 227 canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint); 228 } 229 230 if (value > 0.0f) { 231 canvas.drawLine(0, y, canvas.getWidth(), y, linePaint); 232 } 233 234 if (hilighted) { 235 canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint); 236 } 237 } 238 239 @Override 240 protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { 241 super.onMeasure(widthMeasureSpec, heightMeasureSpec); 242 setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); 243 } 244 245 public void setValue(final float value) { 246 this.value = value; 247 postInvalidate(); 248 } 249 250 public void setHilighted(final boolean highlighted) { 251 this.hilighted = highlighted; 252 this.postInvalidate(); 253 } 254 } 255 256 private class ImageGridAdapter extends BaseAdapter { 257 final ImageSlider[] items = new ImageSlider[NUM_STYLES]; 258 final ArrayList<Button> buttons = new ArrayList<>(); 259 260 { 261 final Button sizeButton = 262 new Button(StylizeActivity.this) { 263 @Override 264 protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { 265 super.onMeasure(widthMeasureSpec, heightMeasureSpec); 266 setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); 267 } 268 }; 269 sizeButton.setText("" + desiredSize); 270 sizeButton.setOnClickListener( 271 new OnClickListener() { 272 @Override 273 public void onClick(final View v) { 274 desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length; 275 desiredSize = SIZES[desiredSizeIndex]; 276 sizeButton.setText("" + desiredSize); 277 sizeButton.postInvalidate(); 278 } 279 }); 280 281 final Button saveButton = 282 new Button(StylizeActivity.this) { 283 @Override 284 protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { 285 super.onMeasure(widthMeasureSpec, heightMeasureSpec); 286 setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); 287 } 288 }; 289 saveButton.setText("save"); 290 saveButton.setTextSize(12); 291 292 saveButton.setOnClickListener( 293 new OnClickListener() { 294 @Override 295 public void onClick(final View v) { 296 if (textureCopyBitmap != null) { 297 // TODO(andrewharp): Save as jpeg with guaranteed unique filename. 298 ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png"); 299 Toast.makeText( 300 StylizeActivity.this, 301 "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png", 302 Toast.LENGTH_LONG) 303 .show(); 304 } 305 } 306 }); 307 308 buttons.add(sizeButton); 309 buttons.add(saveButton); 310 311 for (int i = 0; i < NUM_STYLES; ++i) { 312 LOGGER.v("Creating item %d", i); 313 314 if (items[i] == null) { 315 final ImageSlider slider = new ImageSlider(StylizeActivity.this); 316 final Bitmap bm = 317 getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg"); 318 slider.setImageBitmap(bm); 319 320 items[i] = slider; 321 } 322 } 323 } 324 325 @Override 326 public int getCount() { 327 return buttons.size() + NUM_STYLES; 328 } 329 330 @Override 331 public Object getItem(final int position) { 332 if (position < buttons.size()) { 333 return buttons.get(position); 334 } else { 335 return items[position - buttons.size()]; 336 } 337 } 338 339 @Override 340 public long getItemId(final int position) { 341 return getItem(position).hashCode(); 342 } 343 344 @Override 345 public View getView(final int position, final View convertView, final ViewGroup parent) { 346 if (convertView != null) { 347 return convertView; 348 } 349 return (View) getItem(position); 350 } 351 } 352 353 @Override 354 public void onPreviewSizeChosen(final Size size, final int rotation) { 355 final float textSizePx = TypedValue.applyDimension( 356 TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); 357 borderedText = new BorderedText(textSizePx); 358 borderedText.setTypeface(Typeface.MONOSPACE); 359 360 inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE); 361 362 previewWidth = size.getWidth(); 363 previewHeight = size.getHeight(); 364 365 final Display display = getWindowManager().getDefaultDisplay(); 366 final int screenOrientation = display.getRotation(); 367 368 LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation); 369 370 sensorOrientation = rotation + screenOrientation; 371 372 addCallback( 373 new DrawCallback() { 374 @Override 375 public void drawCallback(final Canvas canvas) { 376 renderDebug(canvas); 377 } 378 }); 379 380 adapter = new ImageGridAdapter(); 381 grid = (GridView) findViewById(R.id.grid_layout); 382 grid.setAdapter(adapter); 383 grid.setOnTouchListener(gridTouchAdapter); 384 setStyle(adapter.items[0], 1.0f); 385 } 386 387 private void setStyle(final ImageSlider slider, final float value) { 388 slider.setValue(value); 389 390 if (NORMALIZE_SLIDERS) { 391 // Slider vals correspond directly to the input tensor vals, and normalization is visually 392 // maintained by remanipulating non-selected sliders. 393 float otherSum = 0.0f; 394 395 for (int i = 0; i < NUM_STYLES; ++i) { 396 if (adapter.items[i] != slider) { 397 otherSum += adapter.items[i].value; 398 } 399 } 400 401 if (otherSum > 0.0) { 402 float highestOtherVal = 0; 403 final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f; 404 for (int i = 0; i < NUM_STYLES; ++i) { 405 final ImageSlider child = adapter.items[i]; 406 if (child == slider) { 407 continue; 408 } 409 final float newVal = child.value * factor; 410 child.setValue(newVal > 0.01f ? newVal : 0.0f); 411 412 if (child.value > highestOtherVal) { 413 lastOtherStyle = i; 414 highestOtherVal = child.value; 415 } 416 } 417 } else { 418 // Everything else is 0, so just pick a suitable slider to push up when the 419 // selected one goes down. 420 if (adapter.items[lastOtherStyle] == slider) { 421 lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES; 422 } 423 adapter.items[lastOtherStyle].setValue(1.0f - value); 424 } 425 } 426 427 final boolean lastAllZero = allZero; 428 float sum = 0.0f; 429 for (int i = 0; i < NUM_STYLES; ++i) { 430 sum += adapter.items[i].value; 431 } 432 allZero = sum == 0.0f; 433 434 // Now update the values used for the input tensor. If nothing is set, mix in everything 435 // equally. Otherwise everything is normalized to sum to 1.0. 436 for (int i = 0; i < NUM_STYLES; ++i) { 437 styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum; 438 439 if (lastAllZero != allZero) { 440 adapter.items[i].postInvalidate(); 441 } 442 } 443 } 444 445 private void resetPreviewBuffers() { 446 croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888); 447 448 frameToCropTransform = ImageUtils.getTransformationMatrix( 449 previewWidth, previewHeight, 450 desiredSize, desiredSize, 451 sensorOrientation, true); 452 453 cropToFrameTransform = new Matrix(); 454 frameToCropTransform.invert(cropToFrameTransform); 455 intValues = new int[desiredSize * desiredSize]; 456 floatValues = new float[desiredSize * desiredSize * 3]; 457 initializedSize = desiredSize; 458 } 459 460 @Override 461 protected void processImage() { 462 if (desiredSize != initializedSize) { 463 LOGGER.i( 464 "Initializing at size preview size %dx%d, stylize size %d", 465 previewWidth, previewHeight, desiredSize); 466 467 rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); 468 croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888); 469 frameToCropTransform = ImageUtils.getTransformationMatrix( 470 previewWidth, previewHeight, 471 desiredSize, desiredSize, 472 sensorOrientation, true); 473 474 cropToFrameTransform = new Matrix(); 475 frameToCropTransform.invert(cropToFrameTransform); 476 intValues = new int[desiredSize * desiredSize]; 477 floatValues = new float[desiredSize * desiredSize * 3]; 478 initializedSize = desiredSize; 479 } 480 rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); 481 final Canvas canvas = new Canvas(croppedBitmap); 482 canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); 483 484 // For examining the actual TF input. 485 if (SAVE_PREVIEW_BITMAP) { 486 ImageUtils.saveBitmap(croppedBitmap); 487 } 488 489 runInBackground( 490 new Runnable() { 491 @Override 492 public void run() { 493 cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); 494 final long startTime = SystemClock.uptimeMillis(); 495 stylizeImage(croppedBitmap); 496 lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; 497 textureCopyBitmap = Bitmap.createBitmap(croppedBitmap); 498 requestRender(); 499 readyForNextImage(); 500 } 501 }); 502 if (desiredSize != initializedSize) { 503 resetPreviewBuffers(); 504 } 505 } 506 507 private void stylizeImage(final Bitmap bitmap) { 508 ++frameNum; 509 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 510 511 if (DEBUG_MODEL) { 512 // Create a white square that steps through a black background 1 pixel per frame. 513 final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth(); 514 final int centerY = bitmap.getHeight() / 2; 515 final int squareSize = 10; 516 for (int i = 0; i < intValues.length; ++i) { 517 final int x = i % bitmap.getWidth(); 518 final int y = i / bitmap.getHeight(); 519 final float val = 520 Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f; 521 floatValues[i * 3] = val; 522 floatValues[i * 3 + 1] = val; 523 floatValues[i * 3 + 2] = val; 524 } 525 } else { 526 for (int i = 0; i < intValues.length; ++i) { 527 final int val = intValues[i]; 528 floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f; 529 floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f; 530 floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f; 531 } 532 } 533 534 // Copy the input data into TensorFlow. 535 LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight()); 536 inferenceInterface.feed( 537 INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3); 538 inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES); 539 540 inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug()); 541 inferenceInterface.fetch(OUTPUT_NODE, floatValues); 542 543 for (int i = 0; i < intValues.length; ++i) { 544 intValues[i] = 545 0xFF000000 546 | (((int) (floatValues[i * 3] * 255)) << 16) 547 | (((int) (floatValues[i * 3 + 1] * 255)) << 8) 548 | ((int) (floatValues[i * 3 + 2] * 255)); 549 } 550 551 bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 552 } 553 554 private void renderDebug(final Canvas canvas) { 555 // TODO(andrewharp): move result display to its own View instead of using debug overlay. 556 final Bitmap texture = textureCopyBitmap; 557 if (texture != null) { 558 final Matrix matrix = new Matrix(); 559 final float scaleFactor = 560 DEBUG_MODEL 561 ? 4.0f 562 : Math.min( 563 (float) canvas.getWidth() / texture.getWidth(), 564 (float) canvas.getHeight() / texture.getHeight()); 565 matrix.postScale(scaleFactor, scaleFactor); 566 canvas.drawBitmap(texture, matrix, new Paint()); 567 } 568 569 if (!isDebug()) { 570 return; 571 } 572 573 final Bitmap copy = cropCopyBitmap; 574 if (copy == null) { 575 return; 576 } 577 578 canvas.drawColor(0x55000000); 579 580 final Matrix matrix = new Matrix(); 581 final float scaleFactor = 2; 582 matrix.postScale(scaleFactor, scaleFactor); 583 matrix.postTranslate( 584 canvas.getWidth() - copy.getWidth() * scaleFactor, 585 canvas.getHeight() - copy.getHeight() * scaleFactor); 586 canvas.drawBitmap(copy, matrix, new Paint()); 587 588 final Vector<String> lines = new Vector<>(); 589 590 final String[] statLines = inferenceInterface.getStatString().split("\n"); 591 Collections.addAll(lines, statLines); 592 593 lines.add(""); 594 595 lines.add("Frame: " + previewWidth + "x" + previewHeight); 596 lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); 597 lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); 598 lines.add("Rotation: " + sensorOrientation); 599 lines.add("Inference time: " + lastProcessingTimeMs + "ms"); 600 lines.add("Desired size: " + desiredSize); 601 lines.add("Initialized size: " + initializedSize); 602 603 borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); 604 } 605} 606