EntityConfidence.java revision f27b1ffc67228d73326ec3426fef4c9db75cd6fd
1/* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package androidx.textclassifier; 18 19import android.os.Parcel; 20import android.os.Parcelable; 21import android.support.annotation.FloatRange; 22import android.support.annotation.NonNull; 23import android.support.annotation.RestrictTo; 24import android.support.v4.util.ArrayMap; 25import android.support.v4.util.Preconditions; 26import android.support.v4.util.SimpleArrayMap; 27 28import java.util.ArrayList; 29import java.util.Collections; 30import java.util.Comparator; 31import java.util.List; 32import java.util.Map; 33 34/** 35 * Helper object for setting and getting entity scores for classified text. 36 * 37 * @hide 38 */ 39@RestrictTo(RestrictTo.Scope.LIBRARY) 40final class EntityConfidence implements Parcelable { 41 42 private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>(); 43 private final ArrayList<String> mSortedEntities = new ArrayList<>(); 44 45 EntityConfidence() {} 46 47 EntityConfidence(@NonNull EntityConfidence source) { 48 Preconditions.checkNotNull(source); 49 mEntityConfidence.putAll((SimpleArrayMap<String, Float>) source.mEntityConfidence); 50 mSortedEntities.addAll(source.mSortedEntities); 51 } 52 53 /** 54 * Constructs an EntityConfidence from a map of entity to confidence. 55 * 56 * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. 57 * 58 * @param source a map from entity to a confidence value in the range 0 (low confidence) to 59 * 1 (high confidence). 60 */ 61 EntityConfidence(@NonNull Map<String, Float> source) { 62 Preconditions.checkNotNull(source); 63 64 // Prune non-existent entities and clamp to 1. 65 mEntityConfidence.ensureCapacity(source.size()); 66 for (Map.Entry<String, Float> it : source.entrySet()) { 67 if (it.getValue() <= 0) continue; 68 mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); 69 } 70 resetSortedEntitiesFromMap(); 71 } 72 73 /** 74 * Returns an immutable list of entities found in the classified text ordered from 75 * high confidence to low confidence. 76 */ 77 @NonNull 78 public List<String> getEntities() { 79 return Collections.unmodifiableList(mSortedEntities); 80 } 81 82 /** 83 * Returns the confidence score for the specified entity. The value ranges from 84 * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the 85 * classified text. 86 */ 87 @FloatRange(from = 0.0, to = 1.0) 88 public float getConfidenceScore(String entity) { 89 if (mEntityConfidence.containsKey(entity)) { 90 return mEntityConfidence.get(entity); 91 } 92 return 0; 93 } 94 95 @Override 96 public String toString() { 97 return mEntityConfidence.toString(); 98 } 99 100 @Override 101 public int describeContents() { 102 return 0; 103 } 104 105 @Override 106 public void writeToParcel(Parcel dest, int flags) { 107 dest.writeInt(mEntityConfidence.size()); 108 for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) { 109 dest.writeString(entry.getKey()); 110 dest.writeFloat(entry.getValue()); 111 } 112 } 113 114 public static final Parcelable.Creator<EntityConfidence> CREATOR = 115 new Parcelable.Creator<EntityConfidence>() { 116 @Override 117 public EntityConfidence createFromParcel(Parcel in) { 118 return new EntityConfidence(in); 119 } 120 121 @Override 122 public EntityConfidence[] newArray(int size) { 123 return new EntityConfidence[size]; 124 } 125 }; 126 127 private EntityConfidence(Parcel in) { 128 final int numEntities = in.readInt(); 129 mEntityConfidence.ensureCapacity(numEntities); 130 for (int i = 0; i < numEntities; ++i) { 131 mEntityConfidence.put(in.readString(), in.readFloat()); 132 } 133 resetSortedEntitiesFromMap(); 134 } 135 136 private void resetSortedEntitiesFromMap() { 137 mSortedEntities.clear(); 138 mSortedEntities.ensureCapacity(mEntityConfidence.size()); 139 mSortedEntities.addAll(mEntityConfidence.keySet()); 140 Collections.sort(mSortedEntities, new EntityConfidenceComparator()); 141 } 142 143 /** Helper to sort entities according to their confidence. */ 144 private class EntityConfidenceComparator implements Comparator<String> { 145 @Override 146 public int compare(String e1, String e2) { 147 float score1 = mEntityConfidence.get(e1); 148 float score2 = mEntityConfidence.get(e2); 149 return Float.compare(score2, score1); 150 } 151 } 152} 153