1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify.winnow;
23
24 import java.util.Arrays;
25 import java.util.HashMap;
26 import java.util.HashSet;
27 import java.util.Iterator;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Set;
31 import java.util.TreeSet;
32
33 import org.apache.commons.lang.ArrayUtils;
34 import org.apache.commons.lang.ObjectUtils;
35 import org.apache.commons.lang.builder.ToStringBuilder;
36 import org.dom4j.Element;
37 import org.dom4j.QName;
38
39 import de.fu_berlin.ties.ContextMap;
40 import de.fu_berlin.ties.ProcessingException;
41 import de.fu_berlin.ties.TiesConfiguration;
42 import de.fu_berlin.ties.classify.PredictionDistribution;
43 import de.fu_berlin.ties.classify.TrainableClassifier;
44 import de.fu_berlin.ties.classify.feature.Feature;
45 import de.fu_berlin.ties.classify.feature.FeatureSet;
46 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
47 import de.fu_berlin.ties.classify.feature.FeatureVector;
48 import de.fu_berlin.ties.io.ObjectElement;
49 import de.fu_berlin.ties.util.Util;
50 import de.fu_berlin.ties.xml.dom.DOMUtils;
51
52 /***
53 * Classifier implementing the Winnow algorithm (Nick Littlestone). <b>Winnow
54 * supports <em>only</em> error-driven training, so you always have to use the
55 * {@link #trainOnError(FeatureVector, String, Set)} method. Trying to
56 * call the {@link
57 * de.fu_berlin.ties.classify.TrainableClassifier#train(FeatureVector, String)}
58 * method instead will result in an
59 * {@link java.lang.UnsupportedOperationException}.</b>
60 *
61 * <p>Instances of this class are thread-safe.
62 *
63 * @author Christian Siefkes
64 * @version $Revision: 1.80 $, $Date: 2006/10/21 16:03:59 $, $Author: siefkes $
65 */
66 public class Winnow extends TrainableClassifier {
67
68 /***
69 * Attribute name used for XML serialization.
70 */
71 private static final QName ATTRIB_PROMOTION =
72 DOMUtils.defaultName("promotion");
73
74 /***
75 * Attribute name used for XML serialization.
76 */
77 private static final QName ATTRIB_DEMOTION =
78 DOMUtils.defaultName("demotion");
79
80 /***
81 * Attribute name used for XML serialization.
82 */
83 private static final QName ATTRIB_THRESHOLD_THICKNESS =
84 DOMUtils.defaultName("threshold-thickness");
85
86 /***
87 * Attribute name used for XML serialization.
88 */
89 private static final QName ATTRIB_BALANCED =
90 DOMUtils.defaultName("balanced");
91
92 /***
93 * Attribute name used for XML serialization.
94 */
95 private static final QName ATTRIB_IGNORE_EXPONENT =
96 DOMUtils.defaultName("ignore-exponent");
97
98 /***
99 * Configuration key: How feature frequencies are considered when
100 * calculating strength values.
101 */
102
103
104
105 /***
106 * Delta used when ignoring irrelevant features.
107 */
108 private static final double IGNORE_DELTA = 0.001;
109
110 /***
111 * Helper method used to initialize the lower limit used when ignoring
112 * irrelevant features.
113 *
114 * @param demotion the demotion factor used by the algorithm
115 * @param ignoreExponent exponent used to calculate which features to
116 * consider irrelevant
117 * @return the lower limit:
118 * <code>demotion</code>^<code>ignoreExponent</code> /
119 * (1.0 + {@link #IGNORE_DELTA})
120 */
121 private static float initLowerIgnoreLimit(final float demotion,
122 final int ignoreExponent) {
123 return (float) (Math.pow(demotion, ignoreExponent)
124 / (1.0 + IGNORE_DELTA));
125 }
126
127 /***
128 * Helper method used to initialize the upper limit used when ignoring
129 * irrelevant features.
130 *
131 * @param promotion the promotion factor used by the algorithm
132 * @param ignoreExponent exponent used to calculate which features to
133 * consider irrelevant
134 * @return the upper limit:
135 * <code>promotion</code>^<code>ignoreExponent</code> *
136 * (1.0 + {@link #IGNORE_DELTA})
137 */
138 private static float initUpperIgnoreLimit(final float promotion,
139 final int ignoreExponent) {
140 return (float) (Math.pow(promotion, ignoreExponent)
141 * (1.0 + IGNORE_DELTA));
142 }
143
144
145 /***
146 * Whether the Balanced Winnow or the standard Winnow algorithm is used.
147 * Balanced Winnow keeps <code>two</code> weights per feature and class,
148 * a positive and a negative one.
149 */
150 private final boolean balanced;
151
152 /***
153 * The promotion factor used by the algorithm.
154 */
155 private final float promotion;
156
157 /***
158 * The demotion factor used by the algorithm.
159 */
160 private final float demotion;
161
162 /***
163 * The thickness of the threshold if the "thick threshold" heuristic is used
164 * (must be < 1.0), 0.0 otherwise.
165 */
166 private final float thresholdThickness;
167
168 /***
169 * Exponent used to calculate which features to consider irrelevant for
170 * classification.
171 */
172 private final int ignoreExponent;
173
174 /***
175 * The lower limit used to ignoring irrelevant features.
176 */
177 private final float lowerIgnoreLimit;
178
179 /***
180 * The upper limit used to ignoring irrelevant features.
181 */
182 private final float upperIgnoreLimit;
183
184 /***
185 * Stores the feature weights, using a variation of the LRU mechanism for
186 * pruning surplus features. Every access should be synchronized on
187 * <strong>this</strong>.
188 */
189 private final WinnowStore store;
190
191
192 /***
193 * Creates a new instance from an XML element, fulfilling the
194 * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
195 *
196 * @param element the XML element containing the serialized representation
197 * @throws InstantiationException if the given element does not contain
198 * a valid classifier description
199 */
200 public Winnow(final Element element) throws InstantiationException {
201
202 super(element);
203
204 promotion = Util.asFloat(element.attributeValue(ATTRIB_PROMOTION));
205 demotion = Util.asFloat(element.attributeValue(ATTRIB_DEMOTION));
206 thresholdThickness =
207 Util.asFloat(element.attributeValue(ATTRIB_THRESHOLD_THICKNESS));
208
209 ignoreExponent = Util.asInt(ObjectUtils.defaultIfNull(
210 element.attributeValue(ATTRIB_IGNORE_EXPONENT), 1));
211 balanced = Util.asBoolean(element.attributeValue(ATTRIB_BALANCED));
212
213 checkArguments(promotion, demotion, thresholdThickness);
214 lowerIgnoreLimit = initLowerIgnoreLimit(demotion, ignoreExponent);
215 upperIgnoreLimit = initUpperIgnoreLimit(promotion, ignoreExponent);
216
217 store = (WinnowStore) ObjectElement.createObject(
218 element.element(WinnowStore.ELEMENT_MAIN));
219
220 if (store.isIgnoringIrrelevant()) {
221 Util.LOG.debug(
222 "Classification will ignore features with all weights in ["
223 + lowerIgnoreLimit + ".." + upperIgnoreLimit + "] range");
224 }
225 }
226
227 /***
228 * Creates a new instance based on the
229 * {@linkplain TiesConfiguration#CONF standard configuration}.
230 *
231 * @param allValidClasses the set of all valid classes
232 * @throws IllegalArgumentException if one of the parameters is outside
233 * the allowed range
234 * @throws ProcessingException if an error occurred while creating
235 * the feature transformer(s)
236 */
237 public Winnow(final Set<String> allValidClasses)
238 throws IllegalArgumentException, ProcessingException {
239 this(allValidClasses, (String) null);
240 }
241
242 /***
243 * Creates a new instance based on the
244 * {@linkplain TiesConfiguration#CONF standard configuration}.
245 *
246 * @param allValidClasses the set of all valid classes
247 * @param configSuffix optional suffix appended to the configuration keys
248 * when configuring this instance; might be <code>null</code>
249 * @throws IllegalArgumentException if one of the parameters is outside
250 * the allowed range
251 * @throws ProcessingException if an error occurred while creating
252 * the feature transformer(s)
253 */
254 protected Winnow(final Set<String> allValidClasses,
255 final String configSuffix)
256 throws IllegalArgumentException, ProcessingException {
257 this(allValidClasses, TiesConfiguration.CONF, configSuffix);
258 }
259
260 /***
261 * Creates a new instance based on the provided configuration.
262 *
263 * @param allValidClasses the set of all valid classes
264 * @param config contains configuration properties
265 * @throws IllegalArgumentException if one of the parameters is outside
266 * the allowed range
267 * @throws ProcessingException if an error occurred while creating
268 * the feature transformer(s)
269 */
270 public Winnow(final Set<String> allValidClasses,
271 final TiesConfiguration config)
272 throws IllegalArgumentException, ProcessingException {
273 this(allValidClasses, config, null);
274 }
275
276 /***
277 * Creates a new instance based on the provided configuration.
278 *
279 * @param allValidClasses the set of all valid classes
280 * @param config contains configuration properties
281 * @param configSuffix optional suffix appended to the configuration keys
282 * when configuring this instance; might be <code>null</code>
283 * @throws IllegalArgumentException if one of the parameters is outside
284 * the allowed range
285 * @throws ProcessingException if an error occurred while creating
286 * the feature transformer(s)
287 */
288 protected Winnow(final Set<String> allValidClasses,
289 final TiesConfiguration config, final String configSuffix)
290 throws IllegalArgumentException, ProcessingException {
291 this(allValidClasses, FeatureTransformer.createTransformer(config),
292 config, configSuffix);
293 }
294
295 /***
296 * Creates a new instance based on the provided configuration.
297 *
298 * @param allValidClasses the set of all valid classes
299 * @param trans the last transformer in the transformer chain to use, or
300 * <code>null</code> if no feature transformers should be used
301 * @param config contains configuration properties
302 * @throws IllegalArgumentException if one of the parameters is outside
303 * the allowed range
304 * @throws ProcessingException if an error occurred while creating
305 * the feature transformer(s)
306 */
307 public Winnow(final Set<String> allValidClasses,
308 final FeatureTransformer trans, final TiesConfiguration config)
309 throws IllegalArgumentException, ProcessingException {
310 this(allValidClasses, trans, config, null);
311 }
312
313 /***
314 * Creates a new instance based on the provided configuration.
315 *
316 * @param allValidClasses the set of all valid classes
317 * @param trans the last transformer in the transformer chain to use, or
318 * <code>null</code> if no feature transformers should be used
319 * @param config contains configuration properties
320 * @param configSuffix optional suffix appended to the configuration keys
321 * when configuring this instance; might be <code>null</code>
322 * @throws IllegalArgumentException if one of the parameters is outside
323 * the allowed range
324 * @throws ProcessingException if an error occurred while creating
325 * the feature transformer(s)
326 */
327 protected Winnow(final Set<String> allValidClasses,
328 final FeatureTransformer trans, final TiesConfiguration config,
329 final String configSuffix)
330 throws IllegalArgumentException, ProcessingException {
331 this(allValidClasses, trans,
332 config.getBoolean(
333 config.adaptKey("classifier.winnow.balanced", configSuffix)),
334 config.getFloat(
335 config.adaptKey("classifier.winnow.promotion", configSuffix)),
336 config.getFloat(
337 config.adaptKey("classifier.winnow.demotion", configSuffix)),
338 config.getFloat(config.adaptKey(
339 "classifier.winnow.threshold.thickness", configSuffix)),
340 config.getInt(config.adaptKey(
341 "classifier.winnow.ignore.exponent", configSuffix)),
342 config, configSuffix);
343 }
344
345 /***
346 * Creates a new instance.
347 *
348 * @param allValidClasses the set of all valid classes
349 * @param trans the last transformer in the transformer chain to use, or
350 * <code>null</code> if no feature transformers should be used
351 * @param balance whether to use the Balanced Winnow or the standard
352 * Winnow algorithm
353 * @param promotionFactor the promotion factor used by the algorithm;
354 * must be > 1.0
355 * @param demotionFactor the demotion factor used by the algorithm; must
356 * be < 1.0
357 * @param thresholdThick the thickness of the threshold if the "thick
358 * threshold" heuristic is used (must be < 1.0), 0.0 otherwise
359 * @param ignoreExp exponent used to calculate which features to consider
360 * irrelevant for classification (if any)
361 * @param config contains configuration properties
362 * @param configSuffix optional suffix appended to the configuration keys
363 * when configuring this instance; might be <code>null</code>
364 * @throws IllegalArgumentException if one of the parameters is outside
365 * the allowed range
366 */
367 public Winnow(final Set<String> allValidClasses,
368 final FeatureTransformer trans, final boolean balance,
369 final float promotionFactor, final float demotionFactor,
370 final float thresholdThick, final int ignoreExp,
371 final TiesConfiguration config, final String configSuffix)
372 throws IllegalArgumentException {
373
374 super(new TreeSet<String>(allValidClasses), trans, config);
375
376
377 checkArguments(promotionFactor, demotionFactor, thresholdThick);
378 balanced = balance;
379 promotion = promotionFactor;
380 demotion = demotionFactor;
381 thresholdThickness = thresholdThick;
382 ignoreExponent = ignoreExp;
383 lowerIgnoreLimit = initLowerIgnoreLimit(demotion, ignoreExponent);
384 upperIgnoreLimit = initUpperIgnoreLimit(promotion, ignoreExponent);
385 store = WinnowStore.create(initWeight(), config, configSuffix);
386
387 if (store.isIgnoringIrrelevant()) {
388 Util.LOG.debug(
389 "Classification will ignore features with all weights in ["
390 + lowerIgnoreLimit + ".." + upperIgnoreLimit + "] range");
391 }
392 }
393
394 /***
395 * Adjusts the weights of a feature for all classes. This method should be
396 * called in a synchronized context.
397 *
398 * @param feature the feature to process
399 * @param directions an array specifying for each class (in alphabetic
400 * order) whether it should be promoted (positive value), demoted (negative
401 * value) or left unmodified (0)
402 */
403 protected void adjustWeights(final Feature feature,
404 final short[] directions) {
405
406 float[] weights = store.getWeights(feature);
407 final int length = getAllClasses().size();
408
409
410 if (directions.length != length) {
411 throw new IllegalArgumentException("Array of directions has "
412 + directions.length + " members instead of one for each of the "
413 + length + " classes");
414 }
415
416 if (weights == null) {
417
418 weights = initWeightArray();
419 store.putWeights(feature, weights);
420 }
421
422
423
424 for (int i = 0; i < length; i++) {
425 if (directions[i] < 0) {
426
427 weights[i] *= demotion;
428
429 if (balanced) {
430
431 weights[i + length] *= promotion;
432 }
433 } else if (directions[i] > 0) {
434
435 weights[i] *= promotion;
436
437 if (balanced) {
438
439 weights[i + length] *= demotion;
440 }
441 }
442
443 }
444
445 if (store.isIgnoringIrrelevant()) {
446
447 store.setRelevant(feature, checkRelevance(weights));
448 }
449 }
450
451 /***
452 * Chooses the classes to promote and the classes to demote. This class
453 * chooses the <code>targetClass</code> for promotion if its score is
454 * less or equal to the {@linkplain #threshold(float) threshold}.
455 * It chooses all other classes for demotion if their score is greather
456 * than the threshold.
457 *
458 * @param winnowDist the prediction distribution returned by
459 * {@link #classify(FeatureVector, Set)}
460 * @param targetClass the expected class of this instance; must be
461 * contained in the set of <code>candidateClasses</code>
462 * @param classesToPromote the classes to promote are added to this set
463 * @param classesToDemote the classes to demote are added to this set
464 */
465 protected void chooseClassesToAdjust(final WinnowDistribution winnowDist,
466 final String targetClass, final Set<String> classesToPromote,
467 final Set<String> classesToDemote) {
468 final Iterator predIter = winnowDist.iterator();
469 WinnowPrediction pred;
470
471
472 final float minorThreshold = minorThreshold(winnowDist.getThreshold(),
473 winnowDist.getRawThreshold());
474 final float majorThreshold = majorThreshold(winnowDist.getThreshold(),
475 winnowDist.getRawThreshold());
476
477
478
479
480
481
482 while (predIter.hasNext()) {
483 pred = (WinnowPrediction) predIter.next();
484
485 if (targetClass.equals(pred.getType())) {
486
487 if (pred.getRawScore() <= majorThreshold) {
488 classesToPromote.add(pred.getType());
489 }
490 } else {
491 if (pred.getRawScore() > minorThreshold) {
492
493 classesToDemote.add(pred.getType());
494 }
495 }
496 }
497 }
498
499 /***
500 * Converts a {@linkplain #normalizeScore(float, float, float) normalized
501 * activation value} into a confidence estimate.
502 *
503 * @param normalized the {@linkplain #normalizeScore(float, float, float)
504 * normalized activation value} to convert
505 * @param sum the sum of all normalized activation values
506 * @return the estimated confidence: <code>normalized / sum</code>
507 */
508 protected double confidence(final float normalized, final float sum) {
509 return (double) normalized / sum;
510 }
511
512 /***
513 * Helper method that checks whether the provided arguments are valid.
514 *
515 * @param promotionFactor the promotion factor used by the algorithm;
516 * must be > 1.0
517 * @param demotionFactor the demotion factor used by the algorithm; must
518 * be < 1.0
519 * @param thresholdThick the thickness of the threshold if the "thick
520 * threshold" heuristic is used (must be < 1.0), 0.0 otherwise
521 * @throws IllegalArgumentException if one of the parameters is outside
522 * the allowed range
523 */
524 private void checkArguments(final float promotionFactor,
525 final float demotionFactor, final float thresholdThick)
526 throws IllegalArgumentException {
527 if ((promotionFactor <= 1.0)) {
528 throw new IllegalArgumentException("Promotion factor must be > 1: "
529 + promotionFactor);
530 }
531 if ((demotionFactor >= 1.0) || (demotionFactor <= 0.0)) {
532 throw new IllegalArgumentException(
533 "Demotion factor must be in ]0, 1[ range:" + demotionFactor);
534 }
535 if ((thresholdThick >= 1.0) || (thresholdThick < 0.0)) {
536 throw new IllegalArgumentException("Threshold thickness must be "
537 + "in [0, 1[ range: " + thresholdThick);
538 }
539 }
540
541 /***
542 * Checks whether a feature is relevant for classification.
543 *
544 * @param weights the weights of the feature
545 * @return <code>true</code> iff the feature is relevant for classification;
546 */
547 protected boolean checkRelevance(final float[] weights) {
548 for (int i = 0; i < weights.length; i++) {
549 if ((weights[i] < lowerIgnoreLimit)
550 || (weights[i] > upperIgnoreLimit)) {
551
552 return true;
553 }
554 }
555
556
557
558 return false;
559 }
560
561 /***
562 * Returns the default weight to use if a feature is unknown. This
563 * implementation returns 0.0 in case of {@link #isBalanced() Balanced
564 * Winnow} (where positive and negative weights should cancel each other
565 * out), {@link #initWeight()} otherwise.
566 *
567 * @return the default weight
568 */
569 protected float defaultWeight() {
570 if (balanced) {
571 return 0.0f;
572 } else {
573 return initWeight();
574 }
575 }
576
577 /***
578 * {@inheritDoc}
579 */
580 public void destroy() {
581 store.destroy();
582 }
583
584 /***
585 * {@inheritDoc}
586 */
587 protected PredictionDistribution doClassify(final FeatureVector features,
588 final Set candidateClasses, final ContextMap context) {
589
590 final FeatureSet featureSet = featureSet(features);
591 final float[] scores = initScores();
592 final Iterator featureIter = featureSet.iterator();
593 Feature currentFeature;
594
595
596
597
598
599 synchronized (this) {
600
601 while (featureIter.hasNext()) {
602 currentFeature = (Feature) featureIter.next();
603
604 if (store.isRelevant(currentFeature)) {
605 updateScores(currentFeature,
606
607 }
608 }
609 }
610
611
612 final float rawThreshold = rawThreshold(featureSet);
613 final float threshold = threshold(rawThreshold);
614 final float[] normalizedScores = new float[scores.length];
615 float normalizedSum = 0.0f;
616 int i;
617
618 for (i = 0; i < scores.length; i++) {
619 normalizedScores[i] =
620 normalizeScore(scores[i], threshold, rawThreshold);
621 normalizedSum += normalizedScores[i];
622 }
623
624
625 final WinnowDistribution result =
626 new WinnowDistribution(threshold, rawThreshold);
627
628
629
630 final Iterator classIter = candidateClasses.iterator();
631 String className;
632 i = 0;
633
634 while (classIter.hasNext()) {
635 className = (String) classIter.next();
636 result.add(new WinnowPrediction(className,
637 confidence(normalizedScores[i], normalizedSum), scores[i],
638 normalizedScores[i]));
639 i++;
640 }
641
642
643
644
645
646
647
648
649
650
651
652 return result;
653 }
654
655 /***
656 * <b>Winnow supports <em>only</em> error-driven training, so you always
657 * have to use the {@link #trainOnError(FeatureVector, String, Set)} method
658 * instead of this one. Trying to call this method instead will result in an
659 * {@link java.lang.UnsupportedOperationException}.</b>
660 *
661 * @param features ignored by this method
662 * @param targetClass ignored by this method
663 * @param context ignored by this method
664 * @throws UnsupportedOperationException always thrown by this method;
665 * use {@link #trainOnError(FeatureVector, String, Set)} instead
666 */
667 protected void doTrain(final FeatureVector features,
668 final String targetClass, final ContextMap context)
669 throws UnsupportedOperationException {
670
671
672 throw new UnsupportedOperationException("Winnow supports only "
673 + "error-driven training -- call trainOnError instead of train");
674 }
675
676 /***
677 * Converts a feature vector into a {@link FeatureSet} (a multi-set of
678 * features). If the
679 * {@linkplain FeatureVector#lastTransformation(FeatureVector) last
680 * transformation} of the provided vector already is a
681 * <code>FeatureSet</code> instance, it is casted and returned. Otherwise a
682 * new <code>FeatureSet</code> with the same contents is created, reading
683 * the used method for considering feature frequencies in strength values
684 * from the "classifier.winnow.strength.frequency" configuration key.
685 *
686 * @param fv the feature vector to convert
687 * @return a feature set with the same contents as the provided vector
688 */
689 protected FeatureSet featureSet(final FeatureVector fv) {
690 final FeatureSet result;
691 final FeatureVector transformed = fv.lastTransformation();
692
693 if (transformed instanceof FeatureSet) {
694 result = (FeatureSet) transformed;
695 } else {
696
697
698
699
700 result = new FeatureSet();
701
702 result.addAll(transformed);
703
704
705 transformed.setTransformed(result);
706 }
707
708 return result;
709 }
710
711 /***
712 * Returns the promotion factor used by the algorithm.
713 *
714 * @return the value of the attribute
715 */
716 public float getDemotion() {
717 return demotion;
718 }
719
720 /***
721 * Returns the demotion factor used by the algorithm.
722 *
723 * @return the value of the attribute
724 */
725 public float getPromotion() {
726 return promotion;
727 }
728
729 /***
730 * Whether the Balanced Winnow or the standard Winnow algorithm is used.
731 * Balanced Winnow keeps <em>two</em> weights per feature and class,
732 * a positive and a negative one.
733 *
734 * @return the value of the attribute
735 */
736 public boolean isBalanced() {
737 return balanced;
738 }
739
740 /***
741 * Initializes the score (activation values) to use for all classes.
742 *
743 * @return an array of floats containing the initial score for each class;
744 * the value of each float will be 0.0
745 */
746 protected float[] initScores() {
747
748 final float[] result = new float[getAllClasses().size()];
749 return result;
750 }
751
752 /***
753 * Returns the thickness of the threshold if the "thick threshold"
754 * heuristic is used.
755 *
756 * @return the value of the attribute, will be < 1.0; 0.0 if no
757 * thick threshold is used
758 */
759 public float getThresholdThickness() {
760 return thresholdThickness;
761 }
762
763 /***
764 * Returns the initial weight to use for each feature per class. This
765 * implementation returns 1.0.
766 *
767 * @return the initial weight
768 */
769 protected float initWeight() {
770 return 1.0f;
771 }
772
773 /***
774 * Returns the initial weight array to use for a feature for all classes.
775 * The array returns by this implementation fill contain one weight for
776 * each class in case of normal Winnow, two weights in case of
777 * {@link #isBalanced() Balanced} Winnow. Each element is initialized to
778 * {@link #initWeight()}.
779 *
780 * @return the initial weight array
781 */
782 protected float[] initWeightArray() {
783 final float[] result;
784
785 if (balanced) {
786
787 result = new float[getAllClasses().size() * 2];
788 } else {
789
790 result = new float[getAllClasses().size()];
791 }
792
793
794 final float initWeight = initWeight();
795 for (int i = 0; i < result.length; i++) {
796 result[i] = initWeight;
797 }
798
799 return result;
800 }
801
802 /***
803 * Calculates the major theshold (<em>theta+</em>) to use for classification
804 * with the "thick threshold" heuristic. This
805 * implementation multiplies <em>theta<sub>r</sub></em> with the
806 * {@linkplain #getThresholdThickness() threshold thickness} and adds
807 * the result to <em>theta</em>. Subclasses can overwrite this method to
808 * calculate the major theshold in a different way.
809 *
810 * @param threshold the {@linkplain #threshold(float) threshold}
811 * <em>theta</em>
812 * @param rawThreshold the {@linkplain #rawThreshold(FeatureSet) raw
813 * threshold} <em>theta<sub>r</sub></em>
814 * @return the major theshold (<em>theta+</em>) to use for classification
815 * @see #minorThreshold(float, float)
816 */
817 protected float majorThreshold(final float threshold,
818 final float rawThreshold) {
819 final float result = threshold + getThresholdThickness() * rawThreshold;
820 return result;
821 }
822
823 /***
824 * Calculates the minor theshold (<em>theta-</em>) to use for classification
825 * with the "thick threshold" heuristic. This
826 * implementation multiplies <em>theta<sub>r</sub></em> with the
827 * {@linkplain #getThresholdThickness() threshold thickness} and subtracts
828 * the result from <em>theta</em>. Subclasses can overwrite this method to
829 * calculate the minor theshold in a different way.
830 *
831 * @param threshold the {@linkplain #threshold(float) threshold}
832 * <em>theta</em>
833 * @param rawThreshold the {@linkplain #rawThreshold(FeatureSet) raw
834 * threshold} <em>theta<sub>r</sub></em>
835 * @return the minor theshold (<em>theta-</em>) to use for classification
836 * @see #majorThreshold(float, float)
837 */
838 protected float minorThreshold(final float threshold,
839 final float rawThreshold) {
840 final float result = threshold - getThresholdThickness() * rawThreshold;
841 return result;
842 }
843
844 /***
845 * Converts the raw <em>score</em> (activation value) to a normalized value
846 * depending on the threshold <em>theta</em>.
847 * In this implementation this is calculed as follows:
848 *
849 * <p>
850 * <em>norm</em>(<em>score</em>, <em>theta</em>,
851 * <em>theta<sub>r</sub></em>) =
852 * e^((<em>score</em> - <em>theta</em>) / <em>theta<sub>r</sub></em>))
853 *
854 * @param score the raw <em>score</em> (activation value); must be a
855 * positive value in case of normal (non-balanced) Winnow
856 * @param threshold the {@linkplain #threshold(float) threshold}
857 * <em>theta</em> used for this instance
858 * @param rawThreshold the {@linkplain #rawThreshold(FeatureSet) raw
859 * threshold} <em>theta<sub>r</sub></em> used for this instance
860 * @return the normalized score calculated as described above
861 */
862 protected float normalizeScore(final float score, final float threshold,
863 final float rawThreshold) {
864
865
866
867 return (float) Math.exp((score - threshold) / rawThreshold);
868 }
869
870 /***
871 * Calculates the theshold (theta) to use for classification, based on the
872 * number of active features. This implementation returns the sum of all
873 * relevant features. Subclasses can overwrite this method to calculate the
874 * theshold in a different way.
875 *
876 * @param features the feature set to consider
877 * @return the raw theshold (theta) to use
878 */
879 protected float rawThreshold(final FeatureSet features) {
880 if (store.isIgnoringIrrelevant()) {
881
882 final Iterator featureIter = features.iterator();
883 Feature feature;
884 int summedStrength = 0;
885
886
887 while (featureIter.hasNext()) {
888 feature = (Feature) featureIter.next();
889
890 if (store.isRelevant(feature)) {
891 summedStrength++;
892
893 }
894 }
895
896
897 return summedStrength;
898 } else {
899
900 return features.size();
901
902
903
904
905 }
906 }
907
908 /***
909 * {@inheritDoc}
910 */
911 public void reset() {
912 store.reset();
913 }
914
915 /***
916 * Returns a mapping from feature representations to weights. Features
917 * that are {@linkplain WinnowStore#isRelevant(Feature) irrelevant} or
918 * unknown (never seen during training) or contain only a comment are
919 * skipped. For each other feature, the returned map will contain an
920 * array of weighs for all classes stored in the order of
921 * {@link TrainableClassifier#getAllClasses()}.
922 *
923 * <p>This method exists for debugging and demonstration purposes.
924 *
925 * @param features the feature vector to consider
926 * @return a mapping from known relevant feature representations to weights
927 */
928 public Map<String, List<Float>> showFeatureWeights(
929 final FeatureVector features) {
930 final Map<String, List<Float>> result =
931 new HashMap<String, List<Float>>();
932 final FeatureVector actualFeatures = features.lastTransformation();
933 final Iterator featureIter = actualFeatures.iterator();
934 Feature currentFeature;
935 String currentRep;
936 float[] currentWeights;
937
938
939 while (featureIter.hasNext()) {
940 currentFeature = (Feature) featureIter.next();
941 currentRep = currentFeature.getRepresentation();
942 currentWeights = store.getWeights(currentFeature);
943
944
945 if ((currentRep != null) && (currentWeights != null)
946 && store.isRelevant(currentFeature)) {
947 result.put(currentRep,
948 Arrays.asList(ArrayUtils.toObject(currentWeights)));
949 }
950 }
951
952 return result;
953 }
954
955 /***
956 * Calculates the theshold (theta) to use for classification. This
957 * implementation returns the <code>rawThreshold</code> multiplied with
958 * the {@linkplain #defaultWeight() default weight}. Subclasses can
959 * overwrite this method to calculate the theshold in a different way.
960 *
961 * @param rawThreshold the {@linkplain #rawThreshold(FeatureSet) raw
962 * threshold}
963 * @return the theshold (theta) to use for classification
964 */
965 protected float threshold(final float rawThreshold) {
966 return rawThreshold * defaultWeight();
967 }
968
969 /***
970 * Hook implementing error-driven learning, promoting and demoting weights
971 * as required.
972 *
973 * @param predDist the prediction distribution returned by
974 * {@link #classify(FeatureVector, Set)}; must be a
975 * {@link WinnowDistribution}
976 * @param features the feature vector to consider
977 * @param targetClass the expected class of this feature vector; must be
978 * contained in the set of <code>candidateClasses</code>
979 * @param candidateClasses an set of classes that are allowed for this item
980 * (the actual <code>targetClass</code> must be one of them)
981 * @param context ignored by this implementation
982 * @return this implementation always returns <code>true</code> to signal
983 * that any error-driven learning was already handled
984 * @throws ProcessingException if an error occurs during training
985 */
986 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
987 final FeatureVector features, final String targetClass,
988 final Set candidateClasses, final ContextMap context)
989 throws ProcessingException {
990
991 final FeatureSet featureSet = featureSet(features);
992
993 final Set<String> classesToPromote = new HashSet<String>();
994 final Set<String> classesToDemote = new HashSet<String>();
995
996
997 chooseClassesToAdjust((WinnowDistribution) predDist, targetClass,
998 classesToPromote, classesToDemote);
999
1000 if (!(classesToPromote.isEmpty() && classesToDemote.isEmpty())) {
1001
1002
1003
1004
1005
1006 final short[] directions = new short[getAllClasses().size()];
1007 final Iterator classIter = getAllClasses().iterator();
1008 String currentClass;
1009 int i = 0;
1010
1011 while (classIter.hasNext()) {
1012 currentClass = (String) classIter.next();
1013 if (classesToDemote.contains(currentClass)) {
1014
1015 directions[i] = -1;
1016 } else if (classesToPromote.contains(currentClass)) {
1017
1018 directions[i] = 1;
1019 } else {
1020
1021 directions[i] = 0;
1022 }
1023 i++;
1024 }
1025
1026 final Iterator featureIter = featureSet.iterator();
1027 Feature currentFeature;
1028
1029 synchronized (this) {
1030
1031 while (featureIter.hasNext()) {
1032 currentFeature = (Feature) featureIter.next();
1033 adjustWeights(currentFeature, directions);
1034 }
1035 }
1036 }
1037
1038
1039 return true;
1040 }
1041
1042 /***
1043 * {@inheritDoc}
1044 */
1045 public ObjectElement toElement() {
1046
1047 final ObjectElement result = super.toElement();
1048 result.addAttribute(ATTRIB_PROMOTION, Float.toString(promotion));
1049 result.addAttribute(ATTRIB_DEMOTION, Float.toString(demotion));
1050 result.addAttribute(ATTRIB_THRESHOLD_THICKNESS,
1051 Float.toString(thresholdThickness));
1052 result.addAttribute(ATTRIB_BALANCED, Boolean.toString(balanced));
1053 result.addAttribute(ATTRIB_IGNORE_EXPONENT,
1054 Integer.toString(ignoreExponent));
1055
1056
1057 result.add(store.toElement());
1058 return result;
1059 }
1060
1061 /***
1062 * Returns a string representation of this object.
1063 *
1064 * @return a textual representation
1065 */
1066 public String toString() {
1067 final ToStringBuilder builder = new ToStringBuilder(this)
1068 .appendSuper(super.toString())
1069 .append("balanced", balanced)
1070 .append("promotion", promotion)
1071 .append("demotion", demotion)
1072 .append("threshold thickness", thresholdThickness)
1073
1074
1075 .append("feature store", store);
1076
1077 if (store.isIgnoringIrrelevant()) {
1078 builder.append("lower ignore limit", lowerIgnoreLimit)
1079 .append("upper ignore limit", upperIgnoreLimit);
1080 }
1081
1082 return builder.toString();
1083 }
1084
1085 /***
1086 * Updates the score (activation values) for all classes by adding the
1087 * weights of a feature. This method should be called in a synchronized
1088 * context.
1089 *
1090 * @param feature the feature to process
1091 * @param scores an array of floats containing the scores for each
1092 * class; will be updated by this method
1093 */
1094 protected void updateScores(final Feature feature, final float[] scores) {
1095
1096
1097
1098 final float[] weights = store.getWeights(feature);
1099 final int length = getAllClasses().size();
1100
1101
1102 if (scores.length != length) {
1103 throw new IllegalArgumentException("Array of scores has "
1104 + scores.length + " members instead of one for each of the "
1105 + length + " classes");
1106 }
1107
1108 if (weights != null) {
1109
1110
1111 for (int i = 0; i < length; i++) {
1112 scores[i] += weights[i];
1113 }
1114
1115 if (balanced) {
1116
1117
1118 for (int i = 0; i < length; i++) {
1119 scores[i] -= weights[i + length];
1120 }
1121 }
1122 } else {
1123 final float defaultWeight = defaultWeight();
1124
1125
1126 if (defaultWeight != 0.0f) {
1127 for (int i = 0; i < length; i++) {
1128 scores[i] += defaultWeight;
1129 }
1130 }
1131 }
1132 }
1133
1134 }