Mega Code Archive

 
Categories / Java / Development Class
 

A random choice maker

//package gr.forth.ics.util; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; import java.util.List; import java.util.Random; /**  * A random choice maker, where each choice is associated with a probability. This implementation  * is based on the fast <em>alias</em> method, where for each random choice two random  * numbers are generated and only a single table lookup performed.  *  * @param <T> the type of the choices to be made  * @see <a href="http://cg.scs.carleton.ca/~luc/rnbookindex.html">L. Devroye, Non-Uniform Random Variate Generation, 1986, p. 107</a>  * @author  Andreou Dimitris, email: jim.andreou (at) gmail (dot) com   */ public class RandomChooser<T> {     private final double[] probs;     private final int[] indexes;     private final List<T> events;          private final Random random;          private RandomChooser(List<Double> weights, List<T> events, Random random) {         double sum = 0.0;         for (double prob : weights) sum += prob;         this.probs = new double[weights.size()];         for (int i = 0; i < weights.size(); i++) {             probs[i] = weights.get(i) * weights.size() / sum; //average = 1.0         }         Deque<Integer> smaller = new ArrayDeque<Integer>(weights.size() / 2 + 2);         Deque<Integer> greater = new ArrayDeque<Integer>(weights.size() / 2 + 2);         for (int i = 0; i < probs.length; i++) {             if (probs[i] < 1.0) {                 smaller.push(i);             } else {                 greater.push(i);             }         }         indexes = new int[weights.size()];         while (!smaller.isEmpty()) {             Integer i = smaller.pop();             Integer k = greater.peek();             indexes[i] =  k;             probs[k] -= (1 - probs[i]);             if (probs[k] < 1.0) {                 greater.pop();                 if (greater.isEmpty()) break;                 smaller.push(k);             }         }         this.events = events;         this.random = random;     }     /**      * Returns a random choice.      *      * @return a random choice      * @see RandomChooserBuilder about how to configure the available choices      */     public T choose() {         int index = random.nextInt(probs.length);         double x = random.nextDouble();         return x < probs[index] ? events.get(index) : events.get(indexes[index]);     }      /**      * Creates a builder of a {@link RandomChooser} instance. The builder is responsible      * for configuring the choices and probabilities of the random chooser.      *      * @param <T> the type of the choices that will be randomly made      * @return a builder of a {@code RandomChooser} object      */     public static <T> RandomChooserBuilder<T> newInstance() {         return new RandomChooserBuilder<T>();     }     /**      * A builder of {@link RandomChooser}.      *       * @param <T> the type of the choices that the created {@code RandomChooser} will make      */     public static class RandomChooserBuilder<T> {         private final List<Double> probs = new ArrayList<Double>();         private final List<T> events = new ArrayList<T>();         private Random random = new Random(0);         private RandomChooserBuilder() { }         /**          * Adds the possibility of a given choice, weighted by a relative probability.          * (Relative means that it is not needed that all probabilities have sum {@code 1.0}).          *          * @param choice a possible choice          * @param prob the relative probability of the choice; must be {@code >= 0}          * @return this          */         public RandomChooserBuilder<T> choice(T choice, double prob) {             Args.gte(prob, 0.0);             Args.notNull(choice);             probs.add(prob);             events.add(choice);             return this;         }         /**          * Specifies the random number generator to be used by the created {@link RandomChooser}.          *          * @param random the random number generator to use          * @return this          */         public RandomChooserBuilder<T> setRandom(Random random) {             this.random = random;             return this;         }         /**          * Builds a {@link RandomChooser} instance, ready to make random choices based on the          * probabilities configured by this builder.          *          * @return a {@code RandomChooser}          */         public RandomChooser<T> build() {             if (probs.isEmpty()) {                 throw new IllegalStateException("No choice was defined");             }             return new RandomChooser<T>(                     new ArrayList<Double>(probs),                     new ArrayList<T>(events),                     random);         }     } } class Args {     private static final String GT = " must be greater than ";     private static final String GTE = " must be greater or equal to ";     private static final String LT = " must be less than ";     private static final String LTE = " must be less or equal to ";     private static final String EQUALS = " must be equal to ";          public static void doesNotContainNull(Iterable<?> iterable) {         notNull(iterable);         for (Object o : iterable) {             notNull("Iterable contains null", o);         }     }          public static void isTrue(boolean condition) {         isTrue("Condition failed", condition);     }          public static void isTrue(String msg, boolean condition) {         if (!condition) {             throw new RuntimeException(msg);         }     }          public static void notNull(Object o) {         notNull(null, o);     }          public static void notNull(String arg, Object o) {         if (arg == null) {             arg = "Argument";         }         if (o == null) {             throw new IllegalArgumentException(arg + " is null");         }     }          public static void notNull(Object... args) {         notNull(null, args);     }          public static void notNull(String message, Object... args) {         if (message == null) {             message = "Some argument";         }         for (Object o : args) {             notNull(message, o);         }     }          public static void notEmpty(Iterable<?> iter) {         notEmpty(null, iter);     }          public static void notEmpty(String arg, Iterable<?> iter) {         if (arg == null) {             arg = "Iterable";         }         notNull(iter);         if (iter.iterator().hasNext()) return;         throw new IllegalArgumentException(arg + " is empty");     }          public static void hasNoNull(Iterable<?> iter) {         hasNoNull(null, iter);     }          public static void hasNoNull(String arg, Iterable<?> iter) {         notNull(iter);         if (arg == null) {             arg = "Iterable";         }         for (Object o : iter) {             if (o == null) {                 throw new IllegalArgumentException(arg + " contains null");             }         }     }          public static void equals(int value, int expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(long value, long expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(double value, double expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(float value, float expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(char value, char expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(short value, short expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void equals(byte value, byte expected) {         if (value == expected) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }     public static void equals(Object value, Object expected) {         if (value == expected || value.equals(expected)) return;         throw new IllegalArgumentException(value + EQUALS + expected);     }          public static void gt(int value, int from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(int value, int from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(int value, int from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(int value, int from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(long value, long from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(long value, long from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(long value, long from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(long value, long from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(short value, short from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(short value, short from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(short value, short from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(short value, short from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(byte value, byte from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(byte value, byte from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(byte value, byte from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(byte value, byte from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(char value, char from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(char value, char from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(char value, char from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(char value, char from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(double value, double from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(double value, double from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(double value, double from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(double value, double from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static void gt(float value, float from) {         if (value > from) return;         throw new IllegalArgumentException(value + GT + from);     }          public static void lt(float value, float from) {         if (value < from) return;         throw new IllegalArgumentException(value + LT + from);     }          public static void gte(float value, float from) {         if (value >= from) return;         throw new IllegalArgumentException(value + GTE + from);     }          public static void lte(float value, float from) {         if (value <= from) return;         throw new IllegalArgumentException(value + LTE + from);     }          public static <T> void gt(Comparable<T> c1, T c2) {         if (c1.compareTo(c2) > 0) return;         throw new IllegalArgumentException(c1 + GT + c2);     }          public static <T> void lt(Comparable<T> c1, T c2) {         if (c1.compareTo(c2) < 0) return;         throw new IllegalArgumentException(c1 + LT + c2);     }          public static <T> void gte(Comparable<T> c1, T c2) {         if (c1.compareTo(c2) >= 0) return;         throw new IllegalArgumentException(c1 + GTE + c2);     }          public static <T> void lte(Comparable<T> c1, T c2) {         if (c1.compareTo(c2) <= 0) return;         throw new IllegalArgumentException(c1 + LTE + c2);     }          public static <T> void inRangeII(Comparable<T> value, T from, T to) {         gte(value, from);         lte(value, to);     }          public static <T> void inRangeEE(Comparable<T> value, T from, T to) {         gt(value, from);         lt(value, to);     }          public static <T> void inRangeIE(Comparable<T> value, T from, T to) {         gt(value, from);         lt(value, to);     }          public static <T> void inRangeEI(Comparable<T> value, T from, T to) {         gt(value, from);         lte(value, to);     }          public static void inRangeII(int value, int from, int to) {         gte(value, from);         lte(value, to);     }          public static void inRangeEE(int value, int from, int to) {         gt(value, from);         lt(value, to);     }          public static void inRangeIE(int value, int from, int to) {         gte(value, from);         lt(value, to);     }          public static void inRangeEI(int value, int from, int to) {         gt(value, from);         lte(value, to);     }          public static void inRangeII(long value, long from, long to) {         gte(value, from);         lte(value, to);     }          public static void inRangeEE(long value, long from, long to) {         gt(value, from);         lt(value, to);     }          public static void inRangeIE(long value, long from, long to) {         gte(value, from);         lt(value, to);     }          public static void inRangeEI(long value, long from, long to) {         gt(value, from);         lte(value, to);     }          public static void inRangeII(short value, short from, short to) {         gte(value, from);         lte(value, to);     }          public static void inRangeEE(short value, short from, short to) {         gt(value, from);         lt(value, to);     }          public static void inRangeIE(short value, short from, short to) {         gte(value, from);         lt(value, to);     }          public static void inRangeEI(short value, short from, short to) {         gt(value, from);         lte(value, to);     }          public static void inRangeII(byte value, byte from, byte to) {         gte(value, from);         lte(value, to);     }          public static void inRangeEE(byte value, byte from, byte to) {         gt(value, from);         lt(value, to);     }          public static void inRangeIE(byte value, byte from, byte to) {         gte(value, from);         lt(value, to);     }          public static void inRangeEI(byte value, byte from, byte to) {         gt(value, from);         lte(value, to);     }          public static void check(boolean assertion, String messageIfFailed) {         if (!assertion) {             throw new RuntimeException(messageIfFailed);         }     } }