1/* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package androidx.textclassifier; 18 19import android.os.Parcel; 20import android.os.Parcelable; 21 22import androidx.annotation.FloatRange; 23import androidx.annotation.NonNull; 24import androidx.annotation.RestrictTo; 25import androidx.collection.ArrayMap; 26import androidx.collection.SimpleArrayMap; 27import androidx.core.util.Preconditions; 28 29import java.util.ArrayList; 30import java.util.Collections; 31import java.util.Comparator; 32import java.util.List; 33import java.util.Map; 34 35/** 36 * Helper object for setting and getting entity scores for classified text. 37 * 38 * @hide 39 */ 40@RestrictTo(RestrictTo.Scope.LIBRARY) 41final class EntityConfidence implements Parcelable { 42 43 private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>(); 44 private final ArrayList<String> mSortedEntities = new ArrayList<>(); 45 46 EntityConfidence() {} 47 48 EntityConfidence(@NonNull EntityConfidence source) { 49 Preconditions.checkNotNull(source); 50 mEntityConfidence.putAll((SimpleArrayMap<String, Float>) source.mEntityConfidence); 51 mSortedEntities.addAll(source.mSortedEntities); 52 } 53 54 /** 55 * Constructs an EntityConfidence from a map of entity to confidence. 56 * 57 * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. 58 * 59 * @param source a map from entity to a confidence value in the range 0 (low confidence) to 60 * 1 (high confidence). 61 */ 62 EntityConfidence(@NonNull Map<String, Float> source) { 63 Preconditions.checkNotNull(source); 64 65 // Prune non-existent entities and clamp to 1. 66 mEntityConfidence.ensureCapacity(source.size()); 67 for (Map.Entry<String, Float> it : source.entrySet()) { 68 if (it.getValue() <= 0) continue; 69 mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); 70 } 71 resetSortedEntitiesFromMap(); 72 } 73 74 /** 75 * Returns an immutable list of entities found in the classified text ordered from 76 * high confidence to low confidence. 77 */ 78 @NonNull 79 public List<String> getEntities() { 80 return Collections.unmodifiableList(mSortedEntities); 81 } 82 83 /** 84 * Returns the confidence score for the specified entity. The value ranges from 85 * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the 86 * classified text. 87 */ 88 @FloatRange(from = 0.0, to = 1.0) 89 public float getConfidenceScore(String entity) { 90 if (mEntityConfidence.containsKey(entity)) { 91 return mEntityConfidence.get(entity); 92 } 93 return 0; 94 } 95 96 @Override 97 public String toString() { 98 return mEntityConfidence.toString(); 99 } 100 101 @Override 102 public int describeContents() { 103 return 0; 104 } 105 106 @Override 107 public void writeToParcel(Parcel dest, int flags) { 108 dest.writeInt(mEntityConfidence.size()); 109 for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) { 110 dest.writeString(entry.getKey()); 111 dest.writeFloat(entry.getValue()); 112 } 113 } 114 115 public static final Parcelable.Creator<EntityConfidence> CREATOR = 116 new Parcelable.Creator<EntityConfidence>() { 117 @Override 118 public EntityConfidence createFromParcel(Parcel in) { 119 return new EntityConfidence(in); 120 } 121 122 @Override 123 public EntityConfidence[] newArray(int size) { 124 return new EntityConfidence[size]; 125 } 126 }; 127 128 private EntityConfidence(Parcel in) { 129 final int numEntities = in.readInt(); 130 mEntityConfidence.ensureCapacity(numEntities); 131 for (int i = 0; i < numEntities; ++i) { 132 mEntityConfidence.put(in.readString(), in.readFloat()); 133 } 134 resetSortedEntitiesFromMap(); 135 } 136 137 private void resetSortedEntitiesFromMap() { 138 mSortedEntities.clear(); 139 mSortedEntities.ensureCapacity(mEntityConfidence.size()); 140 mSortedEntities.addAll(mEntityConfidence.keySet()); 141 Collections.sort(mSortedEntities, new EntityConfidenceComparator()); 142 } 143 144 /** Helper to sort entities according to their confidence. */ 145 private class EntityConfidenceComparator implements Comparator<String> { 146 @Override 147 public int compare(String e1, String e2) { 148 float score1 = mEntityConfidence.get(e1); 149 float score2 = mEntityConfidence.get(e2); 150 return Float.compare(score2, score1); 151 } 152 } 153} 154