diff --git a/.gitignore b/.gitignore index f5b7d72..d19a162 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,19 @@ *.class *.log +*.kml +*.gz +*.class +*.orig +target/ +tmp/ +*~ +project/boot +lib_managed/ +data/models/*.bin +data/gazetteers/*.zip + + # sbt specific dist/* target/ diff --git a/CIAWFBFixer.scala b/CIAWFBFixer.scala new file mode 100644 index 0000000..916f677 --- /dev/null +++ b/CIAWFBFixer.scala @@ -0,0 +1,78 @@ +import java.io._ +import java.util._ + +import scala.collection.JavaConversions._ + +object CIAWFBFixer extends App { + + val countriesToCoords = new HashMap[String, (Double, Double)] + + var country = "" + for(line <- scala.io.Source.fromFile(args(0)).getLines) { + if(line.endsWith(":")) { + country = line.dropRight(1).toLowerCase + if(country.equals("korea, south")) + country = "south korea" + else if(country.equals("korea, north")) + country = "north korea" + else if(line.contains(",")) { + country = line.slice(0, line.indexOf(",")).toLowerCase + } + //println(country) + } + + if(line.length >= 5 && line.startsWith(" ")) { + val tokens = line.trim.split("[^0-9SWNE]+") + if(tokens.length >= 6) { + val lat = (tokens(0).toDouble + tokens(1).toDouble / 60.0) * (if(tokens(2).equals("S")) -1 else 1) + val lon = (tokens(3).toDouble + tokens(4).toDouble / 60.0) * (if(tokens(5).equals("W")) -1 else 1) + countriesToCoords.put(country, (lat, lon)) + + if(country.contains("bosnia")) + countriesToCoords.put("bosnia", (lat, lon)) + + if(country.equals("yugoslavia")) + countriesToCoords.put("serbia", (lat, lon)) + + if(country.equals("holy see (vatican city)")) + countriesToCoords.put("vatican", (lat, lon)) + } + } + } + + countriesToCoords.put("montenegro", (42.5,19.1)) // from Google + + //countriesToCoords.foreach(p => println(p._1 + " " + p._2._1 + "," + p._2._2)) + + val lineRE = """^(.*lat=\")([^\"]+)(.*long=\")(-0)(.*humanPath=\")([^\"]+)(.*)$""".r + + val inDir = new File(if(args(1).endsWith("/")) args(1).dropRight(1) else args(1)) + val outDir = new File(if(args(2).endsWith("/")) args(2).dropRight(1) else args(2)) + for(file <- inDir.listFiles.filter(_.getName.endsWith(".xml"))) { + + val out = new BufferedWriter(new FileWriter(outDir+"/"+file.getName)) + + for(line <- scala.io.Source.fromFile(file).getLines) { + if(line.contains("CIAWFB") && line.contains("long=\"-0\"")) { + val lineRE(beg, lat0, mid, lon0, humpath, countryName, end) = line + + var lon = 0.0 + if(countriesToCoords.contains(countryName.toLowerCase)) { + lon = countriesToCoords(countryName.toLowerCase)._2 + } + + var lat = lat0 + if(countryName.toLowerCase.equals("vatican")) + lat = countriesToCoords(countryName.toLowerCase)._1.toString + + out.write(beg+lat+mid+lon+humpath+countryName+end+"\n") + //println(beg+lat+mid+lon+humpath+countryName+end+"\n") + } + else + out.write(line+"\n") + } + + out.close + } + +} diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/bin/cwarxml2txt.sh b/bin/cwarxml2txt.sh new file mode 100755 index 0000000..5a52c81 --- /dev/null +++ b/bin/cwarxml2txt.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +indir=${1%/} +outdir=${2%/} + +for f in $indir/*.xml +do + filename=$(basename $f) + filename=${filename%.*} + grep ']*>//g' > $outdir/$filename.txt +done diff --git a/bin/fieldspring b/bin/fieldspring new file mode 100755 index 0000000..8b25b51 --- /dev/null +++ b/bin/fieldspring @@ -0,0 +1,397 @@ +#!/bin/sh + +# Amount of memory (in megabytes) to reserve for system operation when +# setting the maximum heap size. +RESERVED_MEMORY=512 + +FIELDSPRING_VERSION=0.1.0 + +HADOOP_FAILURE_OK_ARGS="-Dmapred.max.map.failures.percent=20 -Dmapred.max.reduce.failures.percent=20" + +if [ -z "$FIELDSPRING_DIR" ]; then + echo "Must set FIELDSPRING_DIR to top level of Fieldspring distribution" + exit 1 +fi + +JAVA="$JAVA_HOME/bin/java" +HADOOP_BINARY="${HADOOP_BINARY:-hadoop}" + +# NOTE: If environment var TG_JAVA_OPT is set on entry, it will be used. + +# Process options + +VERBOSE=no +DEBUG=no +HADOOP= +HADOOP_NONDIST= +MEMORY= +JAVA_MISC_OPT= +JAVA_USER_OPT= +while true; do + case "$1" in + -verbose | --verbose ) VERBOSE=yes; shift ;; + -debug-class | --debug-class ) + # I think that -verbose:class is the same as -XX:+TraceClassLoading. + JAVA_MISC_OPT="$JAVA_MISC_OPT -verbose:class -XX:+TraceClassUnloading" + shift ;; + -debug | --debug ) DEBUG=yes; shift ;; + -m | -memory | --memory ) MEMORY="$2"; shift 2 ;; + -minheap | --minheap ) + JAVA_MISC_OPT="$JAVA_MISC_OPT -XX:MinHeapFreeRatio=$2"; shift 2 ;; + -maxheap | --maxheap ) + JAVA_MISC_OPT="$JAVA_MISC_OPT -XX:MaxHeapFreeRatio=$2"; shift 2 ;; + -escape-analysis | --escape-analysis ) + JAVA_MISC_OPT="$JAVA_MISC_OPT -XX:+DoEscapeAnalysis"; shift ;; + -compressed-oops | --compressed-oops ) + JAVA_MISC_OPT="$JAVA_MISC_OPT -XX:+UseCompressedOops"; shift ;; + -java-opt | --java-opt ) + JAVA_USER_OPT="$JAVA_USER_OPT $2"; shift 2 ;; + -hadoop | --hadoop) HADOOP=yes; shift ;; + -hadoop-nondist | --hadoop-nondist) HADOOP_NONDIST=yes; shift ;; + -- ) shift; break ;; + * ) break ;; + esac +done + +# For info on Sun JVM options, see: + +# http://java.sun.com/docs/hotspot/VMOptions.html +# +# (redirects to: +# +# http://www.oracle.com/technetwork/java/javase/tech/vmoptions-jsp-140102.html +# +# ) +# +# Also see the following for tuning garbage collection: +# +# http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html + +JAVA_MEMORY_OPT= +# Try to set the maximum heap size to something slightly less than +# the physical memory of the machine. +if [ -n "$MEMORY" ]; then + JAVA_MEMORY_OPT="-Xmx$MEMORY" +elif [ -n "$TG_SET_JVM_MEMORY" ]; then + MEMMB=`$FIELDSPRING_DIR/bin/fieldspring-memory` + if [ "$VERBOSE" = yes ]; then + echo "Output from fieldspring-memory is: $MEMMB" + fi + if [ "$MEMMB" = unknown ]; then + # The old way we set the heap size, to a very high virtual size. + if [ -z "$MEMORY" ]; then + if $JAVA -version 2>&1 | grep '64-Bit' > /dev/null; then + JAVA_IS_64=yes + # Maximum on Linux is about 127t (127 TB, i.e. 130,048 GB). Maximum on + # MacOS X 10.6 (Snow Leopard) is about 125t, but values that big cause a + # pause of about 6 seconds at the beginning and a couple of seconds at + # the end on my 4GB Mac. 4t doesn't cause much of a pause. + MEMORY=4t + else + JAVA_IS_64=no + MEMORY=2g + fi + fi + else + MEMORY="`expr $MEMMB - $RESERVED_MEMORY`m" + fi + JAVA_MEMORY_OPT="-Xmx$MEMORY" +fi + +if [ "$VERBOSE" = yes -a -n "$JAVA_MEMORY_OPT" ]; then + echo "Setting maximum JVM heap size to $MEMORY" +fi + +JAVA_DEBUG_OPT= +if [ "$DEBUG" = yes ]; then + # Print details about when and how garbage collection happens; recommended + # in http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html + JAVA_DEBUG_OPT="-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps -XX:+TraceClassUnloading" + #This will output a lot of stuff about class loading. Enable it using + # --debug-class if you want. + #JAVA_DEBUG_OPT="$JAVA_DEBUG_OPT -XX:+TraceClassLoading" +fi + +JARS="`echo $FIELDSPRING_DIR/lib/*.jar $FIELDSPRING_DIR/lib_managed/*/*.jar $FIELDSPRING_DIR/lib_managed/*/*/*.jar $FIELDSPRING_DIR/lib_managed/*/*/*/*.jar $FIELDSPRING_DIR/output/*.jar $FIELDSPRING_DIR/target/*.jar | tr ' ' ':'`" +SCALA_LIB="$HOME/.sbt/boot/scala-2.9.2/lib/scala-library.jar" +CP="$FIELDSPRING_DIR/target/classes:$SCALA_LIB:$JARS:$CLASSPATH" + +# Later options override earlier ones, so put command-line options after +# the ones taken from environment variables (TG_JAVA_OPT and to some extent +# JAVA_MEMORY_OPT, because it depends on env var TG_SET_JVM_MEMORY). +JAVA_COMMAND="$JAVA $TG_JAVA_OPT $JAVA_MEMORY_OPT $JAVA_DEBUG_OPT $JAVA_MISC_OPT $JAVA_USER_OPT -classpath $CP" + +CMD="$1" +shift + +help() +{ +cat < +// cp filter {x => Seq("jasper-compiler-5.5.12.jar", "jasper-runtime-5.5.12.jar", "commons-beanutils-1.7.0.jar", "servlet-api-2.5-20081211.jar") contains x.data.getName } +//} + +// FUCK ME TO (JAR) HELL! This is an awful hack. Boys and girls, repeat after +// me: say "fragile library problem" and "Java sucks rocks compared with C#". +// Now repeat 100 times. +// +// Here the problem is that, as a program increases in size and includes +// dependencies from various sources, each with their own sub-dependencies, +// you'll inevitably end up with different versions of the same library as +// sub-dependencies of different dependencies. This is the infamous "fragile +// library problem" (aka DLL hell, JAR hell, etc.). Java has no solution to +// this problem. C# does. (As with 100 other nasty Java problems that don't +// exist in C#.) +// +// On top of this, SBT makes things even worse by not even providing a way +// of automatically picking the most recent library. In fact, it doesn't +// provide any solution at all that doesn't require you to write your own +// code (see below) -- a horrendous solution typical of packages written by +// programmers who are obsessed with the mantra of "customizability" but +// have no sense of proper design, no knowledge of how to write user +// interfaces, and no skill in creating understandable documentation. +// The "solution" below arbitrarily picks the first library version found. +// Is this newer or older? Will it cause weird random breakage? Who knows? + +mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) => + { + case x => { + val oldstrat = old(x) + if (oldstrat == MergeStrategy.deduplicate) MergeStrategy.first + else oldstrat + } + } +} + +// jarName in assembly := "fieldspring-assembly.jar" diff --git a/data/gazetteers/getGeoNames.sh b/data/gazetteers/getGeoNames.sh new file mode 100755 index 0000000..f19e68a --- /dev/null +++ b/data/gazetteers/getGeoNames.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget http://download.geonames.org/export/dump/allCountries.zip diff --git a/data/lists/stopwords.english b/data/lists/stopwords.english new file mode 100644 index 0000000..c6ca14f --- /dev/null +++ b/data/lists/stopwords.english @@ -0,0 +1,572 @@ +'s +a +a's +able +about +above +according +accordingly +across +actually +after +afterwards +again +against +ain't +all +allow +allows +almost +alone +along +already +also +although +always +am +among +amongst +an +and +another +any +anybody +anyhow +anyone +anything +anyway +anyways +anywhere +apart +appear +appreciate +appropriate +are +aren't +around +as +aside +ask +asking +associated +at +available +away +awfully +b +be +became +because +become +becomes +becoming +been +before +beforehand +behind +being +believe +below +beside +besides +best +better +between +beyond +both +brief +but +by +c +c'mon +c's +came +can +can't +cannot +cant +cause +causes +certain +certainly +changes +clearly +co +com +come +comes +concerning +consequently +consider +considering +contain +containing +contains +corresponding +could +couldn't +course +currently +d +definitely +described +despite +did +didn't +different +do +does +doesn't +doing +don't +done +down +downwards +during +e +each +edu +eg +eight +either +else +elsewhere +enough +entirely +especially +et +etc +even +ever +every +everybody +everyone +everything +everywhere +ex +exactly +example +except +f +far +few +fifth +first +five +followed +following +follows +for +former +formerly +forth +four +from +further +furthermore +g +get +gets +getting +given +gives +go +goes +going +gone +got +gotten +greetings +h +had +hadn't +happens +hardly +has +hasn't +have +haven't +having +he +he's +hello +help +hence +her +here +here's +hereafter +hereby +herein +hereupon +hers +herself +hi +him +himself +his +hither +hopefully +how +howbeit +however +i +i'd +i'll +i'm +i've +ie +if +ignored +immediate +in +inasmuch +inc +indeed +indicate +indicated +indicates +inner +insofar +instead +into +inward +is +isn't +it +it'd +it'll +it's +its +itself +j +just +k +keep +keeps +kept +know +knows +known +l +last +lately +later +latter +latterly +least +less +lest +let +let's +like +liked +likely +little +look +looking +looks +ltd +m +mainly +many +may +maybe +me +mean +meanwhile +merely +might +more +moreover +most +mostly +much +must +my +myself +n +name +namely +nd +near +nearly +necessary +need +needs +neither +never +nevertheless +new +next +nine +no +nobody +non +none +noone +nor +normally +not +nothing +novel +now +nowhere +o +obviously +of +off +often +oh +ok +okay +old +on +once +one +ones +only +onto +or +other +others +otherwise +ought +our +ours +ourselves +out +outside +over +overall +own +p +particular +particularly +per +perhaps +placed +please +plus +possible +presumably +probably +provides +q +que +quite +qv +r +rather +rd +re +really +reasonably +regarding +regardless +regards +relatively +respectively +right +s +said +same +saw +say +saying +says +second +secondly +see +seeing +seem +seemed +seeming +seems +seen +self +selves +sensible +sent +serious +seriously +seven +several +shall +she +should +shouldn't +since +six +so +some +somebody +somehow +someone +something +sometime +sometimes +somewhat +somewhere +soon +sorry +specified +specify +specifying +still +sub +such +sup +sure +t +t's +take +taken +tell +tends +th +than +thank +thanks +thanx +that +that's +thats +the +their +theirs +them +themselves +then +thence +there +there's +thereafter +thereby +therefore +therein +theres +thereupon +these +they +they'd +they'll +they're +they've +think +third +this +thorough +thoroughly +those +though +three +through +throughout +thru +thus +to +together +too +took +toward +towards +tried +tries +truly +try +trying +twice +two +u +un +under +unfortunately +unless +unlikely +until +unto +up +upon +us +use +used +useful +uses +using +usually +uucp +v +value +various +very +via +viz +vs +w +want +wants +was +wasn't +way +we +we'd +we'll +we're +we've +welcome +well +went +were +weren't +what +what's +whatever +when +whence +whenever +where +where's +whereafter +whereas +whereby +wherein +whereupon +wherever +whether +which +while +whither +who +who's +whoever +whole +whom +whose +why +will +willing +wish +with +within +without +won't +wonder +would +would +wouldn't +x +y +yes +yet +you +you'd +you'll +you're +you've +your +yours +yourself +yourselves +z +zero diff --git a/data/models/getOpenNLPModels.sh b/data/models/getOpenNLPModels.sh new file mode 100755 index 0000000..7c99cf3 --- /dev/null +++ b/data/models/getOpenNLPModels.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +wget http://opennlp.sourceforge.net/models-1.5/en-ner-location.bin +wget http://opennlp.sourceforge.net/models-1.5/en-token.bin +wget http://opennlp.sourceforge.net/models-1.5/en-sent.bin \ No newline at end of file diff --git a/lib/argot_2.9.1-0.3.5-benwing.jar b/lib/argot_2.9.1-0.3.5-benwing.jar new file mode 100644 index 0000000..f0a4829 Binary files /dev/null and b/lib/argot_2.9.1-0.3.5-benwing.jar differ diff --git a/lib/codeanticode-GLGraphics-0.9.4.jar b/lib/codeanticode-GLGraphics-0.9.4.jar new file mode 100644 index 0000000..99a746f Binary files /dev/null and b/lib/codeanticode-GLGraphics-0.9.4.jar differ diff --git a/lib/controlP5-1.5.2.jar b/lib/controlP5-1.5.2.jar new file mode 100644 index 0000000..3f5b6a3 Binary files /dev/null and b/lib/controlP5-1.5.2.jar differ diff --git a/lib/fhpotsdam-unfolding-0.9.1.jar b/lib/fhpotsdam-unfolding-0.9.1.jar new file mode 100644 index 0000000..bb2472c Binary files /dev/null and b/lib/fhpotsdam-unfolding-0.9.1.jar differ diff --git a/lib/lift-json_2.9.1-2.4.jar b/lib/lift-json_2.9.1-2.4.jar new file mode 100644 index 0000000..fc42c03 Binary files /dev/null and b/lib/lift-json_2.9.1-2.4.jar differ diff --git a/lib/opengl-core-20120724.jar b/lib/opengl-core-20120724.jar new file mode 100644 index 0000000..49561bc Binary files /dev/null and b/lib/opengl-core-20120724.jar differ diff --git a/lib/processing-opengl-20120724.jar b/lib/processing-opengl-20120724.jar new file mode 100644 index 0000000..1a3165e Binary files /dev/null and b/lib/processing-opengl-20120724.jar differ diff --git a/lib/scoobi_2.9.2-0.6.0-cdh3-SNAPSHOT-benwing.jar b/lib/scoobi_2.9.2-0.6.0-cdh3-SNAPSHOT-benwing.jar new file mode 100644 index 0000000..ad912b8 Binary files /dev/null and b/lib/scoobi_2.9.2-0.6.0-cdh3-SNAPSHOT-benwing.jar differ diff --git a/lib/trove-scala_2.9.1-0.0.2-SNAPSHOT.jar b/lib/trove-scala_2.9.1-0.0.2-SNAPSHOT.jar new file mode 100644 index 0000000..ab4c74a Binary files /dev/null and b/lib/trove-scala_2.9.1-0.0.2-SNAPSHOT.jar differ diff --git a/lib/upenn-junto-1.1-assembly.jar b/lib/upenn-junto-1.1-assembly.jar new file mode 100644 index 0000000..17ee04e Binary files /dev/null and b/lib/upenn-junto-1.1-assembly.jar differ diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..ec35b46 --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.8.3") diff --git a/src/main/java/ags/utils/KdTree.java b/src/main/java/ags/utils/KdTree.java new file mode 100644 index 0000000..1e1a34a --- /dev/null +++ b/src/main/java/ags/utils/KdTree.java @@ -0,0 +1,389 @@ +/** + * Copyright 2009 Rednaxela + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * + * 2. This notice may not be removed or altered from any source + * distribution. + */ + +package ags.utils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.ArrayList; +import java.util.List; +import java.util.LinkedList; + +/** + * An efficient well-optimized kd-tree + * + * @author Rednaxela + */ +public class KdTree { + // split method enum + public enum SplitMethod { HALFWAY, MEDIAN, MAX_MARGIN } + + // All types + private final int dimensions; + public final KdTree parent; + private int bucketSize; + private SplitMethod splitMethod; + + // Leaf only + private double[][] locations; + private int locationCount; + + // Stem only + private KdTree left, right; + private int splitDimension; + private double splitValue; + + // Bounds + public double[] minLimit, maxLimit; + private boolean singularity; + + /** + * Construct a KdTree with a given number of dimensions and a limit on + * maxiumum size (after which it throws away old points) + */ + public KdTree(int dimensions, int bucketSize, SplitMethod splitMethod) { + this.bucketSize = bucketSize; + this.dimensions = dimensions; + this.splitMethod = splitMethod; + + // Init as leaf + this.locations = new double[bucketSize][]; + this.locationCount = 0; + this.singularity = true; + + // Init as root + this.parent = null; + } + + /** + * Constructor for child nodes. Internal use only. + */ + private KdTree(KdTree parent, boolean right) { + this.dimensions = parent.dimensions; + this.bucketSize = parent.bucketSize; + this.splitMethod = parent.splitMethod; + + // Init as leaf + this.locations = new double[Math.max(bucketSize, parent.locationCount)][]; + this.locationCount = 0; + this.singularity = true; + + // Init as non-root + this.parent = parent; + } + + /** + * Get the number of points in the tree + */ + public int size() { + return locationCount; + } + + public KdTree getLeaf(double[] location) { + if (left == null || right == null) + return this; + else if (location[splitDimension] <= splitValue) + return left.getLeaf(location); + else + return right.getLeaf(location); + } + + /** + * Add a point and associated value to the tree + */ + public void addPoint(double[] location) { + if (locationCount >= locations.length) { + double[][] newLocations = new double[locations.length * 2][]; + System.arraycopy(locations, 0, newLocations, 0, locationCount); + locations = newLocations; + } + + locations[locationCount] = location; + locationCount++; + extendBounds(location); + } + + /** + * Extends the bounds of this node do include a new location + */ + private final void extendBounds(double[] location) { + if (minLimit == null) { + minLimit = new double[dimensions]; + System.arraycopy(location, 0, minLimit, 0, dimensions); + maxLimit = new double[dimensions]; + System.arraycopy(location, 0, maxLimit, 0, dimensions); + return; + } + + for (int i = 0; i < dimensions; i++) { + if (Double.isNaN(location[i])) { + minLimit[i] = Double.NaN; + maxLimit[i] = Double.NaN; + singularity = false; + } + else if (minLimit[i] > location[i]) { + minLimit[i] = location[i]; + singularity = false; + } + else if (maxLimit[i] < location[i]) { + maxLimit[i] = location[i]; + singularity = false; + } + } + } + + private List getLocations() { + LinkedList l = new LinkedList(); + getLocationsHelper(l); + return l; + } + + private void getLocationsHelper(List l) { + if (left == null || right == null) { + for (int i=0; i width) { + widest = i; + width = nwidth; + } + } + return widest; + } + + public List getNodes() { + List list = new ArrayList(); + this.getNodesHelper(list); + return list; + } + + private void getNodesHelper(List list) { + list.add(this); + if (left != null) left.getNodesHelper(list); + if (right != null) right.getNodesHelper(list); + } + + public List getLeaves() { + List list = new ArrayList(); + this.getLeavesHelper(list); + return list; + } + + private void getLeavesHelper(List list) { + if (left == null && right == null) + list.add(this); + else{ + if (left != null) + left.getLeavesHelper(list); + if (right != null) + right.getLeavesHelper(list); + } + } + + public void balance() { + nodeSplit(this); + } + + private void nodeSplit(KdTree cursor) { + if (cursor.locationCount > cursor.bucketSize) { + cursor.splitDimension = cursor.findWidestAxis(); + + if (splitMethod == SplitMethod.HALFWAY) { + cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + + cursor.maxLimit[cursor.splitDimension]) * 0.5; + } else if (splitMethod == SplitMethod.MEDIAN) { + // split on the median of the elements + List list = new ArrayList(); + for(int i=0; i list = new ArrayList(); + for(int i = 0; i < cursor.locationCount; i++) { + list.add(cursor.locations[i][cursor.splitDimension]); + } + Collections.sort(list); + double maxMargin = 0.0; + double splitValue = Double.NaN; + for (int i = 0; i < list.size() - 1; i++) { + double delta = list.get(i+1) - list.get(i); + if (delta > maxMargin) { + maxMargin = delta; + splitValue = list.get(i) + 0.5 * delta; + } + } + cursor.splitValue = splitValue; + } + + // Never split on infinity or NaN + if (cursor.splitValue == Double.POSITIVE_INFINITY) { + cursor.splitValue = Double.MAX_VALUE; + } + else if (cursor.splitValue == Double.NEGATIVE_INFINITY) { + cursor.splitValue = -Double.MAX_VALUE; + } + else if (Double.isNaN(cursor.splitValue)) { + cursor.splitValue = 0; + } + + // Don't split node if it has no width in any axis. Double the + // bucket size instead + if (cursor.minLimit[cursor.splitDimension] == cursor.maxLimit[cursor.splitDimension]) { + double[][] newLocations = new double[cursor.locations.length * 2][]; + System.arraycopy(cursor.locations, 0, newLocations, 0, cursor.locationCount); + cursor.locations = newLocations; + return; + } + + // Don't let the split value be the same as the upper value as + // can happen due to rounding errors! + if (cursor.splitValue == cursor.maxLimit[cursor.splitDimension]) { + cursor.splitValue = cursor.minLimit[cursor.splitDimension]; + } + + // Create child leaves + KdTree left = new ChildNode(cursor, false); + KdTree right = new ChildNode(cursor, true); + + // Move locations into children + for (int i = 0; i < cursor.locationCount; i++) { + double[] oldLocation = cursor.locations[i]; + if (oldLocation[cursor.splitDimension] > cursor.splitValue) { + // Right + right.locations[right.locationCount] = oldLocation; + right.locationCount++; + right.extendBounds(oldLocation); + } + else { + // Left + left.locations[left.locationCount] = oldLocation; + left.locationCount++; + left.extendBounds(oldLocation); + } + } + + // Make into stem + cursor.left = left; + cursor.right = right; + cursor.locations = null; + cursor.nodeSplit(left); + cursor.nodeSplit(right); + } + } + + + protected double pointDist(double[] p1, double[] p2) { + double d = 0; + + for (int i = 0; i < p1.length; i++) { + double diff = (p1[i] - p2[i]); + if (!Double.isNaN(diff)) { + d += diff * diff; + } + } + + return d; + } + + protected double pointRegionDist(double[] point, double[] min, double[] max) { + double d = 0; + + for (int i = 0; i < point.length; i++) { + double diff = 0; + if (point[i] > max[i]) { + diff = (point[i] - max[i]); + } + else if (point[i] < min[i]) { + diff = (point[i] - min[i]); + } + + if (!Double.isNaN(diff)) { + d += diff * diff; + } + } + + return d; + } + + protected double getAxisWeightHint(int i) { + return 1.0; + } + + /** + * Internal class for child nodes + */ + private class ChildNode extends KdTree { + private ChildNode(KdTree parent, boolean right) { + super(parent, right); + } + + // Distance measurements are always called from the root node + protected double pointDist(double[] p1, double[] p2) { + throw new IllegalStateException(); + } + + protected double pointRegionDist(double[] point, double[] min, double[] max) { + throw new IllegalStateException(); + } + } + + private static String darrayToString(double[] array) { + String retval = ""; + for (int i = 0; i < array.length; i++) { + retval += array[i] + " "; + } + return retval; + } + +} + diff --git a/src/main/java/opennlp/fieldspring/tr/app/BaseApp.java b/src/main/java/opennlp/fieldspring/tr/app/BaseApp.java new file mode 100644 index 0000000..e03aec8 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/BaseApp.java @@ -0,0 +1,471 @@ +/* + * Base app for running resolvers and/or other functionality such as evaluation and visualization generation. + */ + +package opennlp.fieldspring.tr.app; + +import org.apache.commons.cli.*; +import opennlp.fieldspring.tr.topo.*; +import java.io.*; + +public class BaseApp { + + private Options options = new Options(); + + private String inputPath = null; + private String additionalInputPath = null; + private String graphInputPath = null; + private String outputPath = null; + private String xmlInputPath = null; + private String kmlOutputPath = null; + private String dKmlOutputPath = null; + private String logFilePath = null; + private boolean outputGoldLocations = false; + private boolean outputUserKML = false; + protected boolean useGoldToponyms = false; + private String geoGazetteerFilename = null; + private String serializedGazetteerPath = null; + private String serializedCorpusInputPath = null; + private String serializedCorpusOutputPath = null; + private String maxentModelDirInputPath = null; + + private double popComponentCoefficient = 0.0; + private boolean dgProbOnly = false; + private boolean meProbOnly = false; + + private boolean doOracleEval = false; + + private int sentsPerDocument = -1; + + private boolean highRecallNER = false; + + private Region boundingBox = null; + + protected boolean doKMeans = false; + + private String graphOutputPath = null; + private String seedOutputPath = null; + private String wikiInputPath = null; + private String stoplistInputPath = null; + + private int numIterations = 1; + private boolean readWeightsFromFile = false; + + private int knnForLP = -1; + + private double dpc = 10; + private double threshold = -1.0; // sentinel value indicating NOT to use threshold, not real default + + public static enum RESOLVER_TYPE { + RANDOM, + POPULATION, + BASIC_MIN_DIST, + WEIGHTED_MIN_DIST, + DOC_DIST, + TOPO_AS_DOC_DIST, + LABEL_PROP, + LABEL_PROP_DEFAULT_RULE, + LABEL_PROP_CONTEXT_SENSITIVE, + LABEL_PROP_COMPLEX, + MAXENT, + PROB, + BAYES_RULE, + CONSTRUCTION_TPP, + HEURISTIC_TPP + } + protected Enum resolverType = RESOLVER_TYPE.BASIC_MIN_DIST; + + public static enum CORPUS_FORMAT { + PLAIN, + TRCONLL, + GEOTEXT, + WIKITEXT + } + protected Enum corpusFormat = CORPUS_FORMAT.PLAIN; + + + protected void initializeOptionsFromCommandLine(String[] args) throws Exception { + + options.addOption("i", "input", true, "input path"); + options.addOption("ix", "input-xml", true, "xml input path"); + options.addOption("im", "input-models", true, "maxent model input directory"); + options.addOption("ia", "input-additional", true, "path to additional input data to be used in training but not evaluation"); + options.addOption("ig", "input-graph", true, "path to input graph for label propagation resolvers"); + options.addOption("r", "resolver", true, "resolver (RandomResolver, BasicMinDistResolver, WeightedMinDistResolver, LabelPropDefaultRuleResolver, LabelPropContextSensitiveResolver, LabelPropComplexResolver) [default = BasicMinDistResolver]"); + options.addOption("it", "iterations", true, "number of iterations for iterative models [default = 1]"); + options.addOption("rwf", "read-weights-file", false, "read initial weights from probToWMD.dat"); + options.addOption("o", "output", true, "output path"); + options.addOption("ok", "output-kml", true, "kml output path"); + options.addOption("okd", "output-kml-dynamic", true, "dynamic kml output path"); + options.addOption("oku", "output-kml-users", false, "output user-based KML rather than toponym-based KML"); + options.addOption("gold", "output-gold-locations", false, "output gold locations rather than system locations in KML"); + options.addOption("gt", "gold-toponyms", false, "use gold toponyms (named entities) if available"); + options.addOption("g", "geo-gazetteer-filename", true, "GeoNames gazetteer filename"); + options.addOption("sg", "serialized-gazetteer-path", true, "path to serialized GeoNames gazetteer"); + options.addOption("sci", "serialized-corpus-input-path", true, "path to serialized corpus for input"); + //options.addOption("sgci", "serialized-gold-corpus-input-path", true, "path to serialized gold corpus for input"); + options.addOption("sco", "serialized-corpus-output-path", true, "path to serialized corpus for output"); + //options.addOption("tr", "tr-conll", false, "read input path as TR-CoNLL directory"); + options.addOption("cf", "corpus-format", true, "corpus format (Plain, TrCoNLL, GeoText) [default = Plain]"); + + options.addOption("oracle", "oracle", false, "use oracle evaluation"); + + options.addOption("spd", "sentences-per-document", true, "sentences per document (-1 for unlimited) [default = -1]"); + + options.addOption("pc", "pop-comp-coeff", true, "population component coefficient (for PROBABILISTIC resolver)"); + options.addOption("pdg", "prob-doc-geo", false, "use probability from document geolocator only (for PROBABILISTIC resolver)"); + options.addOption("pme", "prob-maxent", false, "use probability from MaxEnt local context component only (for PROBABILISTIC resolver)"); + + options.addOption("minlat", "minimum-latitude", true, + "minimum latitude for bounding box"); + options.addOption("maxlat", "maximum-latitude", true, + "maximum latitude for bounding box"); + options.addOption("minlon", "minimum-longitude", true, + "minimum longitude for bounding box"); + options.addOption("maxlon", "maximum-longitude", true, + "maximum longitude for bounding box"); + + options.addOption("dkm", "do-k-means-multipoints", false, + "(import-gazetteer only) run k-means and create multipoint representations of regions (e.g. countries)"); + + options.addOption("og", "output-graph", true, + "(preprocess-labelprop only) path to output graph file"); + options.addOption("os", "output-seed", true, + "(preprocess-labelprop only) path to output seed file"); + options.addOption("iw", "input-wiki", true, + "(preprocess-labelprop only) path to wikipedia file (article titles, article IDs, and word lists)"); + options.addOption("is", "input-stoplist", true, + "(preprocess-labelprob only) path to stop list input file (one stop word per line)"); + + options.addOption("l", "log-file-input", true, "log file input, from document geolocation"); + options.addOption("knn", "knn", true, "k nearest neighbors to consider from document geolocation log file"); + + options.addOption("dpc", "degrees-per-cell", true, "degrees per cell for grid-based TPP resolvers"); + options.addOption("t", "threshold", true, "threshold in kilometers for agglomerative clustering"); + + options.addOption("ner", "named-entity-recognizer", true, + "option for using High Recall NER"); + + options.addOption("h", "help", false, "print help"); + + Double minLat = null; + Double maxLat = null; + Double minLon = null; + Double maxLon = null; + + CommandLineParser optparse = new PosixParser(); + CommandLine cline = optparse.parse(options, args); + + if (cline.hasOption('h')) { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp("fieldspring [command] ", options); + System.exit(0); + } + + for (Option option : cline.getOptions()) { + String value = option.getValue(); + switch (option.getOpt().charAt(0)) { + case 'i': + if(option.getOpt().equals("i")) + inputPath = value; + else if(option.getOpt().equals("it")) + numIterations = Integer.parseInt(value); + else if(option.getOpt().equals("ia")) + additionalInputPath = value; + else if(option.getOpt().equals("ig")) + graphInputPath = value; + else if(option.getOpt().equals("iw")) + wikiInputPath = value; + else if(option.getOpt().equals("is")) + stoplistInputPath = value; + else if(option.getOpt().equals("ix")) + xmlInputPath = value; + else if(option.getOpt().equals("im")) + maxentModelDirInputPath = value; + break; + case 'o': + if(option.getOpt().equals("o")) + outputPath = value; + else if(option.getOpt().equals("og")) + graphOutputPath = value; + else if(option.getOpt().equals("ok")) + kmlOutputPath = value; + else if(option.getOpt().equals("oku")) + outputUserKML = true; + else if(option.getOpt().equals("okd")) + dKmlOutputPath = value; + else if(option.getOpt().equals("os")) + seedOutputPath = value; + else if(option.getOpt().equals("oracle")) + doOracleEval = true; + break; + case 'r': + if(option.getOpt().equals("r")) { + if(value.toLowerCase().startsWith("r")) + resolverType = RESOLVER_TYPE.RANDOM; + else if(value.toLowerCase().startsWith("w")) + resolverType = RESOLVER_TYPE.WEIGHTED_MIN_DIST; + else if(value.toLowerCase().startsWith("d")) + resolverType = RESOLVER_TYPE.DOC_DIST; + else if(value.toLowerCase().startsWith("t")) + resolverType = RESOLVER_TYPE.TOPO_AS_DOC_DIST; + else if(value.equalsIgnoreCase("labelprop")) + resolverType = RESOLVER_TYPE.LABEL_PROP; + else if(value.toLowerCase().startsWith("labelpropd")) + resolverType = RESOLVER_TYPE.LABEL_PROP_DEFAULT_RULE; + else if(value.toLowerCase().startsWith("labelpropcontext")) + resolverType = RESOLVER_TYPE.LABEL_PROP_CONTEXT_SENSITIVE; + else if(value.toLowerCase().startsWith("labelpropcomplex")) + resolverType = RESOLVER_TYPE.LABEL_PROP_COMPLEX; + else if(value.toLowerCase().startsWith("m")) + resolverType = RESOLVER_TYPE.MAXENT; + else if(value.toLowerCase().startsWith("pr")) + resolverType = RESOLVER_TYPE.PROB; + else if(value.toLowerCase().startsWith("po")) + resolverType = RESOLVER_TYPE.POPULATION; + else if(value.toLowerCase().startsWith("bayes")) + resolverType = RESOLVER_TYPE.BAYES_RULE; + else if(value.toLowerCase().startsWith("h")) + resolverType = RESOLVER_TYPE.HEURISTIC_TPP; + else if(value.toLowerCase().startsWith("c")) + resolverType = RESOLVER_TYPE.CONSTRUCTION_TPP; + else + resolverType = RESOLVER_TYPE.BASIC_MIN_DIST; + } + else if(option.getOpt().equals("rwf")) { + readWeightsFromFile = true; + } + break; + case 'g': + if(option.getOpt().equals("g")) + geoGazetteerFilename = value; + else if(option.getOpt().equals("gold")) + outputGoldLocations = true; + else if(option.getOpt().equals("gt")) + useGoldToponyms = true; + break; + case 's': + if(option.getOpt().equals("sg")) + serializedGazetteerPath = value; + else if(option.getOpt().equals("sci")) + serializedCorpusInputPath = value; + else if(option.getOpt().equals("sco")) + serializedCorpusOutputPath = value; + else if(option.getOpt().equals("spd")) + sentsPerDocument = Integer.parseInt(value); + //else if(option.getOpt().equals("sgci")) + // serializedGoldCorpusInputPath = value; + break; + case 'c': + if(value.toLowerCase().startsWith("t")) + corpusFormat = CORPUS_FORMAT.TRCONLL; + else if(value.toLowerCase().startsWith("g")) + corpusFormat = CORPUS_FORMAT.GEOTEXT; + else//if(value.toLowerCase().startsWith("p")) + corpusFormat = CORPUS_FORMAT.PLAIN; + break; + /*case 't': + readAsTR = true; + break;*/ + case 'l': + if(option.getOpt().equals("l")) + logFilePath = value; + break; + case 'k': + if(option.getOpt().equals("knn")) + knnForLP = Integer.parseInt(value); + break; + case 'm': + if(option.getOpt().equals("minlat")) + minLat = Double.parseDouble(value.replaceAll("n", "-")); + else if(option.getOpt().equals("maxlat")) + maxLat = Double.parseDouble(value.replaceAll("n", "-")); + else if(option.getOpt().equals("minlon")) + minLon = Double.parseDouble(value.replaceAll("n", "-")); + else if(option.getOpt().equals("maxlon")) + maxLon = Double.parseDouble(value.replaceAll("n", "-")); + break; + case 'n': + if(option.getOpt().equals("ner")) + setHighRecallNER(new Integer(value)!=0); + break; + case 'd': + if(option.getOpt().equals("dkm")) + doKMeans = true; + else if(option.getOpt().equals("dpc")) + dpc = Double.parseDouble(value); + break; + case 't': + threshold = Double.parseDouble(value); + break; + case 'p': + if(option.getOpt().equals("pc")) + popComponentCoefficient = Double.parseDouble(value); + else if(option.getOpt().equals("pme")) + meProbOnly = true; + else + dgProbOnly = true; + } + } + + if(minLat != null && maxLat != null && minLon != null && maxLon != null) + boundingBox = RectRegion.fromDegrees(minLat, maxLat, minLon, maxLon); + } + + public static void checkExists(String filename) throws Exception { + if(filename == null) { + System.out.println("Null filename; aborting."); + System.exit(0); + } + File f = new File(filename); + if(!f.exists()) { + System.out.println(filename + " doesn't exist; aborting."); + System.exit(0); + } + } + + public String getInputPath() { + return inputPath; + } + + public String getXMLInputPath() { + return xmlInputPath; + } + + public String getAdditionalInputPath() { + return additionalInputPath; + } + + public String getMaxentModelDirInputPath() { + return maxentModelDirInputPath; + } + + public String getGraphInputPath() { + return graphInputPath; + } + + public Enum getResolverType() { + return resolverType; + } + + public int getNumIterations() { + return numIterations; + } + + public boolean getReadWeightsFromFile() { + return readWeightsFromFile; + } + + public String getOutputPath() { + return outputPath; + } + + public String getKMLOutputPath() { + return kmlOutputPath; + } + + public String getDKMLOutputPath() { + return dKmlOutputPath; + } + + public boolean getOutputGoldLocations() { + return outputGoldLocations; + } + + public boolean getUseGoldToponyms() { + return useGoldToponyms; + } + + public boolean getOutputUserKML() { + return outputUserKML; + } + + public String getGeoGazetteerFilename() { + return geoGazetteerFilename; + } + + public String getSerializedGazetteerPath() { + return serializedGazetteerPath; + } + + public String getSerializedCorpusInputPath() { + return serializedCorpusInputPath; + } + + public String getSerializedCorpusOutputPath() { + return serializedCorpusOutputPath; + } + + + public Enum getCorpusFormat() { + return corpusFormat; + } + + public int getSentsPerDocument() { + return sentsPerDocument; + } + + public boolean isDoingKMeans() { + return doKMeans; + } + + public String getGraphOutputPath() { + return graphOutputPath; + } + + public String getSeedOutputPath() { + return seedOutputPath; + } + + public String getWikiInputPath() { + return wikiInputPath; + } + + public String getStoplistInputPath() { + return stoplistInputPath; + } + + public String getLogFilePath() { + return logFilePath; + } + + public int getKnnForLP() { + return knnForLP; + } + + public Region getBoundingBox() { + return boundingBox; + } + + public double getPopComponentCoefficient() { + return popComponentCoefficient; + } + + public boolean getDGProbOnly() { + return dgProbOnly; + } + + public boolean getMEProbOnly() { + return meProbOnly; + } + + public boolean getDoOracleEval() { + return doOracleEval; + } + + public double getDPC() { + return dpc; + } + + public double getThreshold() { + return threshold; + } + + public void setHighRecallNER(boolean highRecallNER) { + highRecallNER = highRecallNER; + } + + public boolean isHighRecallNER() { + return highRecallNER; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/EvaluateCorpus.java b/src/main/java/opennlp/fieldspring/tr/app/EvaluateCorpus.java new file mode 100644 index 0000000..1270b21 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/EvaluateCorpus.java @@ -0,0 +1,103 @@ +/* Evaluates a given corpus with system disambiguated toponyms againt a given gold corpus. + */ + +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.resolver.*; +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.eval.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; +import java.util.zip.*; + +public class EvaluateCorpus extends BaseApp { + + public static void main(String[] args) throws Exception { + + EvaluateCorpus currentRun = new EvaluateCorpus(); + currentRun.initializeOptionsFromCommandLine(args); + + if(currentRun.getCorpusFormat() == CORPUS_FORMAT.TRCONLL) { + if(currentRun.getInputPath() == null || (currentRun.getSerializedCorpusInputPath() == null && currentRun.getXMLInputPath() == null)) { + System.out.println("Please specify both a system annotated corpus file via the -sci or -ix flag and a gold plaintext corpus file via the -i flag."); + System.exit(0); + } + } + else { + if(currentRun.getSerializedCorpusInputPath() == null && currentRun.getXMLInputPath() == null) { + System.out.println("Please specify a system annotated corpus file via the -sci or -ix flag."); + System.exit(0); + } + } + + StoredCorpus systemCorpus; + if(currentRun.getSerializedCorpusInputPath() != null) { + System.out.print("Reading serialized system corpus from " + currentRun.getSerializedCorpusInputPath() + " ..."); + systemCorpus = TopoUtil.readStoredCorpusFromSerialized(currentRun.getSerializedCorpusInputPath()); + System.out.println("done."); + } + else {// if(getXMLInputPath() != null) { + Tokenizer tokenizer = new OpenNLPTokenizer(); + systemCorpus = Corpus.createStoredCorpus(); + systemCorpus.addSource(new CorpusXMLSource(new BufferedReader(new FileReader(currentRun.getXMLInputPath())), + tokenizer)); + systemCorpus.setFormat(currentRun.getCorpusFormat()==null?CORPUS_FORMAT.PLAIN:currentRun.getCorpusFormat()); + systemCorpus.load(); + } + + StoredCorpus goldCorpus = null; + + if(currentRun.getInputPath() != null && currentRun.getCorpusFormat() == CORPUS_FORMAT.TRCONLL) { + Tokenizer tokenizer = new OpenNLPTokenizer(); + System.out.print("Reading plaintext gold corpus from " + currentRun.getInputPath() + " ..."); + goldCorpus = Corpus.createStoredCorpus(); + goldCorpus.addSource(new TrXMLDirSource(new File(currentRun.getInputPath()), tokenizer)); + goldCorpus.load(); + System.out.println("done."); + } + + currentRun.doEval(systemCorpus, goldCorpus, currentRun.getCorpusFormat(), currentRun.getUseGoldToponyms(), currentRun.getDoOracleEval()); + } + + public void doEval(Corpus systemCorpus, Corpus goldCorpus, Enum corpusFormat, boolean useGoldToponyms) throws Exception { + this.doEval(systemCorpus, goldCorpus, corpusFormat, useGoldToponyms, false); + } + + public void doEval(Corpus systemCorpus, Corpus goldCorpus, Enum corpusFormat, boolean useGoldToponyms, boolean doOracleEval) throws Exception { + System.out.print("\nEvaluating..."); + if(corpusFormat == CORPUS_FORMAT.GEOTEXT) { + DocDistanceEvaluator evaluator = new DocDistanceEvaluator(systemCorpus); + DistanceReport dreport = evaluator.evaluate(); + + System.out.println("\nMinimum error distance (km): " + dreport.getMinDistance()); + System.out.println("Maximum error distance (km): " + dreport.getMaxDistance()); + System.out.println("\nMean error distance (km): " + dreport.getMeanDistance()); + System.out.println("Median error distance (km): " + dreport.getMedianDistance()); + System.out.println("Fraction of distances within 161 km: " + dreport.getFractionDistancesWithinThreshold(161.0)); + System.out.println("\nTotal documents evaluated: " + dreport.getNumDistances()); + } + + else { + SignatureEvaluator evaluator = new SignatureEvaluator(goldCorpus, doOracleEval); + Report report = evaluator.evaluate(systemCorpus, false); + DistanceReport dreport = evaluator.getDistanceReport(); + + System.out.println("\nP: " + report.getPrecision()); + System.out.println("R: " + report.getRecall()); + System.out.println("F: " + report.getFScore()); + //System.out.println("A: " + report.getAccuracy()); + + System.out.println("\nMinimum error distance (km): " + dreport.getMinDistance()); + System.out.println("Maximum error distance (km): " + dreport.getMaxDistance()); + System.out.println("\nMean error distance (km): " + dreport.getMeanDistance()); + System.out.println("Median error distance (km): " + dreport.getMedianDistance()); + System.out.println("Fraction of distances within 161 km: " + dreport.getFractionDistancesWithinThreshold(161.0)); + System.out.println("\nTotal toponyms evaluated: " + dreport.getNumDistances()); + } + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/ImportCorpus.java b/src/main/java/opennlp/fieldspring/tr/app/ImportCorpus.java new file mode 100644 index 0000000..316a246 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/ImportCorpus.java @@ -0,0 +1,147 @@ +/* This class takes a serialized gazetteer and a corpus, and outputs a preprocessed, serialized version of that corpus, ready to be read in quickly by RunResolver. + */ + +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import java.io.*; +import java.util.zip.*; + +public class ImportCorpus extends BaseApp { + + //private static int sentsPerDocument; + + public static void main(String[] args) throws Exception { + + ImportCorpus currentRun = new ImportCorpus(); + + currentRun.initializeOptionsFromCommandLine(args); + //sentsPerDocument = currentRun.getSentsPerDocument(); + + if(currentRun.getSerializedCorpusOutputPath() == null && currentRun.getOutputPath() == null) { + System.out.println("Please specify a serialized corpus output file with the -sco flag and/or an XML output file with the -o flag."); + System.exit(0); + } + + StoredCorpus corpus = currentRun.doImport(currentRun.getInputPath(), currentRun.getSerializedGazetteerPath(), currentRun.getCorpusFormat(), currentRun.getUseGoldToponyms(), currentRun.getSentsPerDocument()); + + if(currentRun.getSerializedCorpusOutputPath() != null) + currentRun.serialize(corpus, currentRun.getSerializedCorpusOutputPath()); + if(currentRun.getOutputPath() != null) + currentRun.writeToXML(corpus, currentRun.getOutputPath()); + } + + public StoredCorpus doImport(String corpusInputPath, String serGazInputPath, + Enum corpusFormat) throws Exception { + return doImport(corpusInputPath, serGazInputPath, corpusFormat, false, -1); + } + + public StoredCorpus doImport(String corpusInputPath, String serGazInputPath, + Enum corpusFormat, + boolean useGoldToponyms, int sentsPerDocument) throws Exception { + + checkExists(corpusInputPath); + if(!useGoldToponyms || doKMeans) + checkExists(serGazInputPath); + + Tokenizer tokenizer = new OpenNLPTokenizer(); + OpenNLPRecognizer recognizer = new OpenNLPRecognizer(); + + GeoNamesGazetteer gnGaz = null; + System.out.println("Reading serialized GeoNames gazetteer from " + serGazInputPath + " ..."); + ObjectInputStream ois = null; + if(serGazInputPath.toLowerCase().endsWith(".gz")) { + GZIPInputStream gis = new GZIPInputStream(new FileInputStream(serGazInputPath)); + ois = new ObjectInputStream(gis); + } + else { + FileInputStream fis = new FileInputStream(serGazInputPath); + ois = new ObjectInputStream(fis); + } + gnGaz = (GeoNamesGazetteer) ois.readObject(); + if(isHighRecallNER()) + recognizer = new HighRecallToponymRecognizer(gnGaz.getUniqueLocationNameSet()); + System.out.println("Done."); + + System.out.print("Reading raw corpus from " + corpusInputPath + " ..."); + StoredCorpus corpus = Corpus.createStoredCorpus(); + if(corpusFormat == CORPUS_FORMAT.TRCONLL) { + File corpusInputFile = new File(corpusInputPath); + if(useGoldToponyms) { + if(corpusInputFile.isDirectory()) + corpus.addSource(new CandidateRepopulator(new TrXMLDirSource(new File(corpusInputPath), tokenizer, sentsPerDocument), gnGaz)); + else + corpus.addSource(new CandidateRepopulator(new TrXMLSource(new BufferedReader(new FileReader(corpusInputPath)), tokenizer, sentsPerDocument), gnGaz)); + } + else { + if(corpusInputFile.isDirectory()) + corpus.addSource(new ToponymAnnotator( + new ToponymRemover(new TrXMLDirSource(new File(corpusInputPath), tokenizer, sentsPerDocument)), + recognizer, gnGaz, null)); + else + corpus.addSource(new ToponymAnnotator( + new ToponymRemover(new TrXMLSource(new BufferedReader(new FileReader(corpusInputPath)), tokenizer, sentsPerDocument)), + recognizer, gnGaz, null)); + } + } + else if(corpusFormat == CORPUS_FORMAT.GEOTEXT) { + corpus.addSource(new ToponymAnnotator(new GeoTextSource( + new BufferedReader(new FileReader(corpusInputPath)), tokenizer), + recognizer, gnGaz, null)); + } + else if (corpusInputPath.endsWith("txt")) { + corpus.addSource(new ToponymAnnotator(new PlainTextSource( + new BufferedReader(new FileReader(corpusInputPath)), new OpenNLPSentenceDivider(), tokenizer, corpusInputPath), + recognizer, gnGaz, null)); + } + else { + corpus.addSource(new ToponymAnnotator(new PlainTextDirSource( + new File(corpusInputPath), new OpenNLPSentenceDivider(), tokenizer), + recognizer, gnGaz, null)); + } + corpus.setFormat(corpusFormat); + //if(corpusFormat != CORPUS_FORMAT.GEOTEXT) + corpus.load(); + System.out.println("done."); + + System.out.println("\nNumber of documents: " + corpus.getDocumentCount()); + System.out.println("Number of word tokens: " + corpus.getTokenCount()); + System.out.println("Number of word types: " + corpus.getTokenTypeCount()); + System.out.println("Number of toponym tokens: " + corpus.getToponymTokenCount()); + System.out.println("Number of toponym types: " + corpus.getToponymTypeCount()); + System.out.println("Average ambiguity (locations per toponym): " + corpus.getAvgToponymAmbiguity()); + System.out.println("Maximum ambiguity (locations per toponym): " + corpus.getMaxToponymAmbiguity()); + + return corpus; + } + + public void serialize(Corpus corpus, String serializedCorpusPath) throws Exception { + + System.out.print("\nSerializing corpus to " + serializedCorpusPath + " ..."); + + ObjectOutputStream oos = null; + if(serializedCorpusPath.toLowerCase().endsWith(".gz")) { + GZIPOutputStream gos = new GZIPOutputStream(new FileOutputStream(serializedCorpusPath)); + oos = new ObjectOutputStream(gos); + } + else { + FileOutputStream fos = new FileOutputStream(serializedCorpusPath); + oos = new ObjectOutputStream(fos); + } + oos.writeObject(corpus); + oos.close(); + + System.out.println("done."); + } + + public void writeToXML(Corpus corpus, String xmlOutputPath) throws Exception { + System.out.print("\nWriting corpus in XML format to " + xmlOutputPath + " ..."); + CorpusXMLWriter w = new CorpusXMLWriter(corpus); + w.write(new File(xmlOutputPath)); + System.out.println("done."); + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/ImportGazetteer.java b/src/main/java/opennlp/fieldspring/tr/app/ImportGazetteer.java new file mode 100644 index 0000000..9f11335 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/ImportGazetteer.java @@ -0,0 +1,59 @@ +/* + * This class imports a gazetteer from a text file and serializes it, to be read quickly by RunResolver quickly. + */ + +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.zip.*; + +public class ImportGazetteer extends BaseApp { + + public static void main(String[] args) throws Exception { + ImportGazetteer currentRun = new ImportGazetteer(); + currentRun.initializeOptionsFromCommandLine(args); + currentRun.serialize(currentRun.doImport(currentRun.getInputPath(), currentRun.isDoingKMeans()), currentRun.getOutputPath()); + } + + public GeoNamesGazetteer doImport(String gazInputPath, boolean runKMeans) throws Exception { + System.out.println("Reading GeoNames gazetteer from " + gazInputPath + " ..."); + + checkExists(gazInputPath); + + GeoNamesGazetteer gnGaz = null; + if(gazInputPath.toLowerCase().endsWith(".zip")) { + ZipFile zf = new ZipFile(gazInputPath); + ZipInputStream zis = new ZipInputStream(new FileInputStream(gazInputPath)); + ZipEntry ze = zis.getNextEntry(); + gnGaz = new GeoNamesGazetteer(new BufferedReader(new InputStreamReader(zf.getInputStream(ze))), runKMeans); + zis.close(); + } + else { + gnGaz = new GeoNamesGazetteer(new BufferedReader(new FileReader(gazInputPath)), runKMeans); + } + + System.out.println("Done."); + + return gnGaz; + } + + public void serialize(GeoNamesGazetteer gnGaz, String serializedGazOutputPath) throws Exception { + System.out.print("Serializing GeoNames gazetteer to " + serializedGazOutputPath + " ..."); + + ObjectOutputStream oos = null; + if(serializedGazOutputPath.toLowerCase().endsWith(".gz")) { + GZIPOutputStream gos = new GZIPOutputStream(new FileOutputStream(serializedGazOutputPath)); + oos = new ObjectOutputStream(gos); + } + else { + FileOutputStream fos = new FileOutputStream(serializedGazOutputPath); + oos = new ObjectOutputStream(fos); + } + oos.writeObject(gnGaz); + oos.close(); + + System.out.println("done."); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreproc.java b/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreproc.java new file mode 100644 index 0000000..bea0b9c --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreproc.java @@ -0,0 +1,215 @@ +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; +import java.util.zip.*; + +public class LabelPropPreproc extends BaseApp { + + + private static final double DPC = 1.0; // degrees per cell + + // the public constants are used by LabelPropComplexResolver // + private static final String CELL_ = "cell_"; + public static final String CELL_LABEL_ = "cell_label_"; + private static final String LOC_ = "loc_"; + private static final String TPNM_TYPE_ = "tpnm_type_"; + public static final String DOC_ = "doc_"; + private static final String TYPE_ = "type_"; + public static final String TOK_ = "tok_"; + + + + public static void main(String[] args) throws Exception { + + LabelPropPreproc currentRun = new LabelPropPreproc(); + + currentRun.initializeOptionsFromCommandLine(args); + StoredCorpus corpus = currentRun.loadCorpus(currentRun.getInputPath(), currentRun.getSerializedGazetteerPath(), currentRun.getSerializedCorpusInputPath(), currentRun.getCorpusFormat()); + + Map > locationCellEdges = new HashMap >(); + Set uniqueToponyms = new HashSet(); + Map > docToponyms = new HashMap >(); + Map docTokenToDocTypeEdges = new HashMap(); + Map toponymTokenToDocEdges = new HashMap(); + Map linearTopTokToTopTokEdges = new HashMap(); + + for(Document doc : corpus) { + docToponyms.put(doc.getId(), new HashSet()); + int tokenIndex = 0; + String prevTokenString = null; + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + uniqueToponyms.add(toponym); + docToponyms.get(doc.getId()).add(toponym); + String curTokenString = DOC_ + doc.getId() + "_" + TOK_ + tokenIndex; + docTokenToDocTypeEdges.put(curTokenString, + DOC_ + doc.getId() + "_" + TYPE_ + toponym.getForm()); + toponymTokenToDocEdges.put(curTokenString, DOC_ + doc.getId()); + if(prevTokenString != null) + linearTopTokToTopTokEdges.put(prevTokenString, curTokenString); + prevTokenString = curTokenString; + for(Location location : toponym.getCandidates()) { + int locationID = location.getId(); + Set curLocationCellEdges = locationCellEdges.get(locationID); + if(curLocationCellEdges != null) + continue; // already processed this location + curLocationCellEdges = new HashSet(); + for(int cellNumber : TopoUtil.getCellNumbers(location, DPC)) { + curLocationCellEdges.add(cellNumber); + } + locationCellEdges.put(locationID, curLocationCellEdges); + } + } + } + tokenIndex++; + } + } + + currentRun.writeCellSeeds(locationCellEdges, currentRun.getSeedOutputPath()); + + currentRun.writeCellCellEdges(currentRun.getGraphOutputPath()); + currentRun.writeLocationCellEdges(locationCellEdges, currentRun.getGraphOutputPath()); + currentRun.writeToponymTypeLocationEdges(uniqueToponyms, currentRun.getGraphOutputPath()); + currentRun.writeDocTypeToponymTypeEdges(docToponyms, currentRun.getGraphOutputPath()); + currentRun.writeStringStringEdges(docTokenToDocTypeEdges, currentRun.getGraphOutputPath()); + currentRun.writeStringStringEdges(toponymTokenToDocEdges, currentRun.getGraphOutputPath()); + currentRun.writeStringStringEdges(linearTopTokToTopTokEdges, currentRun.getGraphOutputPath()); + } + + private void writeCellSeeds(Map > locationCellEdges, String seedOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(seedOutputPath)); + + Set uniqueCellNumbers = new HashSet(); + + for(int locationID : locationCellEdges.keySet()) { + Set curLocationCellEdges = locationCellEdges.get(locationID); + uniqueCellNumbers.addAll(curLocationCellEdges); + } + + for(int cellNumber : uniqueCellNumbers) { + writeEdge(out, CELL_ + cellNumber, CELL_LABEL_ + cellNumber, 1.0); + } + + out.close(); + } + + private void writeCellCellEdges(String graphOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(graphOutputPath)); + + for(int lon = 0; lon < 360 / DPC; lon += DPC) { + for(int lat = 0; lat < 180 / DPC; lat += DPC) { + int curCellNumber = TopoUtil.getCellNumber(lat, lon, DPC); + int leftCellNumber = TopoUtil.getCellNumber(lat, lon - DPC, DPC); + int rightCellNumber = TopoUtil.getCellNumber(lat, lon + DPC, DPC); + int topCellNumber = TopoUtil.getCellNumber(lat + DPC, lon, DPC); + int bottomCellNumber = TopoUtil.getCellNumber(lat - DPC, lon, DPC); + + writeEdge(out, CELL_ + curCellNumber, CELL_ + leftCellNumber, 1.0); + writeEdge(out, CELL_ + curCellNumber, CELL_ + rightCellNumber, 1.0); + if(topCellNumber >= 0) + writeEdge(out, CELL_ + curCellNumber, CELL_ + topCellNumber, 1.0); + if(bottomCellNumber >= 0) + writeEdge(out, CELL_ + curCellNumber, CELL_ + bottomCellNumber, 1.0); + } + } + + out.close(); + } + + private void writeLocationCellEdges(Map > locationCellEdges, String graphOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(graphOutputPath, true)); + + for(int locationID : locationCellEdges.keySet()) { + Set curLocationCellEdges = locationCellEdges.get(locationID); + for(int cellNumber : curLocationCellEdges) { + writeEdge(out, LOC_ + locationID, CELL_ + cellNumber, 1.0); + } + //if(curLocationCellEdges.size() > 1) + // System.out.println("Wrote " + curLocationCellEdges.size() + " edges for location " + locationID); + } + + out.close(); + } + + private void writeToponymTypeLocationEdges(Set uniqueToponyms, String graphOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(graphOutputPath, true)); + + Set toponymNamesAlreadyWritten = new HashSet(); + + for(Toponym toponym : uniqueToponyms) { + if(!toponymNamesAlreadyWritten.contains(toponym.getForm())) { + for(Location location : toponym.getCandidates()) { + writeEdge(out, TPNM_TYPE_ + toponym.getForm(), LOC_ + location.getId(), 1.0); + } + toponymNamesAlreadyWritten.add(toponym.getForm()); + } + } + + out.close(); + } + + private void writeDocTypeToponymTypeEdges(Map > docToponyms, String graphOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(graphOutputPath, true)); + + Set docTypesAlreadyWritten = new HashSet(); + + for(String docId : docToponyms.keySet()) { + for(Toponym toponym : docToponyms.get(docId)) { + String docType = DOC_ + docId + "_" + TYPE_ + toponym.getForm(); + if(!docTypesAlreadyWritten.contains(docType)) { + writeEdge(out, docType, TPNM_TYPE_ + toponym.getForm(), 1.0); + docTypesAlreadyWritten.add(docType); + } + } + } + + out.close(); + } + + private void writeStringStringEdges(Map edgeMap, String graphOutputPath) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(graphOutputPath, true)); + + for(String key : edgeMap.keySet()) { + writeEdge(out, key, edgeMap.get(key), 1.0); + } + + out.close(); + } + + private void writeEdge(BufferedWriter out, String node1, String node2, double weight) throws Exception { + out.write(node1 + "\t" + node2 + "\t" + weight + "\n"); + } + + private StoredCorpus loadCorpus(String corpusInputPath, String serGazPath, String serCorpusPath, Enum corpusFormat) throws Exception { + + StoredCorpus corpus; + if(serCorpusPath != null) { + System.out.print("Reading serialized corpus from " + serCorpusPath + " ..."); + ObjectInputStream ois = null; + if(serCorpusPath.toLowerCase().endsWith(".gz")) { + GZIPInputStream gis = new GZIPInputStream(new FileInputStream(serCorpusPath)); + ois = new ObjectInputStream(gis); + } + else { + FileInputStream fis = new FileInputStream(serCorpusPath); + ois = new ObjectInputStream(fis); + } + corpus = (StoredCorpus) ois.readObject(); + System.out.println("done."); + } + else { + ImportCorpus importCorpus = new ImportCorpus(); + corpus = importCorpus.doImport(corpusInputPath, serGazPath, corpusFormat); + } + + return corpus; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreprocOld.java b/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreprocOld.java new file mode 100644 index 0000000..ad1385f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/LabelPropPreprocOld.java @@ -0,0 +1,276 @@ +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; +import java.util.zip.*; + +public class LabelPropPreprocOld extends BaseApp { + + /* + + private static final int DEGREES_PER_REGION = 1; + private static final double REGION_REGION_WEIGHT = 0.9; + + private static final double WORD_WORD_WEIGHT_THRESHOLD = 0.0; + + private static final int IN_DOC_COUNT_THRESHOLD = 5; + private static final int MATRIX_COUNT_THRESHOLD = 2; + + private static final String ARTICLE_TITLE = "Article title: "; + private static final String ARTICLE_ID = "Article ID: "; + + //private static int toponymLexiconSize; + + public static void main(String[] args) throws Exception { + + initializeOptionsFromCommandLine(args); + + Tokenizer tokenizer = new OpenNLPTokenizer(); + OpenNLPRecognizer recognizer = new OpenNLPRecognizer(); + + System.out.println("Reading serialized GeoNames gazetteer from " + getSerializedGazetteerPath() + " ..."); + GZIPInputStream gis = new GZIPInputStream(new FileInputStream(getSerializedGazetteerPath())); + ObjectInputStream ois = new ObjectInputStream(gis); + GeoNamesGazetteer gnGaz = (GeoNamesGazetteer) ois.readObject(); + System.out.println("Done."); + + StoredCorpus corpus = Corpus.createStoredCorpus(); + System.out.print("Reading TR-CoNLL corpus from " + getInputPath() + " ..."); + //corpus.addSource(new TrXMLDirSource(new File(getInputPath()), tokenizer)); + corpus.addSource(new ToponymAnnotator(new ToponymRemover(new TrXMLDirSource(new File(getInputPath()), tokenizer)), recognizer, gnGaz)); + corpus.load(); + System.out.println("done."); + + Map > toponymRegionEdges = new HashMap >(); + + Lexicon toponymLexicon = TopoUtil.buildLexicon(corpus); + //toponymLexiconSize = toponymLexicon.size(); + + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int idx = toponymLexicon.get(toponym.getForm()); + Set regionSet = toponymRegionEdges.get(idx); + if(regionSet == null) { + regionSet = new HashSet(); + for(Location location : toponym.getCandidates()) { + //int regionNumber = TopoUtil.getCellNumbers(location, DEGREES_PER_REGION); + //regionSet.add(regionNumber); + regionSet.addAll(TopoUtil.getCellNumbers(location, DEGREES_PER_REGION)); + } + toponymRegionEdges.put(idx, regionSet); + } + } + } + } + } + + writeToponymRegionEdges(toponymRegionEdges, getGraphOutputPath()); + writeRegionRegionEdges(getGraphOutputPath()); + //writeWordWordEdges(toponymLexicon, getWikiInputPath(), getGraphOutputPath(), getStoplistInputPath()); + + writeRegionLabels(toponymRegionEdges, getSeedOutputPath()); + } + + private static Set buildStoplist(String stoplistFilename) throws Exception { + Set stoplist = new HashSet(); + + BufferedReader in = new BufferedReader(new FileReader(stoplistFilename)); + + String curLine; + while(true) { + curLine = in.readLine(); + if(curLine == null) + break; + + if(curLine.length() > 0) + stoplist.add(curLine.toLowerCase()); + } + + in.close(); + + return stoplist; + } + + private static void writeWordWordEdges(Lexicon lexicon, String wikiFilename, + String outputFilename, String stoplistFilename) throws Exception { + + Set stoplist = buildStoplist(stoplistFilename); + + int docCount = 0; + Map > countMatrix = new HashMap >(); + + BufferedReader wikiIn = new BufferedReader(new FileReader(wikiFilename)); + + boolean skip = true; + + String curLine; + String articleTitle = null; + Set wordsInDoc = null; + while(true) { + curLine = wikiIn.readLine(); + if(curLine == null) + break; + + if(curLine.startsWith(ARTICLE_ID)) { + System.err.println(curLine + (skip?" skipped":"")); + continue; + } + + if(curLine.startsWith(ARTICLE_TITLE)) { + if(wordsInDoc != null && wordsInDoc.size() > 0) { + for(Integer i1 : wordsInDoc) { + Map curMap = countMatrix.get(i1); + if(curMap == null) { + curMap = new HashMap(); + } + for(Integer i2 : wordsInDoc) { + Integer curCount = curMap.get(i2); + if(curCount == null) { + curCount = 0; + } + curMap.put(i2, curCount + 1); + } + countMatrix.put(i1, curMap); + } + + docCount++; + } + + articleTitle = curLine.substring(ARTICLE_TITLE.length()).trim().toLowerCase(); + if(lexicon.contains(articleTitle)) { + skip = false; + wordsInDoc = new HashSet(); + } + else + skip = true; + } + + else if(!skip) { + //System.err.println(curLine); + + String[] tokens = curLine.split(" "); + String word = tokens[0].toLowerCase(); + + if(!stoplist.contains(word) && Integer.parseInt(tokens[2]) >= IN_DOC_COUNT_THRESHOLD) { + wordsInDoc.add(lexicon.getOrAdd(word)); + } + //System.err.println(curLine); + } + + } + + //System.err.println(countMatrix.get(17).get(17)); + + wikiIn.close(); + + ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream("lexicon.ser")); + oos.writeObject(lexicon); + oos.close(); + + BufferedWriter out = new BufferedWriter(new FileWriter(outputFilename, true)); + + for(int i1 : countMatrix.keySet()) { + Map innerMap = countMatrix.get(i1); + double i1count = innerMap.get(i1); + if(i1count < MATRIX_COUNT_THRESHOLD) continue; + for(int i2 : innerMap.keySet()) { + if(i1 != i2) { + + double i2count = countMatrix.get(i2).get(i2); + double i1i2count = innerMap.get(i2); + + if(i1i2count < MATRIX_COUNT_THRESHOLD) continue; + + double probi1 = i1count / docCount; + double probi2 = i2count / docCount; + double probi1i2 = i1i2count / docCount; + + /*System.err.println(i1); + System.err.println(i2); + System.err.println(docCount); + + System.err.println(probi1); + System.err.println(probi2); + System.err.println(probi1i2);*SLASH + + double wordWordWeight = Math.log(probi1i2 / (probi1 * probi2)); + + /*System.err.println(pmi); + System.err.println("---");*SLASH + + + if(wordWordWeight > WORD_WORD_WEIGHT_THRESHOLD) + out.write(i1 + "\t" + i2 + "\t" + wordWordWeight + "\n"); + } + } + } + + out.close(); + } + + private static void writeToponymRegionEdges(Map > toponymRegionEdges, String filename) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(filename)); + + for(int idx : toponymRegionEdges.keySet()) { + int size = toponymRegionEdges.get(idx).size(); + double weight = 1.0/size; + for(int regionNumber : toponymRegionEdges.get(idx)) { + out.write(idx + "\t" + regionNumber + "R\t" + weight + "\n"); + out.write(regionNumber + "R\t" + idx + "\t1.0\n"); + } + } + + out.close(); + } + + private static void writeRegionRegionEdges(String filename) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(filename, true)); + + for(int lon = 0; lon < 360 / DEGREES_PER_REGION; lon += DEGREES_PER_REGION) { + for(int lat = 0; lat < 180 / DEGREES_PER_REGION; lat += DEGREES_PER_REGION) { + int curRegionNumber = TopoUtil.getCellNumber(lat, lon, DEGREES_PER_REGION); + int leftRegionNumber = TopoUtil.getCellNumber(lat, lon - DEGREES_PER_REGION, DEGREES_PER_REGION); + int rightRegionNumber = TopoUtil.getCellNumber(lat, lon + DEGREES_PER_REGION, DEGREES_PER_REGION); + int topRegionNumber = TopoUtil.getCellNumber(lat + DEGREES_PER_REGION, lon, DEGREES_PER_REGION); + int bottomRegionNumber = TopoUtil.getCellNumber(lat - DEGREES_PER_REGION, lon, DEGREES_PER_REGION); + + out.write(curRegionNumber + "R\t" + leftRegionNumber + "R\t" + REGION_REGION_WEIGHT + "\n"); + out.write(curRegionNumber + "R\t" + rightRegionNumber + "R\t" + REGION_REGION_WEIGHT + "\n"); + if(topRegionNumber >= 0) + out.write(curRegionNumber + "R\t" + topRegionNumber + "R\t" + REGION_REGION_WEIGHT + "\n"); + if(bottomRegionNumber >= 0) + out.write(curRegionNumber + "R\t" + bottomRegionNumber + "R\t" + REGION_REGION_WEIGHT + "\n"); + } + } + + out.close(); + } + + private static void writeRegionLabels(Map > toponymRegionEdges, String filename) throws Exception { + BufferedWriter out = new BufferedWriter(new FileWriter(filename)); + + Set uniqueRegionNumbers = new HashSet(); + + for(int idx : toponymRegionEdges.keySet()) { + for(int regionNumber : toponymRegionEdges.get(idx)) { + uniqueRegionNumbers.add(regionNumber); + } + } + + for(int regionNumber : uniqueRegionNumbers) { + out.write(regionNumber + "R\t" + regionNumber + "L\t1.0\n"); + } + + out.close(); + } + +*/ +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/RunResolver.java b/src/main/java/opennlp/fieldspring/tr/app/RunResolver.java new file mode 100644 index 0000000..bf19523 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/RunResolver.java @@ -0,0 +1,204 @@ +/* + * This class runs the resolvers in opennlp.fieldspring.tr.resolver + */ + +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.resolver.*; +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.eval.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; +import java.util.zip.*; + +public class RunResolver extends BaseApp { + + public static void main(String[] args) throws Exception { + + long startTime = System.currentTimeMillis(); + + RunResolver currentRun = new RunResolver(); + currentRun.initializeOptionsFromCommandLine(args); + + if(currentRun.getSerializedGazetteerPath() == null && currentRun.getSerializedCorpusInputPath() == null) { + System.out.println("Abort: you must specify a path to a serialized gazetteer or corpus. To generate one, run ImportGazetteer and/or ImportCorpus."); + System.exit(0); + } + + Tokenizer tokenizer = new OpenNLPTokenizer(); + OpenNLPRecognizer recognizer = new OpenNLPRecognizer(); + + StoredCorpus goldCorpus = null; + if(currentRun.getCorpusFormat() == CORPUS_FORMAT.TRCONLL) { + System.out.print("Reading gold corpus from " + currentRun.getInputPath() + " ..."); + goldCorpus = Corpus.createStoredCorpus(); + File goldFile = new File(currentRun.getInputPath()); + if(goldFile.isDirectory()) + goldCorpus.addSource(new TrXMLDirSource(goldFile, tokenizer)); + else + goldCorpus.addSource(new TrXMLSource(new BufferedReader(new FileReader(goldFile)), tokenizer)); + goldCorpus.setFormat(CORPUS_FORMAT.TRCONLL); + goldCorpus.load(); + System.out.println("done."); + } + + StoredCorpus testCorpus; + if(currentRun.getSerializedCorpusInputPath() != null) { + System.out.print("Reading serialized corpus from " + currentRun.getSerializedCorpusInputPath() + " ..."); + testCorpus = TopoUtil.readStoredCorpusFromSerialized(currentRun.getSerializedCorpusInputPath()); + System.out.println("done."); + } + else { + ImportCorpus importCorpus = new ImportCorpus(); + testCorpus = importCorpus.doImport(currentRun.getInputPath(), currentRun.getSerializedGazetteerPath(), currentRun.getCorpusFormat(), currentRun.getUseGoldToponyms(), currentRun.getSentsPerDocument()); + } + + StoredCorpus trainCorpus = Corpus.createStoredCorpus(); + if(currentRun.getAdditionalInputPath() != null) { + System.out.print("Reading additional training corpus from " + currentRun.getAdditionalInputPath() + " ..."); + List gazList = new ArrayList(); + LoadableGazetteer trGaz = new InMemoryGazetteer(); + trGaz.load(new CorpusGazetteerReader(testCorpus)); + LoadableGazetteer otherGaz = new InMemoryGazetteer(); + otherGaz.load(new WorldReader(new File(Constants.getGazetteersDir() + File.separator + "dataen-fixed.txt.gz"))); + gazList.add(trGaz); + gazList.add(otherGaz); + Gazetteer multiGaz = new MultiGazetteer(gazList); + /*trainCorpus.addSource(new ToponymAnnotator(new PlainTextSource( + new BufferedReader(new FileReader(currentRun.getAdditionalInputPath())), new OpenNLPSentenceDivider(), tokenizer), + recognizer, + multiGaz));*/ + trainCorpus.addSource(new ToponymAnnotator(new GigawordSource( + new BufferedReader(new InputStreamReader( + new GZIPInputStream(new FileInputStream(currentRun.getAdditionalInputPath())))), 10, 40000), + recognizer, + multiGaz)); + trainCorpus.addSource(new TrXMLDirSource(new File(currentRun.getInputPath()), tokenizer)); + trainCorpus.setFormat(currentRun.getCorpusFormat()); + trainCorpus.load(); + System.out.println("done."); + } + + long endTime = System.currentTimeMillis(); + float seconds = (endTime - startTime) / 1000F; + System.out.println("\nInitialization took " + Float.toString(seconds/(float)60.0) + " minutes."); + + Resolver resolver; + if(currentRun.getResolverType() == RESOLVER_TYPE.RANDOM) { + System.out.print("Running RANDOM resolver..."); + resolver = new RandomResolver(); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.POPULATION) { + System.out.print("Running POPULATION resolver..."); + resolver = new PopulationResolver(); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.WEIGHTED_MIN_DIST) { + System.out.println("Running WEIGHTED MINIMUM DISTANCE resolver with " + currentRun.getNumIterations() + " iteration(s)..."); + resolver = new WeightedMinDistResolver(currentRun.getNumIterations(), currentRun.getReadWeightsFromFile(), + currentRun.getLogFilePath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.DOC_DIST) { + System.out.println("Running DOC DIST resolver, using log file at " + currentRun.getLogFilePath() + " ..."); + resolver = new DocDistResolver(currentRun.getLogFilePath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.TOPO_AS_DOC_DIST) { + System.out.println("Running TOPO AS DOC DIST resolver..."); + resolver = new ToponymAsDocDistResolver(currentRun.getLogFilePath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.LABEL_PROP) { + System.out.print("Running LABEL PROP resolver..."); + resolver = new LabelPropResolver(currentRun.getLogFilePath(), currentRun.getKnnForLP()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.LABEL_PROP_DEFAULT_RULE) { + System.out.print("Running LABEL PROP DEFAULT RULE resolver, using graph at " + currentRun.getGraphInputPath() + " ..."); + resolver = new LabelPropDefaultRuleResolver(currentRun.getGraphInputPath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.LABEL_PROP_CONTEXT_SENSITIVE) { + System.out.print("Running LABEL PROP CONTEXT SENSITIVE resolver, using graph at " + currentRun.getGraphInputPath() + " ..."); + resolver = new LabelPropContextSensitiveResolver(currentRun.getGraphInputPath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.LABEL_PROP_COMPLEX) { + System.out.print("Running LABEL PROP COMPLEX resolver, using graph at " + currentRun.getGraphInputPath() + " ..."); + resolver = new LabelPropComplexResolver(currentRun.getGraphInputPath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.MAXENT) { + System.out.println("Running MAXENT resolver, using models at " + currentRun.getMaxentModelDirInputPath() + " and log file at " + currentRun.getLogFilePath() + " ..."); + resolver = new MaxentResolver(currentRun.getLogFilePath(), currentRun.getMaxentModelDirInputPath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.PROB) { + System.out.println("Running PROBABILISTIC resolver, using models at " + currentRun.getMaxentModelDirInputPath() + " and log file at " + currentRun.getLogFilePath()); + + resolver = new ProbabilisticResolver(currentRun.getLogFilePath(), currentRun.getMaxentModelDirInputPath(), currentRun.getPopComponentCoefficient(), currentRun.getDGProbOnly(), currentRun.getMEProbOnly()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.BAYES_RULE) { + System.out.println("Running BAYES RULE resolver, using models at " + currentRun.getMaxentModelDirInputPath() + " and log file at " + currentRun.getLogFilePath()); + + resolver = new BayesRuleResolver(currentRun.getLogFilePath(), currentRun.getMaxentModelDirInputPath()); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.HEURISTIC_TPP) { + System.out.println("Running HEURISTIC TPP resolver..."); + + resolver = new HeuristicTPPResolver(); + } + else if(currentRun.getResolverType() == RESOLVER_TYPE.CONSTRUCTION_TPP) { + System.out.println("Running CONSTRUCTION TPP resolver..."); + + resolver = new ConstructionTPPResolver(currentRun.getDPC(), currentRun.getThreshold(), testCorpus, currentRun.getMaxentModelDirInputPath()); + } + else {//if(getResolverType() == RESOLVER_TYPE.BASIC_MIN_DIST) { + System.out.print("Running BASIC MINIMUM DISTANCE resolver..."); + resolver = new BasicMinDistResolver(); + } + + if(currentRun.getAdditionalInputPath() != null) + resolver.train(trainCorpus); + StoredCorpus disambiguated = resolver.disambiguate(testCorpus); + disambiguated.setFormat(currentRun.getCorpusFormat()); + if(currentRun.getCorpusFormat() == CORPUS_FORMAT.GEOTEXT) { + if(currentRun.getBoundingBox() != null) + System.out.println("\nOnly disambiguating documents within bounding box: " + currentRun.getBoundingBox().toString()); + SimpleDocumentResolver dresolver = new SimpleDocumentResolver(); + disambiguated = dresolver.disambiguate(disambiguated, currentRun.getBoundingBox()); + } + + System.out.println("done.\n"); + + if(goldCorpus != null || currentRun.getCorpusFormat() == CORPUS_FORMAT.GEOTEXT) { + EvaluateCorpus evaluateCorpus = new EvaluateCorpus(); + evaluateCorpus.doEval(disambiguated, goldCorpus, currentRun.getCorpusFormat(), true, currentRun.getDoOracleEval()); + } + + if(currentRun.getSerializedCorpusOutputPath() != null) { + ImportCorpus importCorpus = new ImportCorpus(); + importCorpus.serialize(disambiguated, currentRun.getSerializedCorpusOutputPath()); + } + + if(currentRun.getOutputPath() != null) { + System.out.print("Writing resolved corpus in XML format to " + currentRun.getOutputPath() + " ..."); + CorpusXMLWriter w = new CorpusXMLWriter(disambiguated); + w.write(new File(currentRun.getOutputPath())); + System.out.println("done."); + } + + if(currentRun.getKMLOutputPath() != null) { + WriteCorpusToKML writeCorpusToKML = new WriteCorpusToKML(); + writeCorpusToKML.writeToKML(disambiguated, currentRun.getKMLOutputPath(), currentRun.getOutputGoldLocations(), currentRun.getOutputUserKML(), currentRun.getCorpusFormat()); + } + + if(currentRun.getDKMLOutputPath() != null) { + System.out.print("Writing resolved corpus in Dynamic KML format to " + currentRun.getDKMLOutputPath() + " ..."); + DynamicKMLWriter w = new DynamicKMLWriter(disambiguated); + w.write(new File(currentRun.getDKMLOutputPath())); + System.out.println("done."); + } + + endTime = System.currentTimeMillis(); + seconds = (endTime - startTime) / 1000F; + System.out.println("\nTotal time elapsed: " + Float.toString(seconds/(float)60.0) + " minutes."); + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/app/WriteCorpusToKML.java b/src/main/java/opennlp/fieldspring/tr/app/WriteCorpusToKML.java new file mode 100644 index 0000000..f71a567 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/app/WriteCorpusToKML.java @@ -0,0 +1,46 @@ +/* + * This class takes a corpus with system resolved toponyms and generates a KML file visualizable in Google Earth. + */ + +package opennlp.fieldspring.tr.app; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; + +public class WriteCorpusToKML extends BaseApp { + + public static void main(String[] args) throws Exception { + + WriteCorpusToKML currentRun = new WriteCorpusToKML(); + currentRun.initializeOptionsFromCommandLine(args); + + if(currentRun.getSerializedCorpusInputPath() == null) { + System.out.println("Please specify an input corpus in serialized format via the -sci flag."); + System.exit(0); + } + + if(currentRun.getKMLOutputPath() == null) { + System.out.println("Please specify a KML output path via the -ok flag."); + System.exit(0); + } + + System.out.print("Reading serialized corpus from " + currentRun.getSerializedCorpusInputPath() + " ..."); + Corpus corpus = TopoUtil.readCorpusFromSerialized(currentRun.getSerializedCorpusInputPath()); + System.out.println("done."); + + currentRun.writeToKML(corpus, currentRun.getKMLOutputPath(), currentRun.getOutputGoldLocations(), currentRun.getOutputUserKML(), currentRun.getCorpusFormat()); + } + + public void writeToKML(Corpus corpus, String kmlOutputPath, boolean outputGoldLocations, boolean outputUserKML, Enum corpusFormat) throws Exception { + System.out.print("Writing visualizable corpus in KML format to " + kmlOutputPath + " ..."); + CorpusKMLWriter kw; + if(corpusFormat == CORPUS_FORMAT.GEOTEXT && outputUserKML) + kw = new GeoTextCorpusKMLWriter(corpus, outputGoldLocations); + else + kw = new CorpusKMLWriter(corpus, outputGoldLocations); + kw.write(new File(kmlOutputPath)); + System.out.println("done."); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/eval/AccuracyEvaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/AccuracyEvaluator.java new file mode 100644 index 0000000..1ae305f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/AccuracyEvaluator.java @@ -0,0 +1,43 @@ +/* + * This is a simple Evaluator that assumes gold named entities were used in preprocessing. For each gold disambiguated toponym, the model + * either got that Location right or wrong, and a Report containing the accuracy figure on this task is returned. + */ + +package opennlp.fieldspring.tr.eval; + +import opennlp.fieldspring.tr.text.*; + +public class AccuracyEvaluator extends Evaluator { + + public AccuracyEvaluator(Corpus corpus) { + super(corpus); + } + + @Override + public Report evaluate() { + + Report report = new Report(); + + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.hasGold()) { + if(toponym.getGoldIdx() == toponym.getSelectedIdx()) { + report.incrementTP(); + } + else { + report.incrementInstanceCount(); + } + } + } + } + } + + return report; + } + + @Override + public Report evaluate(Corpus pred, boolean useSelected) { + return null; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/eval/DistanceReport.java b/src/main/java/opennlp/fieldspring/tr/eval/DistanceReport.java new file mode 100644 index 0000000..4f08d51 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/DistanceReport.java @@ -0,0 +1,61 @@ +package opennlp.fieldspring.tr.eval; + +import java.util.*; + +public class DistanceReport { + + private List distances = new ArrayList(); + private boolean isSorted = true; + + public void addDistance(double distance) { + distances.add(distance); + isSorted = false; + } + + public double getMeanDistance() { + if(distances.size() == 0) return -1; + + double total = 0.0; + for(double distance : distances) { + total += distance; + } + return total / distances.size(); + } + + public double getMedianDistance() { + if(distances.size() == 0) return -1; + sort(); + return distances.get(distances.size() / 2); + } + + public int getNumDistances() { + return distances.size(); + } + + public double getFractionDistancesWithinThreshold(double threshold) { + int count = 0; + for(double distance : distances) + if(distance <= threshold) + count++; + return ((double)count) / distances.size(); + } + + public double getMinDistance() { + if(distances.size() == 0) return -1; + sort(); + return distances.get(0); + } + + public double getMaxDistance() { + if(distances.size() == 0) return -1; + sort(); + return distances.get(distances.size()-1); + } + + private void sort() { + if(isSorted) + return; + Collections.sort(distances); + isSorted = true; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/eval/DocDistanceEvaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/DocDistanceEvaluator.java new file mode 100644 index 0000000..2a76bf0 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/DocDistanceEvaluator.java @@ -0,0 +1,35 @@ +package opennlp.fieldspring.tr.eval; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; + +public class DocDistanceEvaluator { + + protected final Corpus corpus; + + public DocDistanceEvaluator(Corpus corpus) { + this.corpus = (Corpus)corpus; + } + + /* Evaluate the "selected" candidates in the corpus using its "gold" + * candidates. */ + public DistanceReport evaluate() { + DistanceReport dreport = new DistanceReport(); + + for(Document doc : corpus) { + + if(!doc.isTrain()) { + + Coordinate systemCoord = doc.getSystemCoord(); + Coordinate goldCoord = doc.getGoldCoord(); + + if(systemCoord != null && goldCoord != null) { + dreport.addDistance(systemCoord.distanceInKm(goldCoord)); + } + } + } + + return dreport; + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/eval/EDEvaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/EDEvaluator.java new file mode 100644 index 0000000..0200cdf --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/EDEvaluator.java @@ -0,0 +1,52 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.eval; + +import java.util.Iterator; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.Token; + +public class EDEvaluator extends Evaluator { + public EDEvaluator(Corpus corpus) { + super(corpus); + } + + public Report evaluate() { + return null; + } + + public Report evaluate(Corpus pred, boolean useSelected) { + Iterator> goldDocs = this.corpus.iterator(); + Iterator> predDocs = pred.iterator(); + + while (goldDocs.hasNext() && predDocs.hasNext()) { + Iterator> goldSents = goldDocs.next().iterator(); + Iterator> predSents = predDocs.next().iterator(); + + while (goldSents.hasNext() && predSents.hasNext()) { + } + + assert !goldSents.hasNext() && !predSents.hasNext() : "Documents have different numbers of sentences."; + } + + assert !goldDocs.hasNext() && !predDocs.hasNext() : "Corpora have different numbers of documents."; + return null; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/eval/Evaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/Evaluator.java new file mode 100644 index 0000000..d98f791 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/Evaluator.java @@ -0,0 +1,43 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.eval; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Token; + +public abstract class Evaluator { + protected final Corpus corpus; + + /* The given corpus should include either gold or selected candidates or + * both. */ + public Evaluator(Corpus corpus) { + this.corpus = (Corpus) corpus; + } + + /* Evaluate the "selected" candidates in the corpus using its "gold" + * candidates. */ + public abstract Report evaluate(); + + /* Evaluate the given corpus using either the gold or selected candidates in + * the current corpus. */ + public abstract Report evaluate(Corpus pred, boolean useSelected); + + /* A convenience method providing a default for evaluate. */ + public Report evaluate(Corpus pred) { + return this.evaluate(pred, false); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/eval/Report.java b/src/main/java/opennlp/fieldspring/tr/eval/Report.java new file mode 100644 index 0000000..03c043e --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/Report.java @@ -0,0 +1,84 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.eval; + +public class Report { + + private int tp; + private int fp; + private int fn; + private int totalInstances; + + public int getFN() { + return fn; + } + + public int getFP() { + return fp; + } + + public int getTP() { + return tp; + } + + public int getInstanceCount() { + return totalInstances; + } + + public void incrementTP() { + tp++; + totalInstances++; + } + + public void incrementFP() { + fp++; + totalInstances++; + } + + public void incrementFN() { + fn++; + totalInstances++; + } + + public void incrementFPandFN() { + fp++; + fn++; + totalInstances++; + } + + public void incrementInstanceCount() { + totalInstances++; + } + + public double getAccuracy() { + return (double) tp / totalInstances; + } + + public double getPrecision() { + return (double) tp / (tp + fp); + } + + public double getRecall() { + return (double) tp / (tp + fn); + } + + public double getFScore() { + double p = getPrecision(); + double r = getRecall(); + return (2 * p * r) / (p + r); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/eval/SharedNEEvaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/SharedNEEvaluator.java new file mode 100644 index 0000000..83bf28d --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/SharedNEEvaluator.java @@ -0,0 +1,101 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.eval; + +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.Region; + +public class SharedNEEvaluator extends Evaluator { + /* The given corpus should include either gold or selected candidates or + * both. */ + public SharedNEEvaluator(Corpus corpus) { + super(corpus); + } + + /* Evaluate the "selected" candidates in the corpus using its "gold" + * candidates. */ + public Report evaluate() { + return this.evaluate(this.corpus, false); + } + + /* Evaluate the given corpus using either the gold or selected candidates in + * the current corpus. */ + public Report evaluate(Corpus pred, boolean useSelected) { + Report report = new Report(); + + Iterator> goldDocs = this.corpus.iterator(); + Iterator> predDocs = pred.iterator(); + + /* Iterate over documents in sync. */ + while (goldDocs.hasNext() && predDocs.hasNext()) { + Iterator> goldSents = goldDocs.next().iterator(); + Iterator> predSents = predDocs.next().iterator(); + + /* Iterate over sentences in sync. */ + while (goldSents.hasNext() && predSents.hasNext()) { + List goldToponyms = goldSents.next().getToponyms(); + List predToponyms = predSents.next().getToponyms(); + + /* Confirm that we have the same number of toponyms and loop through + * them. */ + assert goldToponyms.size() == predToponyms.size() : "Named entity spans do not match!"; + for (int i = 0; i < goldToponyms.size(); i++) { + Toponym predToponym = predToponyms.get(i); + if (predToponym.hasSelected()) { + Region predRegion = predToponym.getSelected().getRegion(); + List candidates = goldToponyms.get(i).getCandidates(); + + double minDist = Double.POSITIVE_INFINITY; + int minIdx = -1; + for (int j = 0; j < candidates.size(); j++) { + double dist = predRegion.distance(candidates.get(j).getRegion().getCenter()); + if (dist < minDist) { + minDist = dist; + minIdx = j; + } + } + /*System.out.format("Size: %d, minDist: %f, minIdx: %d, goldIdx: %d\n", + candidates.size(), minDist, minIdx, goldToponyms.get(i).getGoldIdx());*/ + + if (minIdx == goldToponyms.get(i).getGoldIdx()) { + report.incrementTP(); + } else { + report.incrementInstanceCount(); + } + } else { + report.incrementInstanceCount(); + } + } + } + } + + return report; + } + + /* A convenience method providing a default for evaluate. */ + public Report evaluate(Corpus pred) { + return this.evaluate(pred, false); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/eval/SignatureEvaluator.java b/src/main/java/opennlp/fieldspring/tr/eval/SignatureEvaluator.java new file mode 100644 index 0000000..089c255 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/eval/SignatureEvaluator.java @@ -0,0 +1,203 @@ +/* + * Evaluator that uses signatures around each gold and predicted toponym to be used in the computation of P/R/F. + */ + +package opennlp.fieldspring.tr.eval; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import java.util.*; +import java.io.*; + +public class SignatureEvaluator extends Evaluator { + + private static final int CONTEXT_WINDOW_SIZE = 20; + + private static final double FP_PENALTY = 20037.5; + private static final double FN_PENALTY = 20037.5; + + private boolean doOracleEval; + + private Map > predCandidates = new HashMap >(); + + public SignatureEvaluator(Corpus goldCorpus, boolean doOracleEval) { + super(goldCorpus); + this.doOracleEval = doOracleEval; + } + + public SignatureEvaluator(Corpus goldCorpus) { + this(goldCorpus, false); + } + + public Report evaluate() { + return null; + } + + private Map populateSigsAndLocations(Corpus corpus, boolean getGoldLocations) { + Map locs = new HashMap(); + + for(Document doc : corpus) { + //System.out.println("Document id: " + doc.getId()); + for(Sentence sent : doc) { + StringBuffer sb = new StringBuffer(); + List toponymStarts = new ArrayList(); + List curLocations = new ArrayList(); + List > curCandidates = new ArrayList >(); + for(Token token : sent) { + //System.out.println(token.getForm()); + if(token.isToponym()) { + Toponym toponym = (Toponym) token; + if((getGoldLocations && toponym.hasGold()) || + (!getGoldLocations && (toponym.hasSelected() || toponym.getAmbiguity() == 0))) { + toponymStarts.add(sb.length()); + if(getGoldLocations) { + /*if(toponym.getGoldIdx() == 801) { + System.out.println(toponym.getForm()+": "+toponym.getGoldIdx()+"/"+toponym.getCandidates().size()); + }*/ + curLocations.add(toponym.getCandidates().get(toponym.getGoldIdx() 0) + curLocations.add(toponym.getCandidates().get(toponym.getSelectedIdx())); + else + curLocations.add(null); + curCandidates.add(toponym.getCandidates()); + } + } + } + sb.append(token.getForm().replaceAll("[^a-z0-9]", "")); + } + for(int i = 0; i < toponymStarts.size(); i++) { + int toponymStart = toponymStarts.get(i); + Location curLoc = curLocations.get(i); + String context = getSignature(sb, toponymStart, CONTEXT_WINDOW_SIZE); + locs.put(context, curLoc); + if(!getGoldLocations) + predCandidates.put(context, curCandidates.get(i)); + } + } + } + + return locs; + } + + private DistanceReport dreport = null; + public DistanceReport getDistanceReport() { return dreport; } + + @Override + public Report evaluate(Corpus pred, boolean useSelected) { + + Report report = new Report(); + dreport = new DistanceReport(); + + Map goldLocs = populateSigsAndLocations(corpus, true); + Map predLocs = populateSigsAndLocations(pred, false); + + Map > errors = new HashMap >(); + + for(String context : goldLocs.keySet()) { + if(predLocs.containsKey(context)) { + Location goldLoc = goldLocs.get(context); + Location predLoc = predLocs.get(context); + + if(predLoc != null && !doOracleEval) { + double dist = goldLoc.distanceInKm(predLoc); + dreport.addDistance(dist); + String key = goldLoc.getName().toLowerCase(); + if(!errors.containsKey(key)) + errors.put(key, new ArrayList()); + errors.get(key).add(dist); + } + + if(doOracleEval) { + if(predCandidates.get(context).size() > 0) { + Location closestMatch = getClosestMatch(goldLoc, predCandidates.get(context)); + dreport.addDistance(goldLoc.distanceInKm(closestMatch)); + report.incrementTP(); + } + } + else { + if(isClosestMatch(goldLoc, predLoc, predCandidates.get(context))) {//goldLocs.get(context) == predLocs.get(context)) { + //System.out.println("TP: " + context + "|" + goldLocs.get(context)); + report.incrementTP(); + } + else { + //System.out.println("FP and FN: " + context + "|" + goldLocs.get(context) + " vs. " + predLocs.get(context)); + //report.incrementFP(); + //report.incrementFN(); + report.incrementFPandFN(); + } + } + } + else { + //System.out.println("FN: " + context + "| not found in pred"); + report.incrementFN(); + //dreport.addDistance(FN_PENALTY); + + } + } + for(String context : predLocs.keySet()) { + if(!goldLocs.containsKey(context)) { + //System.out.println("FP: " + context + "| not found in gold"); + report.incrementFP(); + //dreport.addDistance(FP_PENALTY); + } + } + + try { + BufferedWriter errOut = new BufferedWriter(new FileWriter("errors.txt")); + + for(String toponym : errors.keySet()) { + List errorList = errors.get(toponym); + double sum = 0.0; + for(double error : errorList) { + sum += error; + } + errOut.write(toponym+" & "+errorList.size()+" & "+(sum/errorList.size())+" & "+sum+"\\\\\n"); + } + + errOut.close(); + + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + + return report; + } + + private boolean isClosestMatch(Location goldLoc, Location predLoc, List curPredCandidates) { + if(predLoc == null) + return false; + + double distanceToBeat = predLoc.distance(goldLoc); + + for(Location otherLoc : curPredCandidates) { + if(otherLoc.distance(goldLoc) < distanceToBeat) + return false; + } + return true; + } + + private Location getClosestMatch(Location goldLoc, List curPredCandidates) { + double minDist = Double.POSITIVE_INFINITY; + Location toReturn = null; + + for(Location otherLoc : curPredCandidates) { + double dist = otherLoc.distance(goldLoc); + if(dist < minDist) { + minDist = dist; + toReturn = otherLoc; + } + } + + return toReturn; + } + + private String getSignature(StringBuffer wholeContext, int centerIndex, int windowSize) { + int beginIndex = Math.max(0, centerIndex - windowSize); + int endIndex = Math.min(wholeContext.length(), centerIndex + windowSize); + + return wholeContext.substring(beginIndex, endIndex); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/BasicMinDistResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/BasicMinDistResolver.java new file mode 100644 index 0000000..d0b052e --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/BasicMinDistResolver.java @@ -0,0 +1,162 @@ +/* + * Basic Minimum Distance resolver. For each toponym, the location is selected that minimizes the total distance to some disambiguation + * of the other toponyms in the same document. + */ + +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import java.util.*; + +public class BasicMinDistResolver extends Resolver { + + /* This implementation of disambiguate immediately stops computing distance + * totals for candidates when it becomes clear that they aren't minimal. */ + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + for (Document doc : corpus) { + if(!doc.isTrain() && !doc.isTest()) { + for (Sentence sent : doc) { + for (Toponym toponym : sent.getToponyms()) { + double min = Double.MAX_VALUE; + int minIdx = -1; + + int idx = 0; + for (Location candidate : toponym) { + Double candidateMin = this.checkCandidate(toponym, candidate, doc, min); + if (candidateMin != null) { + min = candidateMin; + minIdx = idx; + } + idx++; + } + + if (minIdx > -1) { + toponym.setSelectedIdx(minIdx); + } + } + } + } + } + + // Backoff to Random: + Resolver randResolver = new RandomResolver(); + randResolver.overwriteSelecteds = false; + corpus = randResolver.disambiguate(corpus); + + return corpus; + } + + /* Returns the minimum total distance to all other locations in the document + * for the candidate, or null if it's greater than the current minimum. */ + public Double checkCandidate(Toponym toponym, Location candidate, Document doc, double currentMinTotal) { + Double total = 0.0; + int seen = 0; + + for (Sentence otherSent : doc) { + for (Toponym otherToponym : otherSent.getToponyms()) { + + /* We don't want to compute distances if this other toponym is the + * same as the current one, or if it has no candidates. */ + if (!otherToponym.equals(toponym) && otherToponym.getAmbiguity() > 0) { + double min = Double.MAX_VALUE; + + for (Location otherLoc : otherToponym) { + double dist = candidate.distance(otherLoc); + if (dist < min) { + min = dist; + } + } + + seen++; + total += min; + + /* If the running total is greater than the current minimum, we can + * stop. */ + if (total >= currentMinTotal) { + return null; + } + } + } + } + + /* Abstain if we haven't seen any other toponyms. */ + return seen > 0 ? total : null; + } + + /* The previous implementation of disambiguate. */ + public StoredCorpus disambiguateOld(StoredCorpus corpus) { + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Token token : sent.getToponyms()) { + //if(token.isToponym()) { + Toponym toponym = (Toponym) token; + + basicMinDistDisambiguate(toponym, doc); + //} + } + } + } + return corpus; + } + + /* + * Sets the selected index of toponymToDisambiguate according to the Location with the minimum total + * distance to some disambiguation of all the Locations of the Toponyms in doc. + */ + private void basicMinDistDisambiguate(Toponym toponymToDisambiguate, Document doc) { + //HashMap totalDistances = new HashMap(); + List totalDistances = new ArrayList(); + + // Compute the total minimum distances from each candidate Location of toponymToDisambiguate to some disambiguation + // of all the Toponyms in doc; store these in totalDistances + for(Location curLoc : toponymToDisambiguate) { + Double totalDistSoFar = 0.0; + int seen = 0; + + for(Sentence sent : doc) { + for(Token token : sent.getToponyms()) { + //if(token.isToponym()) { + Toponym otherToponym = (Toponym) token; + + /* We don't want to compute distances if this other toponym is the + * same as the current one, or if it has no candidates. */ + if (!otherToponym.equals(toponymToDisambiguate) && otherToponym.getAmbiguity() > 0) { + double minDist = Double.MAX_VALUE; + for(Location otherLoc : otherToponym) { + double curDist = curLoc.distance(otherLoc); + if(curDist < minDist) { + minDist = curDist; + } + } + totalDistSoFar += minDist; + seen++; + } + //} + } + } + + /* Abstain if we haven't seen any other toponyms. */ + totalDistances.add(seen > 0 ? totalDistSoFar : Double.MAX_VALUE); + } + + // Find the overall minimum of all the total minimum distances computed above + double minTotalDist = Double.MAX_VALUE; + int indexOfMin = -1; + for(int curLocIndex = 0; curLocIndex < totalDistances.size(); curLocIndex++) { + double totalDist = totalDistances.get(curLocIndex); + if(totalDist < minTotalDist) { + minTotalDist = totalDist; + indexOfMin = curLocIndex; + } + } + + // Set toponymToDisambiguate's index to the index of the Location with the overall minimum distance + // from above, if one was found + if(indexOfMin >= 0) { + toponymToDisambiguate.setSelectedIdx(indexOfMin); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropComplexResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropComplexResolver.java new file mode 100644 index 0000000..04acf14 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropComplexResolver.java @@ -0,0 +1,100 @@ +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; +import opennlp.fieldspring.tr.app.*; +import java.io.*; +import java.util.*; + +public class LabelPropComplexResolver extends Resolver { + + public static final double DPC = 1.0; // degrees per cell + + private String pathToGraph; + private Map > cellDistributions = null; + + public LabelPropComplexResolver(String pathToGraph) { + this.pathToGraph = pathToGraph; + } + + @Override + public void train(StoredCorpus corpus) { + cellDistributions = new HashMap >(); + + try { + BufferedReader in = new BufferedReader(new FileReader(pathToGraph)); + + String curLine; + while(true) { + curLine = in.readLine(); + if(curLine == null) + break; + + String[] tokens = curLine.split("\t"); + + if(!tokens[0].startsWith(LabelPropPreproc.DOC_) || !tokens[0].contains(LabelPropPreproc.TOK_)) + continue; + + int docIdBeginIndex = tokens[0].indexOf(LabelPropPreproc.DOC_) + LabelPropPreproc.DOC_.length(); + int lastTOKIndex = tokens[0].lastIndexOf(LabelPropPreproc.TOK_); + int docIdEndIndex = lastTOKIndex - 1; // - 1 for intermediary "_" + String docId = tokens[0].substring(docIdBeginIndex, docIdEndIndex); + + String tokenIndex = tokens[0].substring(lastTOKIndex + LabelPropPreproc.TOK_.length()); + + String key = docId + ";" + tokenIndex; + + Map cellDistribution = new HashMap(); + + for(int i = 1; i < tokens.length; i++) { + String curToken = tokens[i]; + if(curToken.length() == 0) + continue; + + String[] innerTokens = curToken.split(" "); + for(int j = 0; j < innerTokens.length; j++) { + if(/*!innerTokens[j].startsWith("__DUMMY__") && */innerTokens[j].startsWith(LabelPropPreproc.CELL_LABEL_)) { + int cellNumber = Integer.parseInt(innerTokens[j].substring(LabelPropPreproc.CELL_LABEL_.length())); + double mass = Double.parseDouble(innerTokens[j+1]); + cellDistribution.put(cellNumber, mass); + } + } + + cellDistributions.put(key, cellDistribution); + } + + + } + + in.close(); + + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + } + + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + + if(cellDistributions == null) + train(corpus); + + for(Document doc : corpus) { + int tokenIndex = 0; + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int indexToSelect = TopoUtil.getCorrectCandidateIndex(toponym, cellDistributions.get(doc.getId() + ";" + tokenIndex), DPC); + if(indexToSelect != -1) + toponym.setSelectedIdx(indexToSelect); + } + } + tokenIndex++; + } + } + + return corpus; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropContextSensitiveResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropContextSensitiveResolver.java new file mode 100644 index 0000000..527bb61 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropContextSensitiveResolver.java @@ -0,0 +1,195 @@ +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; + +public class LabelPropContextSensitiveResolver extends Resolver { + + public static final int DEGREES_PER_REGION = 1; + + private static final double CUR_TOP_WEIGHT = 1.0; + private static final double SAME_SENT_WEIGHT = .5; + private static final double OTHER_WEIGHT = .33; + + private String pathToGraph; + private Lexicon lexicon = null;// = new SimpleLexicon(); + //private HashMap reverseLexicon = new HashMap(); + + private HashMap > regionDistributions = null; + + public LabelPropContextSensitiveResolver(String pathToGraph) { + this.pathToGraph = pathToGraph; + } + + @Override + public void train(StoredCorpus corpus){ + //TopoUtil.buildLexicons(corpus, lexicon, reverseLexicon); + try { + ObjectInputStream ois = new ObjectInputStream(new FileInputStream("lexicon.ser")); + lexicon = (Lexicon)ois.readObject(); + ois.close(); + } catch(Exception e) { + e.printStackTrace(); + System.exit(0); + } + + regionDistributions = new HashMap >(); + + try { + BufferedReader in = new BufferedReader(new FileReader(pathToGraph)); + + String curLine; + while(true) { + curLine = in.readLine(); + if(curLine == null) + break; + + String[] tokens = curLine.split("\t"); + if(tokens[0].endsWith("R")) continue; //tokens[0] = tokens[0].substring(0, tokens[0].length()-1); + + int idx = Integer.parseInt(tokens[0]); + + //if(!reverseLexicon.containsKey(idx)) + // continue; + + HashMap curDist = new HashMap(); + regionDistributions.put(idx, curDist); + + //int regionNumber = -1; + for(int i = 1; i < tokens.length; i++) { + String curToken = tokens[i]; + if(curToken.length() == 0) + continue; + + String[] innerTokens = curToken.split(" "); + for(int j = 0; j < innerTokens.length; j++) { + if(/*!innerTokens[j].startsWith("__DUMMY__") && */innerTokens[j].endsWith("L")) { + int regionNumber = Integer.parseInt(innerTokens[j].substring(0, innerTokens[j].length()-1)); + double labelWeight = Double.parseDouble(innerTokens[j+1]); + curDist.put(regionNumber, labelWeight); + } + } + } + } + + in.close(); + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + } + + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + + if(regionDistributions == null) + train(corpus); + + for(Document doc : corpus) { + int outerSentIndex = 0; + for(Sentence outerSent : doc) { + for(Toponym outerToponym : outerSent.getToponyms()) { + if(outerToponym.getAmbiguity() > 0) { + int idx = lexicon.get(outerToponym.getForm()); + + HashMap wordWeights = new HashMap(); + wordWeights.put(idx, CUR_TOP_WEIGHT); + + for(Token otherToken : outerSent.getTokens()) {//.getToponyms()) { + //if(otherToponym.getAmbiguity() > 0) { + Integer otherTokenIdx = lexicon.get(otherToken.getForm()); + if(otherTokenIdx == null) + continue; + if(!wordWeights.containsKey(otherTokenIdx)) + wordWeights.put(otherTokenIdx, SAME_SENT_WEIGHT); + //} + } + + int innerSentIndex = 0; + for(Sentence innerSent : doc) { + for(Token innerToken : innerSent.getTokens()) { + //if(innerToken.getAmbiguity() > 0) { + Integer innerTokenIdx = lexicon.get(innerToken.getForm()); + if(innerTokenIdx == null) + continue; + if(!wordWeights.containsKey(innerTokenIdx)) + wordWeights.put(innerTokenIdx, OTHER_WEIGHT); + //} + } + innerSentIndex++; + } + + //int bestRegionNumber = getBestRegionNumber(outerToponym, wordWeights); + //int indexToSelect = TopoUtil.getCorrectCandidateIndex(outerToponym, bestRegionNumber, DEGREES_PER_REGION); + Map weightedSum = getWeightedSum(wordWeights); + int indexToSelect = TopoUtil.getCorrectCandidateIndex(outerToponym, weightedSum, DEGREES_PER_REGION); + if(indexToSelect == -1) { + System.out.println(outerToponym.getForm()); + } + outerToponym.setSelectedIdx(indexToSelect); + } + } + outerSentIndex++; + } + } + + return corpus; + } + + private Map getWeightedSum(Map wordWeights) { + + Map weightedSum = new HashMap(); + + for(int outerIdx : wordWeights.keySet()) { + double weight = wordWeights.get(outerIdx); + Map curDist = regionDistributions.get(outerIdx); + if(curDist != null) { + for(int innerIdx : curDist.keySet()) { + Double prev = weightedSum.get(innerIdx); + if(prev == null) + prev = 0.0; + weightedSum.put(innerIdx, prev + weight * curDist.get(innerIdx)); + } + } + } + + return weightedSum; + } + + private int getBestRegionNumber(Toponym toponym, Map wordWeights) { + + Map weightedSum = getWeightedSum(wordWeights);//new HashMap(); + + /* + for(int outerIdx : wordWeights.keySet()) { + double weight = wordWeights.get(outerIdx); + Map curDist = regionDistributions.get(outerIdx); + if(curDist != null) { + for(int innerIdx : curDist.keySet()) { + Double prev = weightedSum.get(innerIdx); + if(prev == null) + prev = 0.0; + weightedSum.put(innerIdx, prev + weight * curDist.get(innerIdx)); + } + } + } + */ + + int bestRegionNumber = -1; + double greatestMass = 0.0; + for(int idx : weightedSum.keySet()) { + if(TopoUtil.getCorrectCandidateIndex(toponym, idx, DEGREES_PER_REGION) >= 0) { + double curMass = weightedSum.get(idx); + if(curMass > greatestMass) { + bestRegionNumber = idx; + greatestMass = curMass; + } + } + } + + return bestRegionNumber; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropDefaultRuleResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropDefaultRuleResolver.java new file mode 100644 index 0000000..4e8aaab --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/LabelPropDefaultRuleResolver.java @@ -0,0 +1,117 @@ +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.*; + +public class LabelPropDefaultRuleResolver extends Resolver { + + public static final int DEGREES_PER_REGION = 1; + + private String pathToGraph; + private Lexicon lexicon = new SimpleLexicon(); + private HashMap reverseLexicon = new HashMap(); + //private HashMap defaultRegions = null; + private HashMap > regionDistributions = null;//new HashMap >(); + private HashMap indexCache = new HashMap(); + + public LabelPropDefaultRuleResolver(String pathToGraph) { + this.pathToGraph = pathToGraph; + } + + @Override + public void train(StoredCorpus corpus) { + TopoUtil.buildLexicons(corpus, lexicon, reverseLexicon); + + //defaultRegions = new HashMap(); + regionDistributions = new HashMap >(); + + try { + BufferedReader in = new BufferedReader(new FileReader(pathToGraph)); + + String curLine; + while(true) { + curLine = in.readLine(); + if(curLine == null) + break; + + String[] tokens = curLine.split("\t"); + + if(tokens[0].endsWith("R")) continue;//tokens[0] = tokens[0].substring(0, tokens[0].length()-1); + + int idx = Integer.parseInt(tokens[0]); + + if(!reverseLexicon.containsKey(idx)) + continue; + + HashMap regionDistribution = regionDistributions.get(idx); + if(regionDistribution == null) + regionDistribution = new HashMap(); + + //int regionNumber = -1; + for(int i = 1; i < tokens.length; i++) { + String curToken = tokens[i]; + if(curToken.length() == 0) + continue; + + String[] innerTokens = curToken.split(" "); + for(int j = 0; j < innerTokens.length; j++) { + if(/*!innerTokens[j].startsWith("__DUMMY__") && */innerTokens[j].endsWith("L")) { + int regionNumber = Integer.parseInt(innerTokens[j].substring(0, innerTokens[j].length()-1)); + double mass = Double.parseDouble(innerTokens[j+1]); + regionDistribution.put(regionNumber, mass); + //break; + } + } + } + + regionDistributions.put(idx, regionDistribution); + + /*if(regionNumber == -1) { + System.out.println("-1"); + continue; + }*/ + + //defaultRegions.put(idx, regionNumber); + } + + in.close(); + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + } + + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + + //if(defaultRegions == null) + if(regionDistributions == null) + train(corpus); + + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int idx = lexicon.get(toponym.getForm()); + Integer indexToSelect = indexCache.get(idx); + if(indexToSelect == null) { + //int regionNumber = defaultRegions.get(idx); + //if(regionDistributions.get(idx) == null) + // System.err.println("region dist null for " + reverseLexicon.get(idx)); + indexToSelect = TopoUtil.getCorrectCandidateIndex(toponym, regionDistributions.get(idx), DEGREES_PER_REGION); + indexCache.put(idx, indexToSelect); + } + //System.out.println("index selected for " + toponym.getForm() + ": " + indexToSelect); + if(indexToSelect != -1) + toponym.setSelectedIdx(indexToSelect); + } + } + } + } + + return corpus; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/RandomResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/RandomResolver.java new file mode 100644 index 0000000..a82b058 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/RandomResolver.java @@ -0,0 +1,30 @@ +/* + * Random baseline resolver. Selects a random location for each toponym. + */ + +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import java.util.*; + +public class RandomResolver extends Resolver { + + private Random rand = new Random(); + + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + int ambiguity = toponym.getAmbiguity(); + if (ambiguity > 0 && (overwriteSelecteds || !toponym.hasSelected())) { + toponym.setSelectedIdx(rand.nextInt(ambiguity)); + } + } + } + } + + return corpus; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/Resolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/Resolver.java new file mode 100644 index 0000000..d007fef --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/Resolver.java @@ -0,0 +1,27 @@ +/* + * This version of Resolver (started 9/22/10) is just an abstract class with the disambiguate(Corpus) method. + */ + +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; + +/** + * @param corpus + * a corpus without any selected candidates for each toponym (or ignores the selections if they are present) + * @return + * a corpus with selected candidates, ready for evaluation + */ +public abstract class Resolver { + + // Make this false to have a resolver only resolve toponyms that don't already have a selected candidate + // (not implemented in all resolvers yet) + public boolean overwriteSelecteds = true; + + public void train(StoredCorpus corpus) { + throw new UnsupportedOperationException("This type of resolver cannot be trained."); + } + + public abstract StoredCorpus disambiguate(StoredCorpus corpus); + +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/SimpleDocumentResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/SimpleDocumentResolver.java new file mode 100644 index 0000000..e9fa3a8 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/SimpleDocumentResolver.java @@ -0,0 +1,70 @@ +/* + * This resolves each Document (which currently must be a GeoTextDocument) to a particular coordinate given that its toponyms have already + * been resolved. + */ + +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import java.util.*; + +public class SimpleDocumentResolver extends Resolver { + + public StoredCorpus disambiguate(StoredCorpus corpus) { + return disambiguate(corpus, null); + } + + public StoredCorpus disambiguate(StoredCorpus corpus, Region boundingBox) { + + for(Document doc : corpus) { + + Map locationCounts = new HashMap(); + int greatestLocFreq = 0; + Location mostCommonLoc = null; + + /*if(doc instanceof GeoTextDocument) { + System.out.println("doc " + doc.getId() + " is a GeoTextDocument."); + } + else + System.out.println("doc " + doc.getId() + " is NOT a GeoTextDocument; it's a " + doc.getClass().getName()); + */ + + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if (toponym.hasSelected()) { + Location systemLoc = toponym.getCandidates().get(toponym.getSelectedIdx()); + if(systemLoc.getRegion().getRepresentatives().size() == 1 + && (boundingBox == null || boundingBox.contains(systemLoc.getRegion().getCenter()))) { + int locId = systemLoc.getId(); + Integer prevCount = locationCounts.get(locId); + if(prevCount == null) + prevCount = 0; + locationCounts.put(locId, prevCount + 1); + + if(prevCount + 1 > greatestLocFreq) { + greatestLocFreq = prevCount + 1; + mostCommonLoc = systemLoc; + //System.out.println(mostCommonLoc.getName() + " is now most common with " + (prevCount+1)); + } + } + } + } + } + + if(mostCommonLoc != null) { + doc.setSystemCoord(mostCommonLoc.getRegion().getCenter()); + //System.out.println("Setting mostCommonLoc for " + doc.getId() + " to " + mostCommonLoc.getName()); + //System.out.println("goldCoord was " + doc.getGoldCoord()); + } + else if(boundingBox != null) { + doc.setSystemCoord(boundingBox.getCenter()); + } + else { + doc.setSystemCoord(Coordinate.fromDegrees(0.0, 0.0)); + } + } + + return corpus; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/resolver/WeightedMinDistResolver.java b/src/main/java/opennlp/fieldspring/tr/resolver/WeightedMinDistResolver.java new file mode 100644 index 0000000..356ca0e --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/resolver/WeightedMinDistResolver.java @@ -0,0 +1,412 @@ +/* + * Weighted Minimum Distance resolver. Iterative algorithm that builds on BasicMinDistResolver by incorporating corpus-level + * prominence of various locations into toponym resolution. + */ + +package opennlp.fieldspring.tr.resolver; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; +import java.util.*; +import java.io.*; + +public class WeightedMinDistResolver extends Resolver { + + // weights and toponym lexicon (for indexing into weights) are stored so that a different + // corpus/corpora can be used for training than for disambiguating + private List > weights = null; + Lexicon toponymLexicon = null; + + private int numIterations; + private boolean readWeightsFromFile; + private String logFilePath; + private List > weightsFromFile = null; + //private Map distanceCache = new HashMap(); + //private int maxCoeff = Integer.MAX_VALUE; + private DistanceTable distanceTable; + private static final int PHANTOM_COUNT = 0; // phantom/imagined counts for smoothing + + public WeightedMinDistResolver(int numIterations, boolean readWeightsFromFile, String logFilePath) { + super(); + this.numIterations = numIterations; + this.readWeightsFromFile = readWeightsFromFile; + this.logFilePath = logFilePath; + + if(readWeightsFromFile && logFilePath == null) { + System.err.println("Error: need logFilePath via -l for backoff to DocDist."); + System.exit(0); + } + } + + public WeightedMinDistResolver(int numIterations, boolean readWeightsFromFile) { + this(numIterations, readWeightsFromFile, null); + } + + @Override + public void train(StoredCorpus corpus) { + + distanceTable = new DistanceTable();//corpus.getToponymTypeCount()); + + toponymLexicon = TopoUtil.buildLexicon(corpus); + List > counts = new ArrayList >(toponymLexicon.size()); + for(int i = 0; i < toponymLexicon.size(); i++) counts.add(null); + weights = new ArrayList >(toponymLexicon.size()); + for(int i = 0; i < toponymLexicon.size(); i++) weights.add(null); + + if(readWeightsFromFile) { + weightsFromFile = new ArrayList >(toponymLexicon.size()); + try { + DataInputStream in = new DataInputStream(new FileInputStream("probToWMD.dat")); + for(int i = 0; i < toponymLexicon.size(); i++) { + int ambiguity = in.readInt(); + weightsFromFile.add(new ArrayList(ambiguity)); + for(int j = 0; j < ambiguity; j++) { + weightsFromFile.get(i).add(in.readDouble()); + } + } + in.close(); + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + + /*for(int i = 0; i < weightsFromFile.size(); i++) { + for(int j = 0; j < weightsFromFile.get(i).size(); j++) { + System.out.println(weightsFromFile.get(i).get(j)); + } + System.out.println(); + }*/ + } + + initializeCountsAndWeights(counts, weights, corpus, toponymLexicon, PHANTOM_COUNT, weightsFromFile); + + for(int i = 0; i < numIterations; i++) { + System.out.println("Iteration: " + (i+1)); + updateWeights(corpus, counts, PHANTOM_COUNT, weights, toponymLexicon); + } + } + + @Override + public StoredCorpus disambiguate(StoredCorpus corpus) { + + if(weights == null) + train(corpus); + + TopoUtil.addToponymsToLexicon(toponymLexicon, corpus); + weights = expandWeightsArray(toponymLexicon, corpus, weights); + + StoredCorpus disambiguated = finalDisambiguationStep(corpus, weights, toponymLexicon); + + if(readWeightsFromFile) { + // Backoff to DocDist: + Resolver docDistResolver = new DocDistResolver(logFilePath); + docDistResolver.overwriteSelecteds = false; + disambiguated = docDistResolver.disambiguate(corpus); + } + else { + // Backoff to Random: + Resolver randResolver = new RandomResolver(); + randResolver.overwriteSelecteds = false; + disambiguated = randResolver.disambiguate(corpus); + } + + return disambiguated; + } + + // adds a weight of 1.0 to candidate locations of toponyms found in lexicon but not in oldWeights + private List > expandWeightsArray(Lexicon lexicon, StoredCorpus corpus, List > oldWeights) { + if(oldWeights.size() >= lexicon.size()) + return oldWeights; + + List > newWeights = new ArrayList >(lexicon.size()); + for(int i = 0; i < lexicon.size(); i++) newWeights.add(null); + + for(int i = 0; i < oldWeights.size(); i++) { + newWeights.set(i, oldWeights.get(i)); + } + + initializeWeights(newWeights, corpus, lexicon); + + return newWeights; + } + + private void initializeCountsAndWeights(List > counts, List > weights, + StoredCorpus corpus, Lexicon lexicon, int initialCount, + List > weightsFromFile) { + + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int index = lexicon.get(toponym.getForm()); + if(counts.get(index) == null) { + counts.set(index, new ArrayList(toponym.getAmbiguity())); + weights.set(index, new ArrayList(toponym.getAmbiguity())); + for(int i = 0; i < toponym.getAmbiguity(); i++) { + counts.get(index).add(initialCount); + if(weightsFromFile != null + && weightsFromFile.get(index).size() > 0) + weights.get(index).add(weightsFromFile.get(index).get(i)); + else + weights.get(index).add(1.0); + } + } + } + } + } + } + } + + private void initializeWeights(List > weights, StoredCorpus corpus, Lexicon lexicon) { + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int index = lexicon.get(toponym.getForm()); + if(weights.get(index) == null) { + weights.set(index, new ArrayList(toponym.getAmbiguity())); + for(int i = 0; i < toponym.getAmbiguity(); i++) { + weights.get(index).add(1.0); + } + } + } + } + } + } + } + + private void updateWeights(StoredCorpus corpus, List > counts, int initialCount, List > weights, Lexicon lexicon) { + + for(int i = 0; i < counts.size(); i++) + for(int j = 0; j < counts.get(i).size(); j++) + counts.get(i).set(j, initialCount); + + List sums = new ArrayList(counts.size()); + for(int i = 0; i < counts.size(); i++) sums.add(initialCount * counts.get(i).size()); + + for (Document doc : corpus) { + for (Sentence sent : doc) { + for (Toponym toponym : sent.getToponyms()) { + double min = Double.MAX_VALUE; + int minIdx = -1; + + int idx = 0; + for (Location candidate : toponym) { + Double candidateMin = this.checkCandidate(toponym, candidate, idx, doc, min, weights, lexicon); + if (candidateMin != null) { + min = candidateMin; + minIdx = idx; + } + idx++; + } + + + if(minIdx == -1) { // Most likely happens when there was only one toponym in the document; + // so, choose the location with the greatest weight (unless all are uniform) + double maxWeight = 1.0; + int locationIdx = 0; + for(Location candidate : toponym) { + double thisWeight = weights.get(lexicon.get(toponym.getForm())).get(locationIdx); + if(thisWeight > maxWeight) { + maxWeight = thisWeight; + minIdx = locationIdx; + } + locationIdx++; + } + } + + + if (minIdx > -1) { + int countIndex = lexicon.get(toponym.getForm()); + int prevCount = counts.get(countIndex) + .get(minIdx); + counts.get(countIndex).set(minIdx, prevCount + 1); + int prevSum = sums.get(countIndex); + sums.set(countIndex, prevSum + 1); + + } + + } + } + } + + + for(int i = 0; i < weights.size(); i++) { + List curWeights = weights.get(i); + List curCounts = counts.get(i); + int curSum = sums.get(i); + for(int j = 0; j < curWeights.size(); j++) { + curWeights.set(j, ((double)curCounts.get(j) / curSum) * curWeights.size()); + } + } + } + + + /* This implementation of disambiguate immediately stops computing distance + * totals for candidates when it becomes clear that they aren't minimal. */ + private StoredCorpus finalDisambiguationStep(StoredCorpus corpus, List > weights, Lexicon lexicon) { + for (Document doc : corpus) { + for (Sentence sent : doc) { + for (Toponym toponym : sent.getToponyms()) { + double min = Double.MAX_VALUE; + int minIdx = -1; + + int idx = 0; + for (Location candidate : toponym) { + Double candidateMin = this.checkCandidate(toponym, candidate, idx, doc, min, weights, lexicon); + if (candidateMin != null) { + min = candidateMin; + minIdx = idx; + } + idx++; + } + + if(minIdx == -1) { // Most likely happens when there was only one toponym in the document; + // so, choose the location with the greatest weight (unless all are 1.0) + double maxWeight = 1.0; + int locationIdx = 0; + for(Location candidate : toponym) { + double thisWeight = weights.get(lexicon.get(toponym.getForm())).get(locationIdx); + if(thisWeight > maxWeight) { + maxWeight = thisWeight; + minIdx = locationIdx; + } + locationIdx++; + } + } + + if (minIdx > -1) { + toponym.setSelectedIdx(minIdx); + } + } + } + + } + + return corpus; + } + + /* Returns the minimum total distance to all other locations in the document + * for the candidate, or null if it's greater than the current minimum. */ + private Double checkCandidate(Toponym toponymTemp, Location candidate, int locationIndex, Document doc, + double currentMinTotal, List > weights, Lexicon lexicon) { + StoredToponym toponym = (StoredToponym) toponymTemp; + Double total = 0.0; + int seen = 0; + + for (Sentence otherSent : doc) { + for (Toponym otherToponymTemp : otherSent.getToponyms()) { + StoredToponym otherToponym = (StoredToponym) otherToponymTemp; + + /*Map normalizationDenoms = new HashMap(); + for(Location tempOtherLoc : otherToponym) {//int i = 0; i < otherToponym.getAmbiguity(); i++) { + //Location tempOtherLoc = otherToponym.getCandidates().get(i); + double normalizationDenomTemp = 0.0; + for(Location tempLoc : toponym) { //int j = 0; j < toponym.getAmbiguity(); j++) { + //Location tempLoc = toponym.getCandidates().get(j); + normalizationDenomTemp += tempOtherLoc.distance(tempLoc); + } + normalizationDenoms.put(tempOtherLoc, normalizationDenomTemp); + }*/ + + /* We don't want to compute distances if this other toponym is the + * same as the current one, or if it has no candidates. */ + if (!otherToponym.equals(toponym) && otherToponym.getAmbiguity() > 0) { + double min = Double.MAX_VALUE; + //double sum = 0.0; + + int otherLocIndex = 0; + for (Location otherLoc : otherToponym) { + + /*double normalizationDenom = 0.0; + for(Location tempCand : toponym) { + normalizationDenom += tempCand.distance(otherLoc) / weights.get(otherLoc) ; + } + */ + //double normalizationDenom = normalizationDenoms.get(otherToponym); + + //double weightedDist = distanceTable.getDistance(toponym, locationIndex, otherToponym, otherLocIndex);//candidate.distance(otherLoc) /* / weights.get(otherLoc) */ ; + //double weightedDist = candidate.distance(otherLoc); + double weightedDist = distanceTable.distance(candidate, otherLoc); + double thisWeight = weights.get(lexicon.get(toponym.getForm())).get(locationIndex); + double otherWeight = weights.get(lexicon.get(otherToponym.getForm())).get(otherLocIndex); + weightedDist /= (thisWeight * otherWeight); // weighting; was just otherWeight before + //weightedDist /= normalizationDenoms.get(otherLoc); // normalization + if (weightedDist < min) { + min = weightedDist; + } + //sum += weightedDist; + otherLocIndex++; + } + + seen++; + total += min; + //total += sum; + + /* If the running total is greater than the current minimum, we can + * stop. */ + if (total >= currentMinTotal) { + return null; + } + } + } + } + + /* Abstain if we haven't seen any other toponyms. */ + return seen > 0 ? total : null; + } + + /*private double getDistance(StoredToponym t1, int i1, StoredToponym t2, int i2) { + int t1idx = t1.getIdx(); + int t2idx = t2.getIdx(); + + long key = t1idx + i1 * maxCoeff + t2idx * maxCoeff * maxCoeff + i2 * maxCoeff * maxCoeff * maxCoeff; + + Double dist = distanceCache.get(key); + + if(dist == null) { + dist = t1.getCandidates().get(i1).distance(t2.getCandidates().get(i2)); + distanceCache.put(key, dist); + long key2 = t2idx + i2 * maxCoeff + t1idx * maxCoeff * maxCoeff + i1 * maxCoeff * maxCoeff * maxCoeff; + distanceCache.put(key2, dist); + } + + return dist; + }*/ + + /*private class DistanceTable { + //private double[][][][] allDistances; + + public DistanceTable(int numToponymTypes) { + //allDistances = new double[numToponymTypes][numToponymTypes][][]; + } + + public double getDistance(StoredToponym t1, int i1, StoredToponym t2, int i2) { + + return t1.getCandidates().get(i1).distance(t2.getCandidates().get(i2)); + + /* int t1idx = t1.getIdx(); + int t2idx = t2.getIdx(); + + double[][] distanceMatrix = allDistances[t1idx][t2idx]; + if(distanceMatrix == null) { + distanceMatrix = new double[t1.getAmbiguity()][t2.getAmbiguity()]; + for(int i = 0; i < distanceMatrix.length; i++) { + for(int j = 0; j < distanceMatrix[i].length; j++) { + distanceMatrix[i][j] = t1.getCandidates().get(i).distance(t2.getCandidates().get(j)); + } + } + }*SLASH + /*double distance = distanceMatrix[i1][i2]; + if(distance == 0.0) { + distance = t1.getCandidates().get(i1).distance(t2.getCandidates().get(i2)); + distanceMatrix[i1][i2] = distance; + }*SLASH + //return distanceMatrix[i1][i2]; + + } + + //public + }*/ +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/CompactCorpus.java b/src/main/java/opennlp/fieldspring/tr/text/CompactCorpus.java new file mode 100644 index 0000000..5c08260 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/CompactCorpus.java @@ -0,0 +1,441 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.io.*; + +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; +import opennlp.fieldspring.tr.app.*; + +public class CompactCorpus extends StoredCorpus implements Serializable { + + private static final long serialVersionUID = 42L; + + private Corpus wrapped; + + private final CountingLexicon tokenLexicon; + private final CountingLexicon toponymLexicon; + + private final CountingLexicon tokenOrigLexicon; + private final CountingLexicon toponymOrigLexicon; + + private int[] tokenOrigMap; + private int[] toponymOrigMap; + + private int maxToponymAmbiguity; + private double avgToponymAmbiguity = 0.0; + private int tokenCount = 0; + private int toponymTokenCount = 0; + + private final ArrayList> documents; + private final ArrayList> candidateLists; + + CompactCorpus(Corpus wrapped) { + this.wrapped = wrapped; + + this.tokenLexicon = new SimpleCountingLexicon(); + this.toponymLexicon = new SimpleCountingLexicon(); + this.tokenOrigLexicon = new SimpleCountingLexicon(); + this.toponymOrigLexicon = new SimpleCountingLexicon(); + this.maxToponymAmbiguity = 0; + + this.documents = new ArrayList>(); + this.candidateLists = new ArrayList>(); + } + + public int getDocumentCount() { + return this.documents.size(); + } + + public int getTokenTypeCount() { + return this.tokenLexicon.size(); + } + + public int getTokenOrigTypeCount() { + return this.tokenOrigLexicon.size(); + } + + public int getToponymTypeCount() { + return this.toponymLexicon.size(); + } + + public int getToponymOrigTypeCount() { + return this.toponymOrigLexicon.size(); + } + + public int getMaxToponymAmbiguity() { + return this.maxToponymAmbiguity; + } + + public double getAvgToponymAmbiguity() { + return this.avgToponymAmbiguity; + } + + public int getTokenCount() { + return this.tokenCount; + } + + public int getToponymTokenCount() { + return this.toponymTokenCount; + } + + public void load() { + for (Document document : wrapped) { + ArrayList> sentences = new ArrayList>(); + + for (Sentence sentence : document) { + List tokens = sentence.getTokens(); + int[] tokenIdxs = new int[tokens.size()]; + this.tokenCount += tokens.size(); + + for (int i = 0; i < tokenIdxs.length; i++) { + Token token = tokens.get(i); + tokenIdxs[i] = this.tokenOrigLexicon.getOrAdd(token.getOrigForm()); + this.tokenLexicon.getOrAdd(token.getForm()); + } + + StoredSentence stored = new StoredSentence(sentence.getId(), tokenIdxs); + + for (Iterator> it = sentence.toponymSpans(); it.hasNext(); ) { + Span span = it.next(); + Toponym toponym = (Toponym) span.getItem(); + + this.toponymTokenCount++; + + this.avgToponymAmbiguity += toponym.getAmbiguity(); + + if (toponym.getAmbiguity() > this.maxToponymAmbiguity) { + this.maxToponymAmbiguity = toponym.getAmbiguity(); + } + + int idx = this.toponymOrigLexicon.getOrAdd(toponym.getOrigForm()); + this.toponymLexicon.getOrAdd(toponym.getForm()); + + if (toponym.hasGold()) { + int goldIdx = toponym.getGoldIdx(); + if (toponym.hasSelected()) { + int selectedIdx = toponym.getSelectedIdx(); + stored.addToponym(span.getStart(), span.getEnd(), idx, goldIdx, selectedIdx); + } else { + stored.addToponym(span.getStart(), span.getEnd(), idx, goldIdx); + } + } else { + if(toponym.hasSelected()) { + int selectedIdx = toponym.getSelectedIdx(); + stored.addToponym(span.getStart(), span.getEnd(), idx, -1, selectedIdx); + } + else { + stored.addToponym(span.getStart(), span.getEnd(), idx); + } + } + + if (this.candidateLists.size() <= idx) { + this.candidateLists.add(toponym.getCandidates()); + } else { + this.candidateLists.set(idx, toponym.getCandidates()); + } + } + + stored.compact(); + sentences.add(stored); + } + + sentences.trimToSize(); + if(this.getFormat() == BaseApp.CORPUS_FORMAT.GEOTEXT) { + this.documents.add(new StoredDocument(document.getId(), sentences, + document.getTimestamp(), + document.getGoldCoord(), document.getSystemCoord(), document.getSection())); + } + else + this.documents.add(new StoredDocument(document.getId(), sentences)); + } + + this.avgToponymAmbiguity /= this.toponymTokenCount; + + this.tokenOrigMap = new int[this.tokenOrigLexicon.size()]; + this.toponymOrigMap = new int[this.toponymOrigLexicon.size()]; + + int i = 0; + for (String entry : this.tokenOrigLexicon) { + this.tokenOrigMap[i] = this.tokenLexicon.get(entry.toLowerCase()); + i++; + } + + i = 0; + for (String entry : this.toponymOrigLexicon) { + this.toponymOrigMap[i] = this.toponymLexicon.get(entry.toLowerCase()); + i++; + } + + this.wrapped.close(); + this.wrapped = null; + + this.removeNaNs(); + } + + private void removeNaNs() { + for(List candidates : candidateLists) { + for(Location loc : candidates) { + List reps = loc.getRegion().getRepresentatives(); + int prevSize = reps.size(); + Coordinate.removeNaNs(reps); + if(reps.size() < prevSize) + loc.getRegion().setCenter(Coordinate.centroid(reps)); + } + } + } + + public void addSource(DocumentSource source) { + if (this.wrapped == null) { + throw new UnsupportedOperationException("Cannot add a source to a stored corpus after it has been loaded."); + } else { + this.wrapped.addSource(source); + } + } + + public Iterator> iterator() { + if (this.wrapped != null) { + this.load(); + } + + return this.documents.iterator(); + } + + private class StoredDocument extends Document implements Serializable { + + private static final long serialVersionUID = 42L; + + private final List> sentences; + + private StoredDocument(String id, List> sentences) { + super(id); + this.sentences = sentences; + } + + private StoredDocument(String id, List> sentences, String timestamp, Coordinate goldCoord, Coordinate systemCoord) { + this(id, sentences, timestamp, goldCoord); + this.systemCoord = systemCoord; + } + + private StoredDocument(String id, List> sentences, String timestamp, double goldLat, double goldLon) { + this(id, sentences); + this.timestamp = timestamp; + this.goldCoord = Coordinate.fromDegrees(goldLat, goldLon); + } + + private StoredDocument(String id, List> sentences, String timestamp, Coordinate goldCoord) { + this(id, sentences); + this.timestamp = timestamp; + this.goldCoord = goldCoord; + } + + private StoredDocument(String id, List> sentences, String timestamp, Coordinate goldCoord, Coordinate systemCoord, Enum section) { + this(id, sentences, timestamp, goldCoord, systemCoord); + this.section = section; + } + + public Iterator> iterator() { + return this.sentences.iterator(); + } + } + + private class StoredSentence extends Sentence implements Serializable { + + private static final long serialVersionUID = 42L; + + private final int[] tokens; + private final ArrayList> toponymSpans; + + private StoredSentence(String id, int[] tokens) { + super(id); + this.tokens = tokens; + this.toponymSpans = new ArrayList>(); + } + + private void addToponym(int start, int end, int toponymIdx) { + this.toponymSpans.add(new Span(start, end, new CompactToponym(toponymIdx))); + } + + private void addToponym(int start, int end, int toponymIdx, int goldIdx) { + this.toponymSpans.add(new Span(start, end, new CompactToponym(toponymIdx, goldIdx))); + } + + private void addToponym(int start, int end, int toponymIdx, int goldIdx, int selectedIdx) { + this.toponymSpans.add(new Span(start, end, new CompactToponym(toponymIdx, goldIdx, selectedIdx))); + } + + private void compact() { + this.toponymSpans.trimToSize(); + } + + private class CompactToponym implements StoredToponym { + + private static final long serialVersionUID = 42L; + + private final int idx; + private int goldIdx; + private int selectedIdx; + + private CompactToponym(int idx) { + this(idx, -1); + } + + private CompactToponym(int idx, int goldIdx) { + this(idx, goldIdx, -1); + } + + private CompactToponym(int idx, int goldIdx, int selectedIdx) { + this.idx = idx; + this.goldIdx = goldIdx; + this.selectedIdx = selectedIdx; + } + + public String getForm() { + return CompactCorpus.this.toponymLexicon.atIndex(CompactCorpus.this.toponymOrigMap[this.idx]); + } + + public String getOrigForm() { + return CompactCorpus.this.toponymOrigLexicon.atIndex(this.idx); + } + + public boolean isToponym() { + return true; + } + + public boolean hasGold() { return this.goldIdx > -1; } + public Location getGold() { + if (this.goldIdx == -1) { + return null; + } else { + List candList = CompactCorpus.this.candidateLists.get(this.idx); + return CompactCorpus.this.candidateLists.get(this.idx).get(this.goldIdx -1; } + public Location getSelected() { + if (this.selectedIdx == -1) { + return null; + } else { + return CompactCorpus.this.candidateLists.get(this.idx).get(this.selectedIdx); + } + } + + public int getSelectedIdx() { return this.selectedIdx; } + public void setSelectedIdx(int idx) { this.selectedIdx = idx; } + + public int getAmbiguity() { return CompactCorpus.this.candidateLists.get(this.idx).size(); } + public List getCandidates() { return CompactCorpus.this.candidateLists.get(this.idx); } + public void setCandidates(List candidates) { CompactCorpus.this.candidateLists.set(this.idx, candidates); } + public Iterator iterator() { return CompactCorpus.this.candidateLists.get(this.idx).iterator(); } + + public List getTokens() { throw new UnsupportedOperationException(); } + + public int getIdx() { + return CompactCorpus.this.toponymOrigMap[this.idx]; + } + + public int getOrigIdx() { + return this.idx; + } + + public int getTypeCount() { + return CompactCorpus.this.toponymLexicon.countAtIndex(CompactCorpus.this.toponymOrigMap[this.idx]); + } + + public int getOrigTypeCount() { + return CompactCorpus.this.toponymOrigLexicon.countAtIndex(idx); + } + + @Override + public boolean equals(Object other) { + return other != null && + other.getClass() == this.getClass() && + ((StoredToponym) other).getIdx() == this.getIdx(); + } + } + + public Iterator tokens() { + return new Iterator() { + private int current = 0; + + public boolean hasNext() { + return this.current < StoredSentence.this.tokens.length; + } + + public StoredToken next() { + final int current = this.current++; + final int idx = StoredSentence.this.tokens[current]; + return new StoredToken() { + + private static final long serialVersionUID = 42L; + + public String getForm() { + return CompactCorpus.this.tokenLexicon.atIndex(CompactCorpus.this.tokenOrigMap[idx]); + } + + public String getOrigForm() { + return CompactCorpus.this.tokenOrigLexicon.atIndex(idx); + } + + public boolean isToponym() { + return false; + } + + public int getIdx() { + return CompactCorpus.this.tokenOrigMap[idx]; + } + + public int getOrigIdx() { + return idx; + } + + public int getTypeCount() { + return CompactCorpus.this.tokenLexicon.countAtIndex(CompactCorpus.this.tokenOrigMap[idx]); + } + + public int getOrigTypeCount() { + return CompactCorpus.this.tokenOrigLexicon.countAtIndex(idx); + } + }; + } + + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + public Iterator> toponymSpans() { + return this.toponymSpans.iterator(); + } + } + + public void close() { + if (this.wrapped != null) { + this.wrapped.close(); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/Corpus.java b/src/main/java/opennlp/fieldspring/tr/text/Corpus.java new file mode 100644 index 0000000..c541cc6 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/Corpus.java @@ -0,0 +1,61 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.Iterator; + +import opennlp.fieldspring.tr.util.Lexicon; +import opennlp.fieldspring.tr.app.*; +import java.io.*; + +public abstract class Corpus implements Iterable>, Serializable { + + private static Enum corpusFormat = null;//BaseApp.CORPUS_FORMAT.PLAIN; + + public abstract void addSource(DocumentSource source); + public abstract void close(); + + public static Corpus createStreamCorpus() { + return new StreamCorpus(); + } + + public static StoredCorpus createStoredCorpus() { + return new CompactCorpus(Corpus.createStreamCorpus()); + } + + public DocumentSource asSource() { + final Iterator> iterator = this.iterator(); + + return new DocumentSource() { + public boolean hasNext() { + return iterator.hasNext(); + } + + public Document next() { + return (Document) iterator.next(); + } + }; + } + + public Enum getFormat() { + return corpusFormat; + } + + public void setFormat(Enum corpusFormat) { + this.corpusFormat = corpusFormat; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/Document.java b/src/main/java/opennlp/fieldspring/tr/text/Document.java new file mode 100644 index 0000000..8c1ace7 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/Document.java @@ -0,0 +1,106 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.io.*; +import opennlp.fieldspring.tr.topo.*; + +public abstract class Document implements Iterable>, Serializable { + protected final String id; + public String title = null; + protected Coordinate goldCoord; + protected Coordinate systemCoord; + protected String timestamp; + + public static enum SECTION { + TRAIN, + DEV, + TEST, + ANY + } + protected Enum
section; + + public Document(String id) { + this(id, null, null, null); + } + + public Document(String id, String title) { + this(id); + this.title = title; + } + + public Document(String id, String timestamp, Coordinate goldCoord) { + this(id, timestamp, goldCoord, null); + } + + public Document(String id, String timestamp, Coordinate goldCoord, Coordinate systemCoord) { + this(id, timestamp, goldCoord, systemCoord, SECTION.ANY); + } + + public Document(String id, String timestamp, Coordinate goldCoord, Coordinate systemCoord, Enum
section) { + this(id, timestamp, goldCoord, systemCoord, section, null); + } + + public Document(String id, String timestamp, Coordinate goldCoord, Coordinate systemCoord, Enum
section, String title) { + this.id = id; + this.timestamp = timestamp; + this.goldCoord = goldCoord; + this.systemCoord = systemCoord; + this.section = section; + this.title = title; + } + + public String getId() { + return this.id; + } + + public Coordinate getGoldCoord() { + return this.goldCoord; + } + + public Coordinate getSystemCoord() { + return this.systemCoord; + } + + public void setSystemCoord(Coordinate systemCoord) { + this.systemCoord = systemCoord; + } + + public void setSystemCoord(double systemLat, double systemLon) { + this.systemCoord = Coordinate.fromDegrees(systemLat, systemLon); + } + + public String getTimestamp() { + return this.timestamp; + } + + public Enum
getSection() { + return section; + } + + public boolean isTrain() { + return getSection() == SECTION.TRAIN; + } + + public boolean isDev() { + return getSection() == SECTION.DEV; + } + + public boolean isTest() { + return getSection() == SECTION.TEST; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/DocumentSource.java b/src/main/java/opennlp/fieldspring/tr/text/DocumentSource.java new file mode 100644 index 0000000..e123794 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/DocumentSource.java @@ -0,0 +1,38 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.Iterator; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.Token; + +public abstract class DocumentSource implements Iterator> { + public void close() { + } + + public void remove() { + throw new UnsupportedOperationException("Cannot remove a document from a source."); + } + + protected abstract class SentenceIterator implements Iterator> { + public void remove() { + throw new UnsupportedOperationException("Cannot remove a sentence from a source."); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/DocumentSourceWrapper.java b/src/main/java/opennlp/fieldspring/tr/text/DocumentSourceWrapper.java new file mode 100644 index 0000000..4812806 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/DocumentSourceWrapper.java @@ -0,0 +1,58 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.Iterator; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.Token; + +/** + * Wraps a document source in order to perform some operation on it. + * + * @author Travis Brown + * @version 0.1.0 + */ +public abstract class DocumentSourceWrapper extends DocumentSource { + private final DocumentSource source; + + public DocumentSourceWrapper(DocumentSource source) { + this.source = source; + } + + /** + * Closes the underlying source. + */ + public void close() { + this.source.close(); + } + + /** + * Indicates whether the underlying source has more documents. + */ + public boolean hasNext() { + return this.source.hasNext(); + } + + /** + * Returns the underlying source (for use in subclasses). + */ + protected DocumentSource getSource() { + return this.source; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/GeoTextDocument.java b/src/main/java/opennlp/fieldspring/tr/text/GeoTextDocument.java new file mode 100644 index 0000000..9cf0105 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/GeoTextDocument.java @@ -0,0 +1,47 @@ +package opennlp.fieldspring.tr.text; + +import java.io.*; +import java.util.*; + +import opennlp.fieldspring.tr.topo.*; + +public class GeoTextDocument extends Document { + + private static final long serialVersionUID = 42L; + + private List> sentences; + + public GeoTextDocument(String id, String timestamp, double goldLat, double goldLon) { + super(id); + this.timestamp = timestamp; + this.goldCoord = Coordinate.fromDegrees(goldLat, goldLon); + this.sentences = new ArrayList>(); + this.systemCoord = null; + this.timestamp = null; + } + + public GeoTextDocument(String id, String timestamp, double goldLat, double goldLon, Enum section) { + this(id, timestamp, goldLat, goldLon); + this.section = section; + } + + public GeoTextDocument(String id, String timestamp, double goldLat, double goldLon, long fold) { + this(id, timestamp, goldLat, goldLon); + if(fold >= 1 && fold <= 3) + this.section = Document.SECTION.TRAIN; + else if(fold == 4) + this.section = Document.SECTION.DEV; + else if(fold == 5) + this.section = Document.SECTION.TEST; + else + this.section = Document.SECTION.ANY; + } + + public void addSentence(Sentence sentence) { + sentences.add(sentence); + } + + public Iterator> iterator() { + return sentences.iterator(); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/Sentence.java b/src/main/java/opennlp/fieldspring/tr/text/Sentence.java new file mode 100644 index 0000000..9fe77cd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/Sentence.java @@ -0,0 +1,103 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.io.*; + +import opennlp.fieldspring.tr.util.Span; + +public abstract class Sentence implements Iterable, Serializable { + private final String id; + + protected Sentence(String id) { + this.id = id; + } + + public abstract Iterator tokens(); + + public Iterator> toponymSpans() { + return new Iterator>() { + public boolean hasNext() { + return false; + } + + public Span next() { + throw new NoSuchElementException(); + } + + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + public List getTokens() { + List tokens = new ArrayList(); + for (Iterator it = this.tokens(); it.hasNext(); ) { + tokens.add(it.next()); + } + return tokens; + } + + public List getToponyms() { + List toponyms = new ArrayList(); + for (Iterator> it = this.toponymSpans(); it.hasNext(); ) { + toponyms.add((Toponym) it.next().getItem()); + } + return toponyms; + } + + public String getId() { + return this.id; + } + + public Iterator iterator() { + return new Iterator() { + private final Iterator tokens = Sentence.this.tokens(); + private final Iterator> spans = Sentence.this.toponymSpans(); + private int current = 0; + private Span span = this.spans.hasNext() ? this.spans.next() : null; + + public boolean hasNext() { + return this.tokens.hasNext(); + } + + public A next() { + if (this.span != null && this.span.getStart() == this.current) { + A toponym = span.getItem(); + for (int i = 0; i < this.span.getEnd() - this.span.getStart(); i++) { + this.tokens.next(); + } + this.current = this.span.getEnd(); + this.span = this.spans.hasNext() ? this.spans.next() : null; + return toponym; + } else { + this.current++; + return this.tokens.next(); + } + } + + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/SimpleSentence.java b/src/main/java/opennlp/fieldspring/tr/text/SimpleSentence.java new file mode 100644 index 0000000..e5c153b --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/SimpleSentence.java @@ -0,0 +1,51 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import opennlp.fieldspring.tr.util.Span; +import java.io.*; + +public class SimpleSentence extends Sentence implements Serializable { + + private static final long serialVersionUID = 42L; + + private final List tokens; + private final List> toponymSpans; + + public SimpleSentence(String id, List tokens) { + this(id, tokens, new ArrayList>()); + } + + public SimpleSentence(String id, List tokens, List> toponymSpans) { + super(id); + this.tokens = tokens; + this.toponymSpans = toponymSpans; + } + + public Iterator tokens() { + return this.tokens.iterator(); + } + + public Iterator> toponymSpans() { + return this.toponymSpans.iterator(); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/SimpleToken.java b/src/main/java/opennlp/fieldspring/tr/text/SimpleToken.java new file mode 100644 index 0000000..4b87ebd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/SimpleToken.java @@ -0,0 +1,42 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.io.*; + +public class SimpleToken implements Token, Serializable { + + private static final long serialVersionUID = 42L; + + private final String form; + + public SimpleToken(String form) { + this.form = form; + } + + public String getForm() { + return this.form.toLowerCase(); + } + + public String getOrigForm() { + return this.form; + } + + public boolean isToponym() { + return false; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/SimpleToponym.java b/src/main/java/opennlp/fieldspring/tr/text/SimpleToponym.java new file mode 100644 index 0000000..d8e6740 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/SimpleToponym.java @@ -0,0 +1,104 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.topo.Location; +import java.io.*; + +public class SimpleToponym extends SimpleToken implements Toponym, Serializable { + + private static final long serialVersionUID = 42L; + + private List candidates; + private int goldIdx; + private int selectedIdx; + + public SimpleToponym(String form, List candidates) { + this(form, candidates, -1); + } + + public SimpleToponym(String form, List candidates, int goldIdx) { + this(form, candidates, goldIdx, -1); + } + + public SimpleToponym(String form, List candidates, int goldIdx, int selectedIdx) { + super(form); + this.candidates = candidates; + this.goldIdx = goldIdx; + this.selectedIdx = selectedIdx; + } + + public boolean hasGold() { + return this.goldIdx > -1; + } + + public Location getGold() { + return this.candidates.get(this.goldIdx); + } + + public int getGoldIdx() { + return this.goldIdx; + } + + public void setGoldIdx(int idx) { + this.goldIdx = idx; + } + + public boolean hasSelected() { + return this.selectedIdx > -1; + } + + public Location getSelected() { + return this.candidates.get(this.selectedIdx); + } + + public int getSelectedIdx() { + return this.selectedIdx; + } + + public void setSelectedIdx(int idx) { + this.selectedIdx = idx; + } + + public int getAmbiguity() { + return this.candidates.size(); + } + + public List getCandidates() { + return this.candidates; + } + + public void setCandidates(List candidates) { + this.candidates = candidates; + } + + public List getTokens() { + throw new UnsupportedOperationException("Can't currently get tokens."); + } + + public Iterator iterator() { + return this.candidates.iterator(); + } + + @Override + public boolean isToponym() { + return true; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/StoredCorpus.java b/src/main/java/opennlp/fieldspring/tr/text/StoredCorpus.java new file mode 100644 index 0000000..f1ca273 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/StoredCorpus.java @@ -0,0 +1,35 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.util.CountingLexicon; +import opennlp.fieldspring.tr.util.SimpleCountingLexicon; +import opennlp.fieldspring.tr.util.Span; +import java.io.*; + +public abstract class StoredCorpus extends Corpus implements Serializable { + public abstract int getDocumentCount(); + public abstract int getTokenTypeCount(); + public abstract int getTokenOrigTypeCount(); + public abstract int getToponymTypeCount(); + public abstract int getToponymOrigTypeCount(); + public abstract int getMaxToponymAmbiguity(); + public abstract double getAvgToponymAmbiguity(); + public abstract int getTokenCount(); + public abstract int getToponymTokenCount(); + public abstract void load(); +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/StoredToken.java b/src/main/java/opennlp/fieldspring/tr/text/StoredToken.java new file mode 100644 index 0000000..8b864cd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/StoredToken.java @@ -0,0 +1,25 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; +import java.io.*; + +public interface StoredToken extends Token, Serializable { + public int getIdx(); + public int getOrigIdx(); + public int getTypeCount(); + public int getOrigTypeCount(); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/StoredToponym.java b/src/main/java/opennlp/fieldspring/tr/text/StoredToponym.java new file mode 100644 index 0000000..2d4e2ad --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/StoredToponym.java @@ -0,0 +1,21 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; +import java.io.*; + +public interface StoredToponym extends StoredToken, Toponym, Serializable { +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/StreamCorpus.java b/src/main/java/opennlp/fieldspring/tr/text/StreamCorpus.java new file mode 100644 index 0000000..176d61f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/StreamCorpus.java @@ -0,0 +1,56 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import com.google.common.collect.Iterators; + +public class StreamCorpus extends Corpus { + + private static final long serialVersionUID = 42L; + + private final List sources; + private boolean read; + + StreamCorpus() { + this.sources = new ArrayList(); + this.read = false; + } + + public Iterator> iterator() { + if (this.read) { + throw new UnsupportedOperationException("Cannot read a stream corpus more than once."); + } else { + this.read = true; + return Iterators.concat(this.sources.iterator()); + } + } + + public void addSource(DocumentSource source) { + this.sources.add(source); + } + + public void close() { + for (DocumentSource source : this.sources) { + source.close(); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/Token.java b/src/main/java/opennlp/fieldspring/tr/text/Token.java new file mode 100644 index 0000000..7e85fec --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/Token.java @@ -0,0 +1,25 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.io.*; + +public interface Token extends Serializable { + public String getForm(); + public String getOrigForm(); + public boolean isToponym(); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/Toponym.java b/src/main/java/opennlp/fieldspring/tr/text/Toponym.java new file mode 100644 index 0000000..8558edd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/Toponym.java @@ -0,0 +1,40 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text; + +import java.util.List; +import java.io.*; + +import opennlp.fieldspring.tr.topo.Location; + +public interface Toponym extends Token, Iterable, Serializable { + public boolean hasGold(); + public Location getGold(); + public int getGoldIdx(); + public void setGoldIdx(int idx); + + public boolean hasSelected(); + public Location getSelected(); + public int getSelectedIdx(); + public void setSelectedIdx(int idx); + + public int getAmbiguity(); + public List getCandidates(); + public void setCandidates(List candidates); + + public List getTokens(); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/CorpusKMLWriter.java b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusKMLWriter.java new file mode 100644 index 0000000..774d758 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusKMLWriter.java @@ -0,0 +1,179 @@ +package opennlp.fieldspring.tr.text.io; + +import java.io.*; +import java.util.*; +import javax.xml.datatype.*; +import javax.xml.stream.*; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; + +public class CorpusKMLWriter { + + protected final Corpus corpus; + protected final XMLOutputFactory factory; + + protected Map locationCounts; + protected Map > contexts; + + protected boolean outputGoldLocations; + + public CorpusKMLWriter(Corpus corpus, boolean outputGoldLocations) { + this.corpus = corpus; + this.factory = XMLOutputFactory.newInstance(); + this.outputGoldLocations = outputGoldLocations; + } + + public CorpusKMLWriter(Corpus corpus) { + this(corpus, false); + } + + private void countLocationsAndPopulateContexts(Corpus corpus) { + locationCounts = new HashMap(); + contexts = new HashMap >(); + + for(Document doc : corpus) { + int sentIndex = 0; + for(Sentence sent : doc) { + int tokenIndex = 0; + //for(Toponym toponym : sent.getToponyms()) { + for(Token token : sent) { + if(token.isToponym()) { + Toponym toponym = (Toponym) token; + if((!outputGoldLocations && toponym.getAmbiguity() > 0 && toponym.hasSelected()) + || (outputGoldLocations && toponym.hasGold())) { + opennlp.fieldspring.tr.topo.Location loc; + if(!outputGoldLocations) + loc = toponym.getCandidates().get(toponym.getSelectedIdx()); + else + loc = toponym.getCandidates().get(toponym.getGoldIdx()); + Integer prevCount = locationCounts.get(loc); + if(prevCount == null) + prevCount = 0; + locationCounts.put(loc, prevCount + 1); + + List curContexts = contexts.get(loc); + if(curContexts == null) + curContexts = new ArrayList(); + curContexts.add(getContextAround(doc, sentIndex, tokenIndex)); + contexts.put(loc, curContexts); + } + } + tokenIndex++; + } + sentIndex++; + } + } + } + + private String getContextAround(Document doc, int sentIndex, int tokenIndex) { + StringBuffer sb = new StringBuffer(); + + int curSentIndex = 0; + for(Sentence sent : doc) { + if(curSentIndex == sentIndex - 1 || curSentIndex == sentIndex + 1) { + for(Token token : sent) { + sb.append(token.getOrigForm()).append(" "); + //if(StringUtil.containsAlphanumeric(token.getOrigForm())) + // sb.append(" "); + } + if(curSentIndex == sentIndex + 1) + break; + } + else if(curSentIndex == sentIndex) { + int curTokenIndex = 0; + for(Token token : sent) { + if(curTokenIndex == tokenIndex) + sb.append(" "); + sb.append(token.getOrigForm()).append(" "); + if(curTokenIndex == tokenIndex) + sb.append(" "); + //if(StringUtil.containsAlphanumeric(token.getOrigForm())) + // sb.append(" "); + curTokenIndex++; + } + } + curSentIndex++; + } + + if(sb.length() > 0 || sb.charAt(sb.length() - 1) == ' ') + sb.deleteCharAt(sb.length()-1); + + return sb.toString(); + } + + protected XMLGregorianCalendar getCalendar() throws Exception { + return this.getCalendar(new Date()); + } + + protected XMLGregorianCalendar getCalendar(Date time) throws Exception { + XMLGregorianCalendar xgc = null; + GregorianCalendar gc = new GregorianCalendar(); + gc.setTime(time); + + xgc = DatatypeFactory.newInstance().newXMLGregorianCalendar(gc); + + return xgc; + } + + protected XMLStreamWriter createXMLStreamWriter(Writer writer) throws XMLStreamException { + return this.factory.createXMLStreamWriter(writer); + } + + protected XMLStreamWriter createXMLStreamWriter(OutputStream stream) throws XMLStreamException { + return this.factory.createXMLStreamWriter(stream, "UTF-8"); + } + + public void write(File file) throws Exception { + assert(!file.isDirectory()); + OutputStream stream = new BufferedOutputStream(new FileOutputStream(file)); + this.write(this.createXMLStreamWriter(stream)); + stream.close(); + } + + public void write(OutputStream stream) throws Exception { + this.write(this.createXMLStreamWriter(stream)); + } + + public void write(Writer writer) throws Exception { + this.write(this.createXMLStreamWriter(writer)); + } + + protected void write(XMLStreamWriter out) throws Exception { + + countLocationsAndPopulateContexts(corpus); + + KMLUtil.writeHeader(out, "corpus"); + + for (opennlp.fieldspring.tr.topo.Location loc : locationCounts.keySet()) { + this.writePlacemarkAndPolygon(out, loc); + this.writeContexts(out, loc); + } + + KMLUtil.writeFooter(out); + + out.close(); + } + + protected void writePlacemarkAndPolygon(XMLStreamWriter out, opennlp.fieldspring.tr.topo.Location loc) throws Exception { + String name = loc.getName(); + Coordinate coord = loc.getRegion().getCenter(); + int count = locationCounts.get(loc); + + KMLUtil.writePolygon(out, name, coord, KMLUtil.SIDES, KMLUtil.RADIUS, Math.log(count) * KMLUtil.BARSCALE); + } + + protected void writeContexts(XMLStreamWriter out, opennlp.fieldspring.tr.topo.Location loc) { + int i = 0; + for(String curContext : contexts.get(loc)) { + try { + KMLUtil.writeSpiralPoint(out, loc.getName(), i, curContext, loc.getRegion().getCenter().getNthSpiralPoint(i, KMLUtil.SPIRAL_RADIUS), KMLUtil.RADIUS); + } catch(Exception e) { + e.printStackTrace(); + System.exit(1); + } + i++; + } + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLSource.java new file mode 100644 index 0000000..e2c270e --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLSource.java @@ -0,0 +1,175 @@ +package opennlp.fieldspring.tr.text.io; + +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; +import java.io.*; +import java.util.*; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.util.*; +import opennlp.fieldspring.tr.topo.*; + +public class CorpusXMLSource extends DocumentSource { + private final XMLStreamReader in; + private final Tokenizer tokenizer; + + public CorpusXMLSource(Reader reader, Tokenizer tokenizer) throws XMLStreamException { + this.tokenizer = tokenizer; + + XMLInputFactory factory = XMLInputFactory.newInstance(); + this.in = factory.createXMLStreamReader(reader); + + while (this.in.hasNext() && this.in.next() != XMLStreamReader.START_ELEMENT) {} + if (this.in.getLocalName().equals("corpus")) { + this.in.nextTag(); + } + } + + private void nextTag() { + try { + this.in.nextTag(); + } catch (XMLStreamException e) { + System.err.println("Error while advancing TR-XML file."); + } + } + + public void close() { + try { + this.in.close(); + } catch (XMLStreamException e) { + System.err.println("Error while closing TR-XML file."); + } + } + + public boolean hasNext() { + return this.in.isStartElement() && this.in.getLocalName().equals("doc"); + } + + public Document next() { + assert this.in.isStartElement() && this.in.getLocalName().equals("doc"); + String id = CorpusXMLSource.this.in.getAttributeValue(null, "id"); + String goldLatS = CorpusXMLSource.this.in.getAttributeValue(null, "goldLat"); + String goldLngS = CorpusXMLSource.this.in.getAttributeValue(null, "goldLng"); + Coordinate goldCoord = null; + if(goldLatS != null && goldLngS != null) + goldCoord = Coordinate.fromDegrees(Double.parseDouble(goldLatS), Double.parseDouble(goldLngS)); + String systemLatS = CorpusXMLSource.this.in.getAttributeValue(null, "systemLat"); + String systemLngS = CorpusXMLSource.this.in.getAttributeValue(null, "systemLng"); + Coordinate systemCoord = null; + if(systemLatS != null && systemLngS != null) + systemCoord = Coordinate.fromDegrees(Double.parseDouble(systemLatS), Double.parseDouble(systemLngS)); + String timestamp = CorpusXMLSource.this.in.getAttributeValue(null, "timestamp"); + + CorpusXMLSource.this.nextTag(); + + return new Document(id, timestamp, goldCoord, systemCoord) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + if (CorpusXMLSource.this.in.isStartElement() && + CorpusXMLSource.this.in.getLocalName().equals("s")) { + return true; + } else { + return false; + } + } + + public Sentence next() { + String id = CorpusXMLSource.this.in.getAttributeValue(null, "id"); + List tokens = new ArrayList(); + List> toponymSpans = new ArrayList>(); + + try { + while (CorpusXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + (CorpusXMLSource.this.in.getLocalName().equals("w") || + CorpusXMLSource.this.in.getLocalName().equals("toponym"))) { + String name = CorpusXMLSource.this.in.getLocalName(); + + if (name.equals("w")) { + tokens.add(new SimpleToken(CorpusXMLSource.this.in.getAttributeValue(null, "tok"))); + } else { + int spanStart = tokens.size(); + String form = CorpusXMLSource.this.in.getAttributeValue(null, "term"); + List formTokens = CorpusXMLSource.this.tokenizer.tokenize(form); + + for (String formToken : CorpusXMLSource.this.tokenizer.tokenize(form)) { + tokens.add(new SimpleToken(formToken)); + } + + ArrayList locations = new ArrayList(); + int goldIdx = -1; + int selectedIdx = -1; + + if (CorpusXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + CorpusXMLSource.this.in.getLocalName().equals("candidates")) { + while (CorpusXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + CorpusXMLSource.this.in.getLocalName().equals("cand")) { + String gold = CorpusXMLSource.this.in.getAttributeValue(null, "gold"); + String selected = CorpusXMLSource.this.in.getAttributeValue(null, "selected"); + if (selected != null && (selected.equals("yes") || selected.equals("true"))) { + selectedIdx = locations.size(); + } + if (gold != null && (gold.equals("yes") || selected.equals("true"))) { + goldIdx = locations.size(); + } + + String locId = CorpusXMLSource.this.in.getAttributeValue(null, "id"); + String type = CorpusXMLSource.this.in.getAttributeValue(null, "type"); + String popString = CorpusXMLSource.this.in.getAttributeValue(null, "population"); + Integer population = null; + if(popString != null) + population = Integer.parseInt(popString); + String admin1code = CorpusXMLSource.this.in.getAttributeValue(null, "admin1code"); + + ArrayList representatives = new ArrayList(); + if(CorpusXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + CorpusXMLSource.this.in.getLocalName().equals("representatives")) { + while(CorpusXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + CorpusXMLSource.this.in.getLocalName().equals("rep")) { + double lat = Double.parseDouble(CorpusXMLSource.this.in.getAttributeValue(null, "lat")); + double lng = Double.parseDouble(CorpusXMLSource.this.in.getAttributeValue(null, "long")); + representatives.add(Coordinate.fromDegrees(lat, lng)); + CorpusXMLSource.this.nextTag(); + assert CorpusXMLSource.this.in.isEndElement() && + CorpusXMLSource.this.in.getLocalName().equals("rep"); + } + } + + Region region = new PointSetRegion(representatives); + Location loc = new Location(locId, form, region, type, population, admin1code); + locations.add(loc); + CorpusXMLSource.this.nextTag(); + assert CorpusXMLSource.this.in.isEndElement() && + CorpusXMLSource.this.in.getLocalName().equals("cand"); + } + } + + if (locations.size() > 0 /*&& goldIdx > -1*/) { + Toponym toponym = new SimpleToponym(form, locations, goldIdx, selectedIdx); + toponymSpans.add(new Span(spanStart, tokens.size(), toponym)); + } + } + CorpusXMLSource.this.nextTag(); + assert CorpusXMLSource.this.in.isStartElement() && + (CorpusXMLSource.this.in.getLocalName().equals("w") || + CorpusXMLSource.this.in.getLocalName().equals("toponym")); + } + } catch (XMLStreamException e) { + System.err.println("Error while reading TR-XML file."); + } + + CorpusXMLSource.this.nextTag(); + if(CorpusXMLSource.this.in.getLocalName().equals("doc") + && CorpusXMLSource.this.in.isEndElement()) + CorpusXMLSource.this.nextTag(); + return new SimpleSentence(id, tokens, toponymSpans); + } + }; + } + }; + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLWriter.java b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLWriter.java new file mode 100644 index 0000000..4a6ae47 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/CorpusXMLWriter.java @@ -0,0 +1,264 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; + +import java.util.ArrayList; +import java.util.Date; +import java.util.Iterator; +import java.util.GregorianCalendar; +import java.util.List; + +import javax.xml.datatype.DatatypeConfigurationException; +import javax.xml.datatype.DatatypeFactory; +import javax.xml.datatype.XMLGregorianCalendar; + +import javax.xml.stream.XMLOutputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamWriter; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; + +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.Coordinate; + +import opennlp.fieldspring.tr.app.BaseApp; + +public class CorpusXMLWriter { + protected final Corpus corpus; + protected final XMLOutputFactory factory; + + public CorpusXMLWriter(Corpus corpus) { + this.corpus = corpus; + this.factory = XMLOutputFactory.newInstance(); + } + + protected XMLGregorianCalendar getCalendar() { + return this.getCalendar(new Date()); + } + + protected XMLGregorianCalendar getCalendar(Date time) { + XMLGregorianCalendar xgc = null; + GregorianCalendar gc = new GregorianCalendar(); + gc.setTime(time); + try { + xgc = DatatypeFactory.newInstance().newXMLGregorianCalendar(gc); + } catch (DatatypeConfigurationException e) { + System.err.println(e); + System.exit(1); + } + return xgc; + } + + protected XMLStreamWriter createXMLStreamWriter(Writer writer) throws XMLStreamException { + return this.factory.createXMLStreamWriter(writer); + } + + protected XMLStreamWriter createXMLStreamWriter(OutputStream stream) throws XMLStreamException { + return this.factory.createXMLStreamWriter(stream, "UTF-8"); + } + + protected void writeDocument(XMLStreamWriter out, Document document) throws XMLStreamException { + out.writeStartElement("doc"); + if (document.getId() != null) { + out.writeAttribute("id", document.getId()); + } + Coordinate systemCoord = document.getSystemCoord(); + if(systemCoord != null) { + out.writeAttribute("systemLat", systemCoord.getLatDegrees() + ""); + out.writeAttribute("systemLng", systemCoord.getLngDegrees() + ""); + } + Coordinate goldCoord = document.getGoldCoord(); + if(goldCoord != null) { + out.writeAttribute("goldLat", goldCoord.getLatDegrees() + ""); + out.writeAttribute("goldLng", goldCoord.getLngDegrees() + ""); + } + if(document.getTimestamp() != null) { + out.writeAttribute("timestamp", document.getTimestamp()); + } + for (Sentence sentence : document) { + this.writeSentence(out, sentence); + } + out.writeEndElement(); + } + + // BUGGY!!! Won't output multiword toponyms as such!!! + protected void writeSentence(XMLStreamWriter out, Sentence sentence) throws XMLStreamException { + out.writeStartElement("s"); + if (sentence.getId() != null) { + out.writeAttribute("id", sentence.getId()); + } + for (Token token : sentence) { + if (token.isToponym()) { + this.writeToponym(out, (Toponym) token); + } else { + this.writeToken(out, token); + } + } + out.writeEndElement(); + } + + private static String okChars = "!?:;,'\"|+=-_*^%$#@`~(){}[]\\/"; + + public static boolean isSanitary(/*Enum corpusFormat, */String s) { + //if(corpusFormat != BaseApp.CORPUS_FORMAT.GEOTEXT) + // return true; + for(int i = 0; i < s.length(); i++) { + char curChar = s.charAt(i); + if(!Character.isLetterOrDigit(curChar) && !okChars.contains(curChar + "")) { + return false; + } + } + return true; + } + + protected void writeToken(XMLStreamWriter out, Token token) throws XMLStreamException { + out.writeStartElement("w"); + if(isSanitary(/*corpus.getFormat(), */token.getOrigForm())) + out.writeAttribute("tok", token.getOrigForm()); + else + out.writeAttribute("tok", " "); + out.writeEndElement(); + } + + protected void writeToponym(XMLStreamWriter out, Toponym toponym) throws XMLStreamException { + out.writeStartElement("toponym"); + if(isSanitary(/*corpus.getFormat(), */toponym.getOrigForm())) + out.writeAttribute("term", toponym.getOrigForm()); + else + out.writeAttribute("term", " "); + out.writeStartElement("candidates"); + Location gold = toponym.hasGold() ? toponym.getGold() : null; + Location selected = toponym.hasSelected() ? toponym.getSelected() : null; + + for (Location location : toponym) { + this.writeLocation(out, location, gold, selected); + } + out.writeEndElement(); + out.writeEndElement(); + } + + protected void writeLocation(XMLStreamWriter out, Location location, Location gold, Location selected) throws XMLStreamException { + //location.removeNaNs(); + out.writeStartElement("cand"); + out.writeAttribute("id", String.format("c%d", location.getId())); + out.writeAttribute("lat", String.format("%f", location.getRegion().getCenter().getLatDegrees())); + out.writeAttribute("long", String.format("%f", location.getRegion().getCenter().getLngDegrees())); + out.writeAttribute("type", String.format("%s", location.getType())); + out.writeAttribute("admin1code", String.format("%s", location.getAdmin1Code())); + int population = location.getPopulation(); + if (population > 0) { + out.writeAttribute("population", String.format("%d", population)); + } + if (location == gold) { + out.writeAttribute("gold", "true"); + } + if (location == selected) { + out.writeAttribute("selected", "true"); + } + + out.writeStartElement("representatives"); + for(Coordinate coord : location.getRegion().getRepresentatives()) { + out.writeStartElement("rep"); + out.writeAttribute("lat", String.format("%f", coord.getLatDegrees())); + out.writeAttribute("long", String.format("%f", coord.getLngDegrees())); + out.writeEndElement(); + } + out.writeEndElement(); + + out.writeEndElement(); + } + + public void write(File file) { + this.write(file, "doc-"); + } + + public void write(File file, String prefix) { + try { + if (file.isDirectory()) { + int idx = 0; + for (Document document : this.corpus) { + File docFile = new File(file, String.format("%s%06d.xml", prefix, idx)); + OutputStream stream = new BufferedOutputStream(new FileOutputStream(docFile)); + XMLStreamWriter out = this.createXMLStreamWriter(stream); + out.writeStartDocument("UTF-8", "1.0"); + out.writeStartElement("corpus"); + out.writeAttribute("created", this.getCalendar().toString()); + this.writeDocument(out, document); + out.writeEndElement(); + out.close(); + stream.close(); + idx++; + } + } else { + OutputStream stream = new BufferedOutputStream(new FileOutputStream(file)); + this.write(this.createXMLStreamWriter(stream)); + stream.close(); + } + } catch (XMLStreamException e) { + System.err.println(e); + System.exit(1); + } catch (IOException e) { + System.err.println(e); + System.exit(1); + } + } + + public void write(OutputStream stream) { + try { + this.write(this.createXMLStreamWriter(stream)); + } catch (XMLStreamException e) { + System.err.println(e); + System.exit(1); + } + } + + public void write(Writer writer) { + try { + this.write(this.createXMLStreamWriter(writer)); + } catch (XMLStreamException e) { + System.err.println(e); + System.exit(1); + } + } + + protected void write(XMLStreamWriter out) { + try { + out.writeStartDocument("UTF-8", "1.0"); + out.writeStartElement("corpus"); + out.writeAttribute("created", this.getCalendar().toString()); + for (Document document : this.corpus) { + this.writeDocument(out, document); + } + out.writeEndElement(); + out.close(); + } catch (XMLStreamException e) { + System.err.println(e); + System.exit(1); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextCorpusKMLWriter.java b/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextCorpusKMLWriter.java new file mode 100644 index 0000000..7989218 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextCorpusKMLWriter.java @@ -0,0 +1,61 @@ +package opennlp.fieldspring.tr.text.io; + +import opennlp.fieldspring.tr.text.*; +import javax.xml.stream.*; +import opennlp.fieldspring.tr.util.*; +import opennlp.fieldspring.tr.topo.*; + +public class GeoTextCorpusKMLWriter extends CorpusKMLWriter { + public GeoTextCorpusKMLWriter(Corpus corpus, boolean outputGoldLocations) { + super(corpus, outputGoldLocations); + } + + public GeoTextCorpusKMLWriter(Corpus corpus) { + this(corpus, false); + } + + protected void writeDocument(XMLStreamWriter out, Document document) throws XMLStreamException { + Coordinate coord = outputGoldLocations ? document.getGoldCoord() : document.getSystemCoord(); + + KMLUtil.writePlacemark(out, document.getId(), coord, KMLUtil.RADIUS); + int sentIndex = 0; + for(Sentence sent : document) { + StringBuffer curTweetSB = new StringBuffer(); + for(Token token : sent) { + if(isSanitary(token.getOrigForm())) + curTweetSB.append(token.getOrigForm()).append(" "); + } + String curTweet = curTweetSB.toString().trim(); + + KMLUtil.writeSpiralPoint(out, document.getId(), + sentIndex, curTweet, + coord.getNthSpiralPoint(sentIndex, KMLUtil.SPIRAL_RADIUS), KMLUtil.RADIUS); + sentIndex++; + } + } + + private String okChars = "!?:;,'\"|+=-_*^%$#@`~(){}[]\\/"; + + private boolean isSanitary(String s) { + for(int i = 0; i < s.length(); i++) { + char curChar = s.charAt(i); + if(!Character.isLetterOrDigit(curChar) && !okChars.contains(curChar + "")) { + return false; + } + } + return true; + } + + protected void write(XMLStreamWriter out) throws Exception { + + KMLUtil.writeHeader(out, "corpus"); + + for(Document doc : corpus) { + writeDocument(out, doc); + } + + KMLUtil.writeFooter(out); + + out.close(); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextSource.java new file mode 100644 index 0000000..8290261 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/GeoTextSource.java @@ -0,0 +1,195 @@ +package opennlp.fieldspring.tr.text.io; + +import java.io.*; +import java.util.*; + +//import javax.xml.stream.XMLInputFactory; +//import javax.xml.stream.XMLStreamException; +//import javax.xml.stream.XMLStreamReader; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.util.*; + +public class GeoTextSource extends DocumentSource { + //private final XMLStreamReader in; + private final Tokenizer tokenizer; + private List documents; + private int curDocIndex = 0; + + + public GeoTextSource(Reader reader, Tokenizer tokenizer) throws Exception { + BufferedReader breader = new BufferedReader(reader); + this.tokenizer = tokenizer; + + this.documents = new ArrayList(); + + String curLine; + String prevDocId = "-1"; + GeoTextDocument curDoc = null; + int sentIndex = -1; + while(true) { + curLine = breader.readLine(); + if(curLine == null) + break; + + String[] tokens = curLine.split("\t"); + + if(tokens.length < 6) + continue; + + String docId = tokens[0]; + long userId = Long.parseLong(docId.substring(docId.indexOf("_")+1), 16); + + long fold = (userId % 5); + fold = fold==0? 5 : fold; + if(fold >= 1 && fold <= 4) { // reads train and dev set only + + if(!docId.equals(prevDocId)) { + curDoc = new GeoTextDocument(docId, tokens[1], + Double.parseDouble(tokens[3]), + Double.parseDouble(tokens[4]), + fold); + documents.add(curDoc); + sentIndex = -1; + } + prevDocId = docId; + + String rawSent = tokens[5]; + sentIndex++; + List wList = new ArrayList(); + + for(String w : tokenizer.tokenize(rawSent)) { + wList.add(new SimpleToken(w)); + } + + curDoc.addSentence(new SimpleSentence("" + sentIndex, wList)); + } + } + + //XMLInputFactory factory = XMLInputFactory.newInstance(); + //this.in = factory.createXMLStreamReader(reader); + + //while (this.in.hasNext() && this.in.next() != XMLStreamReader.START_ELEMENT) {} + //if (this.in.getLocalName().equals("corpus")) { + // this.in.nextTag(); + //} + } + + /*private void nextTag() { + try { + this.in.nextTag(); + } catch (XMLStreamException e) { + System.err.println("Error while advancing TR-XML file."); + } + } + + public void close() { + try { + this.in.close(); + } catch (XMLStreamException e) { + System.err.println("Error while closing TR-XML file."); + } + }*/ + + public Iterator iterator() { + return documents.iterator(); + } + + public boolean hasNext() { + return this.curDocIndex < this.documents.size();//this.in.isStartElement() && this.in.getLocalName().equals("doc"); + } + + public Document next() { + + return documents.get(curDocIndex++); + /*return new Document(id) { + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + + } + } + } + }*/ + /*assert this.in.isStartElement() && this.in.getLocalName().equals("doc"); + String id = TrXMLSource.this.in.getAttributeValue(null, "id"); + TrXMLSource.this.nextTag(); + + return new Document(id) { + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + if (TrXMLSource.this.in.isStartElement() && + TrXMLSource.this.in.getLocalName().equals("s")) { + return true; + } else { + return false; + } + } + + public Sentence next() { + String id = TrXMLSource.this.in.getAttributeValue(null, "id"); + List tokens = new ArrayList(); + List> toponymSpans = new ArrayList>(); + + try { + while (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + (TrXMLSource.this.in.getLocalName().equals("w") || + TrXMLSource.this.in.getLocalName().equals("toponym"))) { + String name = TrXMLSource.this.in.getLocalName(); + + if (name.equals("w")) { + tokens.add(new SimpleToken(TrXMLSource.this.in.getAttributeValue(null, "tok"))); + } else { + int spanStart = tokens.size(); + String form = TrXMLSource.this.in.getAttributeValue(null, "term"); + List formTokens = TrXMLSource.this.tokenizer.tokenize(form); + + for (String formToken : TrXMLSource.this.tokenizer.tokenize(form)) { + tokens.add(new SimpleToken(formToken)); + } + + ArrayList locations = new ArrayList(); + int goldIdx = -1; + + if (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + TrXMLSource.this.in.getLocalName().equals("candidates")) { + while (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + TrXMLSource.this.in.getLocalName().equals("cand")) { + String selected = TrXMLSource.this.in.getAttributeValue(null, "selected"); + if (selected != null && selected.equals("yes")) { + goldIdx = locations.size(); + } + + double lat = Double.parseDouble(TrXMLSource.this.in.getAttributeValue(null, "lat")); + double lng = Double.parseDouble(TrXMLSource.this.in.getAttributeValue(null, "long")); + Region region = new PointRegion(Coordinate.fromDegrees(lat, lng)); + locations.add(new Location(form, region)); + TrXMLSource.this.nextTag(); + assert TrXMLSource.this.in.isEndElement() && + TrXMLSource.this.in.getLocalName().equals("cand"); + } + } + + if (locations.size() > 0 && goldIdx > -1) { + Toponym toponym = new SimpleToponym(form, locations, goldIdx); + toponymSpans.add(new Span(spanStart, tokens.size(), toponym)); + } + } + TrXMLSource.this.nextTag(); + } + } catch (XMLStreamException e) { + System.err.println("Error while reading TR-XML file."); + } + + TrXMLSource.this.nextTag(); + return new SimpleSentence(id, tokens, toponymSpans); + } + }; + } + };*/ + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextDirSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextDirSource.java new file mode 100644 index 0000000..72b4733 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextDirSource.java @@ -0,0 +1,106 @@ +/** + * + */ +package opennlp.fieldspring.tr.text.io; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.FilenameFilter; +import java.io.IOException; +import java.util.Arrays; +import java.util.Vector; + +import javax.xml.stream.XMLStreamException; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.prep.SentenceDivider; +import opennlp.fieldspring.tr.text.prep.Tokenizer; + +/** + * @author abhimanu kumar + * + */ +public class PlainTextDirSource extends DocumentSource { + + private final Tokenizer tokenizer; + private final SentenceDivider divider; + private final Vector files; + private int currentIdx; + private PlainTextSource current; + + + public PlainTextDirSource(File directory, SentenceDivider divider, Tokenizer tokenizer) { + this.divider = divider; + this.tokenizer = tokenizer; + files=new Vector(); + FilenameFilter filter=new FilenameFilter() { + public boolean accept(File dir, String name) { + return name.endsWith(".txt"); + } + }; + listFiles(directory,filter); + + // this.files = files == null ? new File[0] : files; + // Arrays.sort(this.files); + + this.currentIdx = 0; + this.nextFile(); + } + + private void listFiles(File directory, FilenameFilter filter) { + File[] childrenTextFiles=directory.listFiles(filter); + for(File file : childrenTextFiles){ + if(file!=null && !file.isDirectory()) + files.add(file); + } + File[] childrenDir=directory.listFiles(); + for(File file:childrenDir){ + if(file.isDirectory()) + listFiles(file,filter); + } + return; + } + + private void nextFile() { + try { + if (this.current != null) { + this.current.close(); + } + if (this.currentIdx < this.files.size()) { + File currentFile = this.files.get(this.currentIdx); + this.current = new PlainTextSource(new BufferedReader(new FileReader(currentFile)), this.divider, this.tokenizer, currentFile.getName()); + } + } catch (IOException e) { + System.err.println("Error while reading text file "+this.files.get(this.currentIdx).getName()); + } + } + + public void close() { + if (this.current != null) { + this.current.close(); + } + } + + public boolean hasNext() { + if (this.currentIdx < this.files.size()) { + if (this.current.hasNext()) { + return true; + } else { + this.currentIdx++; + this.nextFile(); + return this.hasNext(); + } + } else { + return false; + } + } + + public Document next() { + return this.current.next(); + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextSource.java new file mode 100644 index 0000000..bbbd804 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/PlainTextSource.java @@ -0,0 +1,115 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io; + +import java.io.BufferedReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.SimpleToken; +import opennlp.fieldspring.tr.text.Token; +//import opennlp.fieldspring.tr.text.ner.NamedEntityRecognizer; +import opennlp.fieldspring.tr.text.prep.Tokenizer; +import opennlp.fieldspring.tr.text.prep.SentenceDivider; +//import opennlp.fieldspring.tr.topo.gaz.Gazetteer; + +public class PlainTextSource extends TextSource { + private final SentenceDivider divider; + private final Tokenizer tokenizer; + //private final NamedEntityRecognizer recognizer; + //private final Gazetteer gazetteer; + + private int number; + private String current; + private int currentIdx; + private int parasPerDocument; + private String docId; + + public PlainTextSource(BufferedReader reader, SentenceDivider divider, Tokenizer tokenizer, String id) + throws IOException { + this(reader, divider, tokenizer, id, -1); + } + + public PlainTextSource(BufferedReader reader, SentenceDivider divider, Tokenizer tokenizer, + String id, int parasPerDocument) + throws IOException { + super(reader); + this.divider = divider; + this.tokenizer = tokenizer; + //this.recognizer = recognizer; + //this.gazetteer = gazetteer; + this.current = this.readLine(); + this.number = 0; + this.currentIdx = 0; + this.parasPerDocument = parasPerDocument; + this.docId = id; + } + + public boolean hasNext() { + return this.current != null; + } + + public Document next() { + return new Document(this.docId) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + final List> sentences = new ArrayList>(); + while (PlainTextSource.this.current != null && + (PlainTextSource.this.parasPerDocument == -1 || + (PlainTextSource.this.current.length() > 0 && + PlainTextSource.this.number < PlainTextSource.this.parasPerDocument))) { + for (String sentence : PlainTextSource.this.divider.divide(PlainTextSource.this.current)) { + List tokens = new ArrayList(); + for (String token : PlainTextSource.this.tokenizer.tokenize(sentence)) { + tokens.add(new SimpleToken(token)); + } + sentences.add(tokens); + } + PlainTextSource.this.number++; + PlainTextSource.this.current = PlainTextSource.this.readLine(); + } + PlainTextSource.this.number = 0; + if (PlainTextSource.this.current != null && + PlainTextSource.this.current.length() == 0) { + PlainTextSource.this.current = PlainTextSource.this.readLine(); + } + + return new SentenceIterator() { + private int idx = 0; + + public boolean hasNext() { + return this.idx < sentences.size(); + } + + public Sentence next() { + final int idx = this.idx++; + return new Sentence(null) { + private static final long serialVersionUID = 42L; + public Iterator tokens() { + return sentences.get(idx).iterator(); + } + }; + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/TextSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/TextSource.java new file mode 100644 index 0000000..0f746f9 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/TextSource.java @@ -0,0 +1,52 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io; + +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.Token; + +public abstract class TextSource extends DocumentSource { + protected final BufferedReader reader; + + public TextSource(BufferedReader reader) throws IOException { + this.reader = reader; + } + + protected String readLine() { + String line = null; + try { + line = this.reader.readLine(); + } catch (IOException e) { + System.err.println("Error while reading document source."); + } + return line; + } + + public void close() { + try { + this.reader.close(); + } catch (IOException e) { + System.err.println("Error while closing document source."); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLDirSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLDirSource.java new file mode 100644 index 0000000..80a1280 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLDirSource.java @@ -0,0 +1,107 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FilenameFilter; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; + +import com.google.common.collect.Iterators; + +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.prep.Tokenizer; + +public class TrXMLDirSource extends DocumentSource { + private final Tokenizer tokenizer; + private final File[] files; + private int currentIdx; + private TrXMLSource current; + private int sentsPerDocument; + + public TrXMLDirSource(File directory, Tokenizer tokenizer) { + this(directory, tokenizer, -1); + } + + public TrXMLDirSource(File directory, Tokenizer tokenizer, int sentsPerDocument) { + this.tokenizer = tokenizer; + this.sentsPerDocument = sentsPerDocument; + File[] files = directory.listFiles(new FilenameFilter() { + public boolean accept(File dir, String name) { + return name.endsWith(".xml"); + } + }); + + this.files = files == null ? new File[0] : files; + Arrays.sort(this.files); + + this.currentIdx = 0; + this.nextFile(); + } + + private void nextFile() { + try { + if (this.current != null) { + this.current.close(); + } + if (this.currentIdx < this.files.length) { + this.current = new TrXMLSource(new BufferedReader(new FileReader(this.files[this.currentIdx])), this.tokenizer, this.sentsPerDocument); + } + } catch (XMLStreamException e) { + System.err.println("Error while reading TR-XML directory file."); + } catch (FileNotFoundException e) { + System.err.println("Error while reading TR-XML directory file."); + } + } + + public void close() { + if (this.current != null) { + this.current.close(); + } + } + + public boolean hasNext() { + if (this.currentIdx < this.files.length) { + if (this.current.hasNext()) { + return true; + } else { + this.currentIdx++; + this.nextFile(); + return this.hasNext(); + } + } else { + return false; + } + } + + public Document next() { + return this.current.next(); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLSource.java b/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLSource.java new file mode 100644 index 0000000..d2b397e --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/io/TrXMLSource.java @@ -0,0 +1,198 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io; + +import java.io.Reader; +import java.io.IOException; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamReader; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.SimpleSentence; +import opennlp.fieldspring.tr.text.SimpleToken; +import opennlp.fieldspring.tr.text.SimpleToponym; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; +import opennlp.fieldspring.tr.text.prep.Tokenizer; +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.PointRegion; +import opennlp.fieldspring.tr.topo.Region; +import opennlp.fieldspring.tr.util.Span; + +public class TrXMLSource extends DocumentSource { + private final XMLStreamReader in; + private final Tokenizer tokenizer; + private boolean corpusWrapped = false; + private int sentsPerDocument; + private int partOfDoc = 0; + private String curDocId; + + public TrXMLSource(Reader reader, Tokenizer tokenizer) throws XMLStreamException { + this(reader, tokenizer, -1); + } + + public TrXMLSource(Reader reader, Tokenizer tokenizer, int sentsPerDocument) throws XMLStreamException { + this.tokenizer = tokenizer; + this.sentsPerDocument = sentsPerDocument; + + XMLInputFactory factory = XMLInputFactory.newInstance(); + this.in = factory.createXMLStreamReader(reader); + + while (this.in.hasNext() && this.in.next() != XMLStreamReader.START_ELEMENT) {} + if (this.in.getLocalName().equals("corpus")) { + this.in.nextTag(); + this.corpusWrapped = true; + } + } + + private void nextTag() { + try { + this.in.nextTag(); + } catch (XMLStreamException e) { + System.err.println("Error while advancing TR-XML file."); + } + } + + public void close() { + try { + this.in.close(); + } catch (XMLStreamException e) { + System.err.println("Error while closing TR-XML file."); + } + } + + public boolean hasNext() { + //try { + if(this.in.isEndElement() && this.in.getLocalName().equals("doc") && this.corpusWrapped) { + this.nextTag(); + } + //} catch(XMLStreamException e) { + //System.err.println("Error while closing TR-XML file."); + //} + if(this.in.getLocalName().equals("doc")) + TrXMLSource.this.partOfDoc = 0; + return this.in.isStartElement() && (this.in.getLocalName().equals("doc") || this.in.getLocalName().equals("s")); + } + + public Document next() { + //assert this.in.isStartElement() && this.in.getLocalName().equals("doc"); + String id; + if(TrXMLSource.this.sentsPerDocument <= 0) + id = TrXMLSource.this.in.getAttributeValue(null, "id"); + else { + if(TrXMLSource.this.partOfDoc == 0) + TrXMLSource.this.curDocId = TrXMLSource.this.in.getAttributeValue(null, "id"); + + id = TrXMLSource.this.curDocId + ".p" + TrXMLSource.this.partOfDoc; + } + if(this.in.getLocalName().equals("doc")) + TrXMLSource.this.nextTag(); + + return new Document(id) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + int sentNumber = 0; + public boolean hasNext() { + if (TrXMLSource.this.in.isStartElement() && + TrXMLSource.this.in.getLocalName().equals("s") && + (TrXMLSource.this.sentsPerDocument <= 0 || + sentNumber < TrXMLSource.this.sentsPerDocument)) { + return true; + } else { + return false; + } + } + + public Sentence next() { + String id = TrXMLSource.this.in.getAttributeValue(null, "id"); + List tokens = new ArrayList(); + List> toponymSpans = new ArrayList>(); + + try { + while (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + (TrXMLSource.this.in.getLocalName().equals("w") || + TrXMLSource.this.in.getLocalName().equals("toponym"))) { + String name = TrXMLSource.this.in.getLocalName(); + + if (name.equals("w")) { + tokens.add(new SimpleToken(TrXMLSource.this.in.getAttributeValue(null, "tok"))); + } else { + int spanStart = tokens.size(); + String form = TrXMLSource.this.in.getAttributeValue(null, "term"); + List formTokens = TrXMLSource.this.tokenizer.tokenize(form); + + for (String formToken : TrXMLSource.this.tokenizer.tokenize(form)) { + tokens.add(new SimpleToken(formToken)); + } + + ArrayList locations = new ArrayList(); + int goldIdx = -1; + + if (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + TrXMLSource.this.in.getLocalName().equals("candidates")) { + while (TrXMLSource.this.in.nextTag() == XMLStreamReader.START_ELEMENT && + TrXMLSource.this.in.getLocalName().equals("cand")) { + String selected = TrXMLSource.this.in.getAttributeValue(null, "selected"); + if (selected != null && selected.equals("yes")) { + goldIdx = locations.size(); + } + + double lat = Double.parseDouble(TrXMLSource.this.in.getAttributeValue(null, "lat")); + double lng = Double.parseDouble(TrXMLSource.this.in.getAttributeValue(null, "long")); + Region region = new PointRegion(Coordinate.fromDegrees(lat, lng)); + locations.add(new Location(form, region)); + TrXMLSource.this.nextTag(); + assert TrXMLSource.this.in.isEndElement() && + TrXMLSource.this.in.getLocalName().equals("cand"); + } + } + + if (locations.size() > 0 && goldIdx > -1) { + Toponym toponym = new SimpleToponym(form, locations, goldIdx); + if(toponym.getGoldIdx() >= toponym.getCandidates().size()) + System.out.println(toponym.getForm()+": "+toponym.getGoldIdx()+"/"+toponym.getCandidates().size()); + toponymSpans.add(new Span(spanStart, tokens.size(), toponym)); + } + } + TrXMLSource.this.nextTag(); + } + } catch (XMLStreamException e) { + System.err.println("Error while reading TR-XML file."); + } + + TrXMLSource.this.nextTag(); + sentNumber++; + if(sentNumber == TrXMLSource.this.sentsPerDocument) + TrXMLSource.this.partOfDoc++; + return new SimpleSentence(id, tokens, toponymSpans); + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateAnnotator.java b/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateAnnotator.java new file mode 100644 index 0000000..1af3301 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateAnnotator.java @@ -0,0 +1,90 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.DocumentSourceWrapper; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.SimpleSentence; +import opennlp.fieldspring.tr.text.SimpleToponym; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.gaz.Gazetteer; +import opennlp.fieldspring.tr.util.Span; + +/** + * Wraps a document source, removes any toponym spans, and identifies toponyms + * using a named entity recognizer and a gazetteer. + * + * @author Travis Brown + * @version 0.1.0 + */ +public class CandidateAnnotator extends DocumentSourceWrapper { + private final Gazetteer gazetteer; + + public CandidateAnnotator(DocumentSource source, + Gazetteer gazetteer) { + super(source); + this.gazetteer = gazetteer; + } + + public Document next() { + final Document document = this.getSource().next(); + final Iterator> sentences = document.iterator(); + + return new Document(document.getId()) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + return sentences.hasNext(); + } + + public Sentence next() { + Sentence sentence = sentences.next(); + List tokens = sentence.getTokens(); + + List> toponymSpans = new ArrayList>(); + + Iterator> spans = sentence.toponymSpans(); + while (spans.hasNext()) { + Span span = /*(Span)*/ spans.next(); + Toponym toponym = (Toponym) span.getItem(); + String form = toponym.getOrigForm(); + + List candidates = CandidateAnnotator.this.gazetteer.lookup(form.toLowerCase()); + Toponym newToponym = toponym; + if (candidates != null) { + newToponym = new SimpleToponym(form, candidates); + } + toponymSpans.add(new Span(span.getStart(), span.getEnd(), newToponym)); + } + + return new SimpleSentence(sentence.getId(), tokens, toponymSpans); + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateRepopulator.java b/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateRepopulator.java new file mode 100644 index 0000000..006e8c1 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/CandidateRepopulator.java @@ -0,0 +1,61 @@ +package opennlp.fieldspring.tr.text.prep; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.DocumentSourceWrapper; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.SimpleSentence; +import opennlp.fieldspring.tr.text.SimpleToponym; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; +import opennlp.fieldspring.tr.topo.gaz.Gazetteer; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.util.Span; + + +public class CandidateRepopulator extends DocumentSourceWrapper { + + private final Gazetteer gazetteer; + + public CandidateRepopulator(DocumentSource source, Gazetteer gazetteer) { + super(source); + this.gazetteer = gazetteer; + } + + public Document next() { + final Document document = this.getSource().next(); + final Iterator> sentences = document.iterator(); + + return new Document(document.getId()) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + return sentences.hasNext(); + } + + public Sentence next() { + Sentence sentence = sentences.next(); + for(Token token : sentence) { + if(token.isToponym()) { + Toponym toponym = (Toponym) token; + List candidates = gazetteer.lookup(toponym.getForm()); + if(candidates == null) candidates = new ArrayList(); + toponym.setCandidates(candidates); + toponym.setGoldIdx(-1); + } + } + return sentence; + //return new SimpleSentence(sentence.getId(), sentence.getTokens()); + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/HighRecallToponymRecognizer.java b/src/main/java/opennlp/fieldspring/tr/text/prep/HighRecallToponymRecognizer.java new file mode 100644 index 0000000..ce02e5a --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/HighRecallToponymRecognizer.java @@ -0,0 +1,251 @@ +/** + * + */ +package opennlp.fieldspring.tr.text.prep; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.zip.GZIPInputStream; + +import opennlp.fieldspring.tr.topo.gaz.GeoNamesGazetteer; +import opennlp.fieldspring.tr.util.Constants; +import opennlp.fieldspring.tr.util.Span; +import opennlp.tools.namefind.NameFinderME; +import opennlp.tools.namefind.TokenNameFinder; +import opennlp.tools.namefind.TokenNameFinderModel; +import opennlp.tools.util.HashList; +import opennlp.tools.util.InvalidFormatException; +import scala.actors.threadpool.Arrays; + +/** + * @author abhimanu kumar + * + */ +public class HighRecallToponymRecognizer extends OpenNLPRecognizer { + private HashMap nameHashMap; + private HashMap stringInternHashMap; + private TokenNameFinder personFinder; + private TokenNameFinder orgFinder; + private Pattern splitPattern; + private Pattern historicalPattern; + private static final String spaceRegEx = " "; + private static final String historyRegEx = "(historical)"; + private boolean flagStart = false; + private int stringCount=Integer.MIN_VALUE; + private int lineCount=0; + + public HighRecallToponymRecognizer(GeoNamesGazetteer gnGz) throws IOException, InvalidFormatException { + super(); + getCleanedNameSet(gnGz.getUniqueLocationNameSet()); + getNLPModels(); + + } + + + + private void getNLPModels() throws InvalidFormatException, IOException { + personFinder=new NameFinderME(new TokenNameFinderModel(new FileInputStream( + Constants.getOpenNLPModelsDir() + File.separator + "en-ner-person.bin"))); + orgFinder=new NameFinderME(new TokenNameFinderModel(new FileInputStream( + Constants.getOpenNLPModelsDir() + File.separator + "en-ner-organization.bin"))); + } + + + + private void getCleanedNameSet(Set keySet) { +// System.out.println("Formatting Locations..."); + this.nameHashMap = new HashMap(500000); + this.stringInternHashMap = new HashMap(100000); + compilePatterns(); + for (Iterator iterator = keySet.iterator(); iterator.hasNext();) { + String toponym = (String) iterator.next(); + toponym = historicalPattern.matcher(toponym).replaceAll(""); + + + if(!nameHashMap.containsKey(toponym)){ + String[] toponymTokens = splitPattern.split(toponym); + int xorNum = 0; + ArrayList tokenIntList = new ArrayList(); + for (int i = 0; i < toponymTokens.length; i++) { + String token = toponymTokens[i];//.trim(); + if(stringInternHashMap.containsKey(token)){ + tokenIntList.add(stringInternHashMap.get(token)); + }else{ + stringInternHashMap.put(token, stringCount); + tokenIntList.add(stringCount); + stringCount++; + } + } + nameHashMap.put(toponym,tokenIntList.toArray(new Integer[0])); + } + } +// System.out.println(nameHashMap.size()); +// System.out.println(stringInternHashMap.size()); + } + + + + private void compilePatterns() { + splitPattern = Pattern.compile(spaceRegEx); + historicalPattern = Pattern.compile(historyRegEx); + } + + + + public HighRecallToponymRecognizer(String gazPath) throws Exception{ + super(); + GZIPInputStream gis; + ObjectInputStream ois; + GeoNamesGazetteer gnGaz = null; + gis = new GZIPInputStream(new FileInputStream(gazPath)); + ois = new ObjectInputStream(gis); + gnGaz = (GeoNamesGazetteer) ois.readObject(); + getCleanedNameSet(gnGaz.getUniqueLocationNameSet()); + getNLPModels(); + } + + public HighRecallToponymRecognizer(Set uniqueLocationNameSet) throws IOException, InvalidFormatException { + super(); + getCleanedNameSet(uniqueLocationNameSet); + getNLPModels(); + } + + + + public List> recognize(List tokens) { + if(!flagStart){ + System.out.print("\nRaw Corpus: Searching for Toponyms "); + flagStart=true; + } + if(lineCount==1000){ + System.out.print("."); + lineCount=0; + } + lineCount++; + List> spans = new ArrayList>(); + String[] tokensToBeLookedArray = (String[]) Arrays.copyOf(tokens.toArray(),tokens.toArray().length,String[].class); + for (opennlp.tools.util.Span span : this.finder.find(tokens.toArray(new String[0]))) { + spans.add(new Span(span.getStart(), span.getEnd(), this.type)); + + for (int i = span.getStart(); i < span.getEnd(); i++) { + tokensToBeLookedArray[i]=" "; + } + } + + for (opennlp.tools.util.Span span : this.personFinder.find(tokens.toArray(new String[0]))) { + for (int i = span.getStart(); i < span.getEnd(); i++) { +// System.out.println("PERSON "+tokensToBeLookedArray[i]); + tokensToBeLookedArray[i]=" "; + } + } + + for (opennlp.tools.util.Span span : this.orgFinder.find(tokens.toArray(new String[0]))) { + for (int i = span.getStart(); i < span.getEnd(); i++) { +// System.out.println("ORG "+tokensToBeLookedArray[i]); + tokensToBeLookedArray[i]=" "; + } + } + + for (int i = 0; i < tokensToBeLookedArray.length; i++) { + String token = tokensToBeLookedArray[i]; + if(token.length()==1) + continue;; + if(startsWithCaps(token)){ + boolean matched=true; + int k=0; + for (Iterator iterator = nameHashMap.keySet().iterator(); iterator.hasNext();) { + String toponymTokens = ((String) iterator.next()); + + int toponymLength=nameHashMap.get(toponymTokens).length; + if(toponymLength>tokensToBeLookedArray.length-i) + continue; + + matched=true; + + ArrayList suspectedTokenSet = new ArrayList(); + for (int j = 0; j < toponymLength; j++) { + if(!startsWithCaps(tokensToBeLookedArray[j+i])){ + matched=false; + break; + } + suspectedTokenSet.add(stringInternHashMap.get(tokensToBeLookedArray[j+i].toLowerCase())); + + } + if(!stringMatch(suspectedTokenSet.toArray(new Integer[0]),nameHashMap.get(toponymTokens))){ + matched=false; + continue; + } + if(matched){ + k=i+toponymLength-1; + break; + } + } + if(matched){ +// String startToken = tokensToBeLookedArray[i]; +// String endToken = tokensToBeLookedArray[k]; + spans.add(new Span(i, k+1, this.type)); + i=k; +// System.out.println(startToken+endToken); + } + } + } + return spans; + } + + + private boolean stringMatch(Integer[] suspectedSet,Integer[] toponymSet) { + for (int i = 0; i < suspectedSet.length; i++) { + if(suspectedSet[i]!=toponymSet[i]) + return false; + } + return true; + } + + private boolean stringMatch(StringBuilder tokenCombined, + StringBuilder toponymCombined) { + if(tokenCombined.length()!=toponymCombined.length()) + return false; + for (int i = 0; i < tokenCombined.length(); i++) { + if(tokenCombined.charAt(i)!=toponymCombined.charAt(i)) + return false; + } + return true; + } + + + + private boolean startsWithCaps(String tobeLooked) { + return new Integer('A')<=new Integer(tobeLooked.charAt(0)) && new Integer(tobeLooked.charAt(0))<=new Integer('Z'); + } + + +} + +class HashString { + + private String string; + private int hash; + + HashString(String string ,int hash){ + this.string=string; + this.hash=hash; + } + + public String toString(){ + return string; + } + + public int hashCode() { + return hash; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/JythonNER.java b/src/main/java/opennlp/fieldspring/tr/text/prep/JythonNER.java new file mode 100644 index 0000000..a05eadf --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/JythonNER.java @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import javax.script.ScriptEngine; +import javax.script.ScriptException; +import java.util.ArrayList; +import java.util.List; + +import opennlp.fieldspring.tr.util.Span; + +public class JythonNER extends ScriptNER { + public JythonNER(String name, NamedEntityType type) { + super("python", name, type); + } + + public JythonNER(String name) { + this(name, NamedEntityType.LOCATION); + } + + public List> recognize(List tokens) { + ScriptEngine engine = this.getEngine(); + engine.put("tokens", tokens); + + try { + engine.eval("spans = recognize(tokens)"); + } catch (ScriptException e) { + return null; + } + + List> tuples = (List>) engine.get("spans"); + List> spans = + new ArrayList>(tuples.size()); + + for (List tuple : tuples) { + spans.add(new Span(tuple.get(0), tuple.get(1), this.getType())); + } + + return spans; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityRecognizer.java b/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityRecognizer.java new file mode 100644 index 0000000..98b836a --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityRecognizer.java @@ -0,0 +1,25 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.List; + +import opennlp.fieldspring.tr.util.Span; + +public interface NamedEntityRecognizer { + public List> recognize(List tokens); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityType.java b/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityType.java new file mode 100644 index 0000000..354e0ea --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/NamedEntityType.java @@ -0,0 +1,27 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +public enum NamedEntityType { + DATE, + LOCATION, + MONEY, + ORGANIZATION, + PERCENTAGE, + PERSON, + TIME; +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPRecognizer.java b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPRecognizer.java new file mode 100644 index 0000000..5575c3a --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPRecognizer.java @@ -0,0 +1,58 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.tools.namefind.NameFinderME; +import opennlp.tools.namefind.TokenNameFinder; +import opennlp.tools.namefind.TokenNameFinderModel; +import opennlp.tools.util.InvalidFormatException; + +import opennlp.fieldspring.tr.util.Constants; +import opennlp.fieldspring.tr.util.Span; + +public class OpenNLPRecognizer implements NamedEntityRecognizer { + protected final TokenNameFinder finder; + protected final NamedEntityType type; + + public OpenNLPRecognizer() throws IOException, InvalidFormatException { + this(new FileInputStream( + Constants.getOpenNLPModelsDir() + File.separator + "en-ner-location.bin"), + NamedEntityType.LOCATION); + } + + public OpenNLPRecognizer(InputStream in, NamedEntityType type) + throws IOException, InvalidFormatException { + this.finder = new NameFinderME(new TokenNameFinderModel(in)); + this.type = type; + } + + public List> recognize(List tokens) { + List> spans = new ArrayList>(); + for (opennlp.tools.util.Span span : this.finder.find(tokens.toArray(new String[0]))) { + spans.add(new Span(span.getStart(), span.getEnd(), this.type)); + } + return spans; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPSentenceDivider.java b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPSentenceDivider.java new file mode 100644 index 0000000..04f6e4f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPSentenceDivider.java @@ -0,0 +1,47 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; + +import opennlp.tools.sentdetect.SentenceDetector; +import opennlp.tools.sentdetect.SentenceDetectorME; +import opennlp.tools.sentdetect.SentenceModel; +import opennlp.tools.util.InvalidFormatException; + +import opennlp.fieldspring.tr.util.Constants; + +public class OpenNLPSentenceDivider implements SentenceDivider { + private final SentenceDetector detector; + + public OpenNLPSentenceDivider() throws IOException, InvalidFormatException { + this(new FileInputStream(Constants.getOpenNLPModelsDir() + File.separator + "en-sent.bin")); + } + + public OpenNLPSentenceDivider(InputStream in) throws IOException, InvalidFormatException { + this.detector = new SentenceDetectorME(new SentenceModel(in)); + } + + public List divide(String text) { + return Arrays.asList(this.detector.sentDetect(text)); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPTokenizer.java b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPTokenizer.java new file mode 100644 index 0000000..b347beb --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/OpenNLPTokenizer.java @@ -0,0 +1,46 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; + +import opennlp.tools.tokenize.TokenizerME; +import opennlp.tools.tokenize.TokenizerModel; +import opennlp.tools.util.InvalidFormatException; + +import opennlp.fieldspring.tr.util.Constants; + +public class OpenNLPTokenizer implements Tokenizer { + private final opennlp.tools.tokenize.Tokenizer tokenizer; + + public OpenNLPTokenizer() throws IOException, InvalidFormatException { + this(new FileInputStream(Constants.getOpenNLPModelsDir() + File.separator + "en-token.bin")); + } + + public OpenNLPTokenizer(InputStream in) throws IOException, InvalidFormatException { + this.tokenizer = new TokenizerME(new TokenizerModel(in)); + } + + public List tokenize(String text) { + return Arrays.asList(this.tokenizer.tokenize(text)); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/ScriptNER.java b/src/main/java/opennlp/fieldspring/tr/text/prep/ScriptNER.java new file mode 100644 index 0000000..ec61410 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/ScriptNER.java @@ -0,0 +1,73 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import javax.script.ScriptEngine; +import javax.script.ScriptEngineManager; +import javax.script.ScriptException; + +public abstract class ScriptNER implements NamedEntityRecognizer { + private final String language; + private final String name; + private final NamedEntityType type; + private final ScriptEngine engine; + + /** + * Constructor for classes that use the JSR-223 scripting engine to perform + * named entity recognition. + * + * @param language The JSR-223 name of the scripting language + * @param name The path to the resource containing the script + * @param type The kind of named entity that is recognized + */ + public ScriptNER(String language, String name, NamedEntityType type) { + this.language = language; + this.name = name; + this.type = type; + + ScriptEngineManager manager = new ScriptEngineManager(); + this.engine = manager.getEngineByName(this.language); + + try { + InputStream stream = ScriptNER.class.getResourceAsStream(this.name); + InputStreamReader reader = new InputStreamReader(stream); + this.engine.eval(reader); + stream.close(); + } catch (ScriptException e) { + System.err.println(e); + System.exit(1); + } catch (IOException e) { + System.err.println(e); + System.exit(1); + } + } + + public ScriptNER(String language, String name) { + this(language, name, NamedEntityType.LOCATION); + } + + protected ScriptEngine getEngine() { + return this.engine; + } + + protected NamedEntityType getType() { + return this.type; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/SentenceDivider.java b/src/main/java/opennlp/fieldspring/tr/text/prep/SentenceDivider.java new file mode 100644 index 0000000..9497d3c --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/SentenceDivider.java @@ -0,0 +1,23 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.List; + +public interface SentenceDivider { + public List divide(String text); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/Tokenizer.java b/src/main/java/opennlp/fieldspring/tr/text/prep/Tokenizer.java new file mode 100644 index 0000000..bba2364 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/Tokenizer.java @@ -0,0 +1,23 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.List; + +public interface Tokenizer { + public List tokenize(String text); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymAnnotator.java b/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymAnnotator.java new file mode 100644 index 0000000..5ac9a83 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymAnnotator.java @@ -0,0 +1,121 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.util.*; + +/** + * Wraps a document source, removes any toponym spans, and identifies toponyms + * using a named entity recognizer and a gazetteer. + * + * @author Travis Brown + * @version 0.1.0 + */ +public class ToponymAnnotator extends DocumentSourceWrapper { + private final NamedEntityRecognizer recognizer; + private final Gazetteer gazetteer; + private final Region boundingBox; + + public ToponymAnnotator(DocumentSource source, + NamedEntityRecognizer recognizer, + Gazetteer gazetteer) { + this(source, recognizer, gazetteer, null); + } + + public ToponymAnnotator(DocumentSource source, + NamedEntityRecognizer recognizer, + Gazetteer gazetteer, + Region boundingBox) { + super(source); + this.recognizer = recognizer; + this.gazetteer = gazetteer; + this.boundingBox = boundingBox; + } + + public Document next() { + final Document document = this.getSource().next(); + final Iterator> sentences = document.iterator(); + + return new Document(document.getId(), document.getTimestamp(), document.getGoldCoord(), document.getSystemCoord(), document.getSection(), document.title) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + return sentences.hasNext(); + } + + public Sentence next() { + Sentence sentence = sentences.next(); + List forms = new ArrayList(); + List tokens = sentence.getTokens(); + + for (Token token : tokens) { + forms.add(token.getOrigForm()); + } + + List> spans = ToponymAnnotator.this.recognizer.recognize(forms); + List> toponymSpans = new ArrayList>(); + + for (Span span : spans) { + if (span.getItem() == NamedEntityType.LOCATION) { + StringBuilder builder = new StringBuilder(); + for (int i = span.getStart(); i < span.getEnd(); i++) { + builder.append(forms.get(i)); + if (i < span.getEnd() - 1) { + builder.append(" "); + } + } + + String form = builder.toString(); + //List candidates = ToponymAnnotator.this.gazetteer.lookup(form.toLowerCase()); + List candidates = TopoUtil.filter( + ToponymAnnotator.this.gazetteer.lookup(form.toLowerCase()), boundingBox); + if(candidates != null) { + for(Location loc : candidates) { + List reps = loc.getRegion().getRepresentatives(); + int prevSize = reps.size(); + Coordinate.removeNaNs(reps); + if(reps.size() < prevSize) + loc.getRegion().setCenter(Coordinate.centroid(reps)); + } + } + //if(form.equalsIgnoreCase("united states")) + // for(Location loc : ToponymAnnotator.this.gazetteer.lookup("united states")) + // System.out.println(loc.getRegion().getCenter()); + if (candidates != null) { + Toponym toponym = new SimpleToponym(form, candidates); + //if(form.equalsIgnoreCase("united states")) + // System.out.println(toponym.getCandidates().get(0).getRegion().getCenter()); + toponymSpans.add(new Span(span.getStart(), span.getEnd(), toponym)); + } + } + } + + return new SimpleSentence(sentence.getId(), tokens, toponymSpans); + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymRemover.java b/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymRemover.java new file mode 100644 index 0000000..32ed228 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/text/prep/ToponymRemover.java @@ -0,0 +1,67 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.prep; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.text.Corpus; +import opennlp.fieldspring.tr.text.Document; +import opennlp.fieldspring.tr.text.DocumentSource; +import opennlp.fieldspring.tr.text.DocumentSourceWrapper; +import opennlp.fieldspring.tr.text.Sentence; +import opennlp.fieldspring.tr.text.SimpleSentence; +import opennlp.fieldspring.tr.text.SimpleToponym; +import opennlp.fieldspring.tr.text.Token; +import opennlp.fieldspring.tr.text.Toponym; +import opennlp.fieldspring.tr.topo.gaz.Gazetteer; +import opennlp.fieldspring.tr.util.Span; + +/** + * Wraps a document source and removes any toponyms spans that it contains, + * returning only the tokens. + * + * @author Travis Brown + * @version 0.1.0 + */ +public class ToponymRemover extends DocumentSourceWrapper { + public ToponymRemover(DocumentSource source) { + super(source); + } + + public Document next() { + final Document document = this.getSource().next(); + final Iterator> sentences = document.iterator(); + + return new Document(document.getId()) { + private static final long serialVersionUID = 42L; + public Iterator> iterator() { + return new SentenceIterator() { + public boolean hasNext() { + return sentences.hasNext(); + } + + public Sentence next() { + Sentence sentence = sentences.next(); + return new SimpleSentence(sentence.getId(), sentence.getTokens()); + } + }; + } + }; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/Coordinate.java b/src/main/java/opennlp/fieldspring/tr/topo/Coordinate.java new file mode 100644 index 0000000..8088f67 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/Coordinate.java @@ -0,0 +1,174 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.util.List; +import java.util.ArrayList; +import java.io.Serializable; + +import opennlp.fieldspring.tr.util.FastTrig; + +public class Coordinate implements Serializable { + + private static final long serialVersionUID = 42L; + + private final double lng; + private final double lat; + + public Coordinate(double lat, double lng) { + this.lng = lng; + this.lat = lat; + } + + public static Coordinate fromRadians(double lat, double lng) { + return new Coordinate(lat, lng); + } + + public static Coordinate fromDegrees(double lat, double lng) { + return new Coordinate(lat * Math.PI / 180.0, lng * Math.PI / 180.0); + } + + public double getLat() { + return this.lat; + } + + public double getLng() { + return this.lng; + } + + public double getLatDegrees() { + return this.lat * 180.0 / Math.PI; + } + + public double getLngDegrees() { + return this.lng * 180.0 / Math.PI; + } + + /** + * Compare two coordinates to see if they're sufficiently close together. + * @param other The other coordinate being compared + * @param maxDiff Both lat and lng must be within this value + * @return Whether the two coordinates are sufficiently close + */ + public boolean looselyMatches(Coordinate other, double maxDiff) { + return Math.abs(this.lat - other.lat) <= maxDiff && + Math.abs(this.lng - other.lng) <= maxDiff; + } + + /** + * Generate a new Coordinate that is the `n'th point along a spiral + * radiating outward from the given coordinate. `initRadius' controls where + * on the spiral the zeroth point is located. The constant local variable + * `radianUnit' controls the spacing of the points (FIXME, should be an + * optional parameter). The radius of the spiral increases by 1/10 (FIXME, + * should be controllable) of `initRadius' every point. + * + * @param n + * How far along the spiral to return a coordinate for + * @param initRadius + * Where along the spiral the 0th point is located; this also + * controls how quickly the spiral grows outward + * @return A new coordinate along the spiral + */ + public Coordinate getNthSpiralPoint(int n, double initRadius) { + if (n == 0) { + return this; + } + + final double radianUnit = Math.PI / 10.0; + double radius = initRadius + (initRadius * 0.1) * n; + double angle = radianUnit / 2.0 + 1.1 * radianUnit * n; + + double newLatDegrees = this.getLatDegrees() + radius * Math.cos(angle); + double newLngDegrees = this.getLngDegrees() + radius * Math.sin(angle); + + return new Coordinate(newLatDegrees * Math.PI / 180.0, newLngDegrees * Math.PI / 180.0); + } + + public String toString() { + return String.format("%.02f,%.02f", this.getLatDegrees(), this.getLngDegrees()); + } + + public double distance(Coordinate other) { + if(this.lat == other.lat && this.lng == other.lng) + return 0; + return Math.acos(Math.sin(this.lat) * Math.sin(other.lat) + + Math.cos(this.lat) * Math.cos(other.lat) * Math.cos(other.lng - this.lng)); + } + + public double distanceInKm(Coordinate other) { + return 6372.8 * this.distance(other); + } + + public double distanceInMi(Coordinate other) { + return .621371 * this.distanceInKm(other); + } + + /** + * Compute the approximate centroid by taking the average of the latitudes + * and longitudes. + */ + public static Coordinate centroid(List coordinates) { + double latSins = 0.0; + double latCoss = 0.0; + double lngSins = 0.0; + double lngCoss = 0.0; + + for (int i = 0; i < coordinates.size(); i++) { + latSins += Math.sin(coordinates.get(i).getLat()); + latCoss += Math.cos(coordinates.get(i).getLat()); + lngSins += Math.sin(coordinates.get(i).getLng()); + lngCoss += Math.cos(coordinates.get(i).getLng()); + } + + latSins /= coordinates.size(); + latCoss /= coordinates.size(); + lngSins /= coordinates.size(); + lngCoss /= coordinates.size(); + + double lat = Math.atan2(latSins, latCoss); + double lng = Math.atan2(lngSins, lngCoss); + + return Coordinate.fromRadians(lat, lng); + } + + public static List removeNaNs(List coordinates) { + List toReturn = new ArrayList(); + for(Coordinate coord : coordinates) { + if(!(Double.isNaN(coord.getLatDegrees()) || Double.isNaN(coord.getLngDegrees()))) { + toReturn.add(coord); + } + } + return toReturn; + } + + @Override + public boolean equals(Object other) { + return other != null && + other.getClass() == this.getClass() && + ((Coordinate) other).lat == this.lat && + ((Coordinate) other).lng == this.lng; + } + + @Override + public int hashCode() { + int hash = 3; + hash = 29 * hash + (int) (Double.doubleToLongBits(this.lng) ^ (Double.doubleToLongBits(this.lng) >>> 32)); + hash = 29 * hash + (int) (Double.doubleToLongBits(this.lat) ^ (Double.doubleToLongBits(this.lat) >>> 32)); + return hash; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/Location.java b/src/main/java/opennlp/fieldspring/tr/topo/Location.java new file mode 100644 index 0000000..e693466 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/Location.java @@ -0,0 +1,205 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.io.Serializable; +import java.util.*; + +public class Location implements Serializable { + + private static final long serialVersionUID = 42L; + + public enum Type { + STATE, WATER, CITY, SITE, PARK, TRANSPORT, MOUNTAIN, UNDERSEA, FOREST, UNKNOWN + } + + private final int id; + private final String name; + private Region region; + private Location.Type type; + private int population; + private String admin1code; + private double threshold; + + public Location(int id, String name, Region region, String typeString, Integer population, String admin1code) { + this(id, name, region, Location.convertTypeString(typeString), + population==null?0:population, admin1code==null?"00":admin1code, 10.0); + } + + public Location(String idWithC, String name, Region region, String typeString, Integer population, String admin1code) { + this(Integer.parseInt(idWithC.substring(1)), name, region, typeString, population, admin1code); + } + + public Location(int id, String name, Region region, Location.Type type, int population, String admin1code, double threshold) { + this.id = id; + this.name = name; + this.region = region; + this.type = type; + this.population = population; + this.admin1code = admin1code; + this.threshold = threshold; + } + + public Location(int id, String name, Region region, Location.Type type, int population) { + this(id, name, region, type, population, "00", 10.0); + } + + public Location(int id, String name, Region region, Location.Type type) { + this(id, name, region, type, 0); + } + + public Location(int id, String name, Region region) { + this(id, name, region, Location.Type.UNKNOWN); + } + + public Location(String name, Region region) { + this(-1, name, region); + } + + public Location(String name, Region region, Location.Type type) { + this(-1, name, region, type); + } + + public Location(String name, Region region, Location.Type type, int population) { + this(-1, name, region, type, population); + } + + /*public void removeNaNs() { + List reps = this.getRegion().getRepresentatives(); + int prevSize = reps.size(); + reps = Coordinate.removeNaNs(reps); + if(reps.size() < prevSize) { + this.getRegion().setRepresentatives(reps); + System.out.println("Recalculating centroid"); + this.getRegion().setCenter(Coordinate.centroid(reps)); + } + }*/ + + public int getId() { + return this.id; + } + + public String getName() { + return this.name; + } + + public Region getRegion() { + return this.region; + } + + public void setRegion(Region region) { + this.region = region; + } + + public void setType(Location.Type type) { + this.type = type; + } + + public Location.Type getType() { + return this.type; + } + + public void setThreshold(double threshold) { + this.threshold = threshold; + } + + public double getThreshold() { + return this.threshold; + } + + public void recomputeThreshold() { + // Commented out for now since experiments with this didn't perform well + /*if(this.getRegion().getRepresentatives().size() > 1) { + //int count = 0; + double minDist = Double.POSITIVE_INFINITY; + for(int i = 0; i < this.getRegion().getRepresentatives().size(); i++) { + for(int j = 0; j < this.getRegion().getRepresentatives().size(); j++) { + if(i != j) { + double dist = this.getRegion().getRepresentatives().get(i).distanceInKm( + this.getRegion().getRepresentatives().get(j)); + if(dist < minDist) + minDist = dist; + //count++; + } + } + } + //dist /= count; + this.setThreshold(minDist / 2); + }*/ + } + + public static Location.Type convertTypeString(String typeString) { + typeString = typeString.toUpperCase(); + if(typeString.equals("STATE")) + return Location.Type.STATE; + if(typeString.equals("WATER")) + return Location.Type.WATER; + if(typeString.equals("CITY")) + return Location.Type.CITY; + if(typeString.equals("SITE")) + return Location.Type.SITE; + if(typeString.equals("PARK")) + return Location.Type.PARK; + if(typeString.equals("TRANSPORT")) + return Location.Type.TRANSPORT; + if(typeString.equals("MOUNTAIN")) + return Location.Type.MOUNTAIN; + if(typeString.equals("UNDERSEA")) + return Location.Type.UNDERSEA; + if(typeString.equals("FOREST")) + return Location.Type.FOREST; + else + return Location.Type.UNKNOWN; + + } + + public int getPopulation() { + return this.population; + } + + public String getAdmin1Code() { + return admin1code; + } + + public double distance(Location other) { + return this.getRegion().distance(other.getRegion()); + } + + public double distanceInKm(Location other) { + return this.getRegion().distanceInKm(other.getRegion()); + } + + @Override + public String toString() { + return String.format("%8d (%s), %s, (%s), %d", this.id, this.name, this.type, this.region.getCenter(), this.population); + } + + /** + * Two Locations are the same if they have the same class and same ID. + */ + @Override + public boolean equals(Object other) { + return other != null && + other.getClass() == this.getClass() && + ((Location) other).id == this.id; + } + + @Override + public int hashCode() { + return this.id; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/PointRegion.java b/src/main/java/opennlp/fieldspring/tr/topo/PointRegion.java new file mode 100644 index 0000000..267b5bb --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/PointRegion.java @@ -0,0 +1,69 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.util.ArrayList; +import java.util.List; + +public class PointRegion extends Region { + + private static final long serialVersionUID = 42L; + + private Coordinate coordinate; + + public PointRegion(Coordinate coordinate) { + this.coordinate = coordinate; + } + + public Coordinate getCenter() { + return this.coordinate; + } + + public void setCenter(Coordinate coord) { + this.coordinate = coord; + } + + public boolean contains(double lat, double lng) { + return lat == this.coordinate.getLat() && lng == this.coordinate.getLng(); + } + + public double getMinLat() { + return this.coordinate.getLat(); + } + + public double getMaxLat() { + return this.coordinate.getLat(); + } + + public double getMinLng() { + return this.coordinate.getLng(); + } + + public double getMaxLng() { + return this.coordinate.getLng(); + } + + public List getRepresentatives() { + List representatives = new ArrayList(1); + representatives.add(this.coordinate); + return representatives; + } + + public void setRepresentatives(List coordinates) { + this.coordinate = coordinates.get(0); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/PointSetRegion.java b/src/main/java/opennlp/fieldspring/tr/topo/PointSetRegion.java new file mode 100644 index 0000000..060dca6 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/PointSetRegion.java @@ -0,0 +1,107 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.util.List; + +public class PointSetRegion extends Region { + + private static final long serialVersionUID = 42L; + + private List coordinates; + private Coordinate center; + private final double minLat; + private final double maxLat; + private final double minLng; + private final double maxLng; + + public PointSetRegion(List coordinates) { + this.coordinates = coordinates; + this.center = Coordinate.centroid(this.coordinates); + + double minLat = Double.POSITIVE_INFINITY; + double maxLat = Double.NEGATIVE_INFINITY; + double minLng = Double.POSITIVE_INFINITY; + double maxLng = Double.NEGATIVE_INFINITY; + + for (Coordinate coordinate : this.coordinates) { + double lat = coordinate.getLat(); + double lng = coordinate.getLng(); + + if (lat < minLat) minLat = lat; + if (lat > maxLat) maxLat = lat; + if (lng < minLng) minLng = lng; + if (lng > maxLng) maxLng = lng; + } + + this.minLat = minLat; + this.maxLat = maxLat; + this.minLng = minLng; + this.maxLng = maxLng; + } + + public Coordinate getCenter() { + return this.center; + } + + public void setCenter(Coordinate coord) { + this.center = coord; + } + + public boolean contains(double lat, double lng) { + return lat == this.center.getLat() && lng == this.center.getLng(); + } + + public double getMinLat() { + return this.minLat; + } + + public double getMaxLat() { + return this.maxLat; + } + + public double getMinLng() { + return this.minLng; + } + + public double getMaxLng() { + return this.maxLng; + } + + public List getRepresentatives() { + return this.coordinates; + } + + public void setRepresentatives(List representatives) { + this.coordinates = representatives; + this.center = Coordinate.centroid(representatives); + } + + @Override + public double distance(Coordinate coordinate) { + double minDistance = Double.POSITIVE_INFINITY; + + for (Coordinate representative : this.coordinates) { + double distance = representative.distance(coordinate); + if (distance < minDistance) { + minDistance = distance; + } + } + + return minDistance; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/RectRegion.java b/src/main/java/opennlp/fieldspring/tr/topo/RectRegion.java new file mode 100644 index 0000000..987f170 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/RectRegion.java @@ -0,0 +1,109 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.util.ArrayList; +import java.util.List; + +public class RectRegion extends Region { + + private static final long serialVersionUID = 42L; + + private final double minLat; + private final double maxLat; + private final double minLng; + private final double maxLng; + + private RectRegion(double minLat, double maxLat, double minLng, double maxLng) { + this.minLat = minLat; + this.maxLat = maxLat; + this.minLng = minLng; + this.maxLng = maxLng; + } + + public static RectRegion fromRadians(double minLat, double maxLat, double minLng, double maxLng) { + return new RectRegion(minLat, maxLat, minLng, maxLng); + } + + public static RectRegion fromDegrees(double minLat, double maxLat, double minLng, double maxLng) { + return new RectRegion(minLat * Math.PI / 180.0, + maxLat * Math.PI / 180.0, + minLng * Math.PI / 180.0, + maxLng * Math.PI / 180.0); + } + + /** + * Returns the average of the minimum and maximum values for latitude and + * longitude. Should be changed to avoid problems around zero degrees. + */ + public Coordinate getCenter() { + return Coordinate.fromRadians((this.maxLat + this.minLat) / 2.0, + (this.maxLng + this.minLng) / 2.0); + } + + public void setCenter(Coordinate coord) { + } + + public boolean contains(double lat, double lng) { + if(this.minLng <= this.maxLng) { + return lat >= this.minLat && + lat <= this.maxLat && + lng >= this.minLng && + lng <= this.maxLng; + } + // for boxes around 180/-180 longitude: + return (lat >= minLat && lat <= maxLat) && + ((lng >= this.minLat && lng <= 180) || + (lng >= -180 && lng <= this.maxLat)); + } + + public double getMinLat() { + return this.minLat; + } + + public double getMaxLat() { + return this.maxLat; + } + + public double getMinLng() { + return this.minLng; + } + + public double getMaxLng() { + return this.maxLng; + } + + public List getRepresentatives() { + List representatives = new ArrayList(4); + representatives.add(Coordinate.fromRadians(this.minLat, this.minLng)); + representatives.add(Coordinate.fromRadians(this.maxLat, this.minLng)); + representatives.add(Coordinate.fromRadians(this.maxLat, this.maxLng)); + representatives.add(Coordinate.fromRadians(this.minLat, this.maxLng)); + return representatives; + } + + public void setRepresentatives(List coordinates) { + System.err.println("Warning: can't set representatives of RectRegion."); + } + + public String toString() { + return "lat: [" + (minLat*180.0/Math.PI) + ", " + + (maxLat*180.0/Math.PI) + "] lon: [" + + (minLng*180.0/Math.PI) + ", " + + (maxLng*180.0/Math.PI) + "]"; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/Region.java b/src/main/java/opennlp/fieldspring/tr/topo/Region.java new file mode 100644 index 0000000..0354978 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/Region.java @@ -0,0 +1,96 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo; + +import java.util.List; +import java.io.Serializable; + +public abstract class Region implements Serializable { + + private static final long serialVersionUID = 42L; + + public abstract Coordinate getCenter(); + public abstract void setCenter(Coordinate coord); + public abstract boolean contains(double lat, double lng); + + public abstract double getMinLat(); + public abstract double getMaxLat(); + public abstract double getMinLng(); + public abstract double getMaxLng(); + public abstract List getRepresentatives(); + public abstract void setRepresentatives(List coordinates); + + public boolean contains(Coordinate coordinate) { + return this.contains(coordinate.getLat(), coordinate.getLng()); + } + + public double getMinLatDegrees() { + return this.getMinLat() * Math.PI / 180.0; + } + + public double getMaxLatDegrees() { + return this.getMaxLat() * Math.PI / 180.0; + } + + public double getMinLngDegrees() { + return this.getMinLng() * Math.PI / 180.0; + } + + public double getMaxLngDegrees() { + return this.getMaxLng() * Math.PI / 180.0; + } + + public double distance(Region other) { + //return this.distance(other.getCenter()); + double minDist = Double.POSITIVE_INFINITY; + for(Coordinate coord : this.getRepresentatives()) { + for(Coordinate otherCoord : other.getRepresentatives()) { + double curDist = coord.distance(otherCoord); + if(curDist < minDist) + minDist = curDist; + } + } + return minDist; + } + + public double distance(Coordinate coordinate) { + //return this.getCenter().distance(coordinate); + double minDist = Double.POSITIVE_INFINITY; + for(Coordinate coord : this.getRepresentatives()) { + double curDist = coord.distance(coordinate); + if(curDist < minDist) + minDist = curDist; + } + return minDist; + } + + public double distanceInKm(Region other) { + if(this.getRepresentatives().size() == 1 && other.getRepresentatives().size() == 1) + return this.getCenter().distanceInKm(other.getCenter()); + //return this.distanceInKm(other.getCenter()); + /*double minDist = Double.POSITIVE_INFINITY; + for(Coordinate coord : this.getRepresentatives()) { + for(Coordinate otherCoord : other.getRepresentatives()) { + double curDist = coord.distanceInKm(otherCoord); + if(curDist < minDist) + minDist = curDist; + } + } + return minDist;*/ + return 6372.8 * this.distance(other); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/CandidateList.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/CandidateList.java new file mode 100644 index 0000000..94b27bb --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/CandidateList.java @@ -0,0 +1,131 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; + +/** + * Represents a list of candidates for a given toponym. + * + * @author Travis Brown + */ +public abstract class CandidateList implements Iterable { + /** + * The size of the list. + */ + public abstract int size(); + + /** + * The location at a given index. + */ + public abstract Location get(int index); + + /** + * Iterate over the candidates in the list in order of increasing distance + * from a given point. This should be overridden in subclasses where the + * structure of the gazetteer makes a more efficient implementation + * possible. + */ + public Iterator getNearest(Coordinate point) { + List items = new ArrayList(this.size()); + for (int i = 0; i < this.size(); i++) { + items.add(new SortedItem(i, this.get(i).getRegion().distance(point))); + } + Collections.sort(items); + return items.iterator(); + } + + /** + * Represents an item in the list sorted with respect to distance from a + * given coordinate. + */ + public class SortedItem implements Comparable { + private final int index; + private final double distance; + + private SortedItem(int index, double distance) { + this.index = index; + this.distance = distance; + } + + /** + * The index of this candidate in the list. + */ + public int getIndex() { + return this.index; + } + + /** + * The distance between this candidate and the given point. + */ + public double getDistance() { + return this.distance; + } + + /** + * A convenience method providing direct access to the location object for + * this candidate. + */ + public Location getLocation() { + return CandidateList.this.get(this.index); + } + + /** + * Returns a point in the region for the location coordinate. When + * possible for region locations this should be the point in the region + * nearest to the given point (in which case this implementation should be + * overridden). + */ + public Coordinate getCoordinate() { + return CandidateList.this.get(this.index).getRegion().getCenter(); + } + + /** + * We need to be able to sort by distance, with the lower index coming + * first in the case that the distances are the same. + */ + public int compareTo(SortedItem other) { + double diff = this.getDistance() - other.getDistance(); + if (diff < 0.0) { + return -1; + } else if (diff > 0.0) { + return 1; + } else { + return this.getIndex() - other.getIndex(); + } + } + } + + /** + * A convenience class that we can extend in subclasses to avoid repeatedly + * implementing the unsupported removal operation. + */ + private abstract class CandidateListIterator { + public void remove() { + throw new UnsupportedOperationException( + "Cannot remove a location from a candidate list." + ); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/FilteredGeoNamesReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/FilteredGeoNamesReader.java new file mode 100644 index 0000000..3fbf70c --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/FilteredGeoNamesReader.java @@ -0,0 +1,49 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStreamReader; +import java.io.IOException; +import java.io.BufferedReader; +import java.util.zip.GZIPInputStream; + +import opennlp.fieldspring.tr.topo.Location; + +public class FilteredGeoNamesReader extends GeoNamesReader { + public FilteredGeoNamesReader(File file) throws FileNotFoundException, IOException { + this(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))))); + } + + public FilteredGeoNamesReader(BufferedReader reader) + throws FileNotFoundException, IOException { + super(reader); + } + + protected Location parseLine(String line, int currentId) { + Location location = super.parseLine(line, currentId); + if (location != null) { + Location.Type type = location.getType(); + if (type != Location.Type.STATE && type != Location.Type.CITY) { + location = null; + } + } + return location; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/Gazetteer.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/Gazetteer.java new file mode 100644 index 0000000..c34f097 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/Gazetteer.java @@ -0,0 +1,34 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.List; + +import opennlp.fieldspring.tr.topo.Location; + +/** + * Represents a mapping from toponym strings to lists of location candidates. + * + * @author Travis Brown + */ +public interface Gazetteer { + /** + * Lookup a toponym in the gazetteer, returning null if no candidate list is + * found. + */ + public List lookup(String query); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerFileReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerFileReader.java new file mode 100644 index 0000000..e2c0b21 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerFileReader.java @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.Iterator; +import java.util.zip.GZIPInputStream; + +public abstract class GazetteerFileReader extends GazetteerReader { + private final BufferedReader reader; + + protected GazetteerFileReader(BufferedReader reader) + throws FileNotFoundException, IOException { + this.reader = reader; + } + + protected String readLine() { + String line = null; + try { + line = this.reader.readLine(); + } catch (IOException e) { + System.err.format("Error while reading gazetteer file: %s\n", e); + e.printStackTrace(); + } + return line; + } + + public void close() { + try { + this.reader.close(); + } catch (IOException e) { + System.err.format("Error closing gazetteer file: %s\n", e); + e.printStackTrace(); + System.exit(1); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerLineReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerLineReader.java new file mode 100644 index 0000000..78d5ed7 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerLineReader.java @@ -0,0 +1,59 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Iterator; + +import opennlp.fieldspring.tr.topo.Location; + +public abstract class GazetteerLineReader extends GazetteerFileReader { + private Location current; + private int currentId; + + protected GazetteerLineReader(BufferedReader reader) + throws FileNotFoundException, IOException { + super(reader); + this.current = this.nextLocation(); + this.currentId = 1; + } + + protected abstract Location parseLine(String line, int currentId); + + private Location nextLocation() { + Location location = null; + for (String line = this.readLine(); line != null; line = this.readLine()) { + location = this.parseLine(line, this.currentId); + if (location != null) break; + } + this.currentId++; + //if (this.currentId % 50000 == 0) { System.out.format("At location id: %d.\n", this.currentId); } + return location; + } + + public boolean hasNext() { + return this.current != null; + } + + public Location next() { + Location location = this.current; + this.current = this.nextLocation(); + return location; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerReader.java new file mode 100644 index 0000000..bf40687 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GazetteerReader.java @@ -0,0 +1,41 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.Iterator; +import opennlp.fieldspring.tr.topo.Location; + +public abstract class GazetteerReader implements Iterable, + Iterator { + public abstract void close(); + + protected Location.Type getLocationType(String code) { + return Location.Type.UNKNOWN; + } + + protected Location.Type getLocationType(String code, String fine) { + return this.getLocationType(code); + } + + public Iterator iterator() { + return this; + } + + public void remove() { + throw new UnsupportedOperationException("Cannot remove location from gazetteer."); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteer.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteer.java new file mode 100644 index 0000000..f59d4ab --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteer.java @@ -0,0 +1,386 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.*; +import java.util.*; + +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.PointRegion; +import opennlp.fieldspring.tr.topo.PointSetRegion; +import opennlp.fieldspring.tr.topo.Region; +import opennlp.fieldspring.tr.topo.SphericalGeometry; +import opennlp.fieldspring.tr.util.cluster.Clusterer; +import opennlp.fieldspring.tr.util.cluster.KMeans; +import opennlp.fieldspring.tr.util.TopoUtil; + +public class GeoNamesGazetteer implements Gazetteer, Serializable { + /** + * + */ + private static final long serialVersionUID = 1L; + private final boolean expandRegions; + private final double pointRatio; + private final int minPoints; + private final int maxPoints; + private final int maxConsidered; + + private final List locations; + private final Map> names; + private final Map ipes; + private final Map adms; + private final Map> ipePoints; + private final Map> admPoints; + + public GeoNamesGazetteer(BufferedReader reader) throws IOException { + this(reader, true, -1.0); + } + + public GeoNamesGazetteer(BufferedReader reader, boolean expandRegions) throws IOException { + this(reader, expandRegions, -1.0); + } + + public GeoNamesGazetteer(BufferedReader reader, boolean expandRegions, int kPoints) + throws IOException { + this(reader, expandRegions, 1.0, kPoints, kPoints); + } + + public GeoNamesGazetteer(BufferedReader reader, boolean expandRegions, double pointRatio) + throws IOException { + this(reader, expandRegions, pointRatio, 5, 30); + } + + public GeoNamesGazetteer(BufferedReader reader, boolean expandRegions, double pointRatio, int minPoints, int maxPoints) + throws IOException { + this(reader, expandRegions, pointRatio, minPoints, maxPoints, 3000); + } + + public GeoNamesGazetteer(BufferedReader reader, boolean expandRegions, double pointRatio, int minPoints, int maxPoints, int maxConsidered) + throws IOException { + this.expandRegions = expandRegions; + this.pointRatio = pointRatio; + this.minPoints = minPoints; + this.maxPoints = maxPoints; + this.maxConsidered = maxConsidered; + + this.locations = new ArrayList(); + this.names = new HashMap>(); + this.ipes = new HashMap(); + this.adms = new HashMap(); + this.ipePoints = new HashMap>(); + this.admPoints = new HashMap>(); + + this.load(reader); + if (this.expandRegions) { + this.expandRegionsHelper(this.ipes, this.ipePoints);//this.expandIPE(); + this.expandRegionsHelper(this.adms, this.admPoints);//this.expandADM(); + } + } + + private boolean ignore(String cat, String type) { + return (cat.equals("H") || cat.equals("L") || cat.equals("S") || cat.equals("U") || cat.equals("V") + || cat.equals("R") || cat.equals("T")); + } + + private boolean store(String cat, String type) { + return true; + } + + private void expandRegionsHelper(Map regions, Map> regionPoints) { + Clusterer clusterer = new KMeans(); + + System.out.println("Selecting points for " + regions.size() + " regions."); + for (String region : regions.keySet()) { + Location location = this.locations.get(regions.get(region)); + List contained = regionPoints.get(region);// ALL points in e.g. USA + + int k = 0; + + if(this.pointRatio > 0) { + k = (int) Math.floor(contained.size() * this.pointRatio); + if (k < this.minPoints) { + k = this.minPoints; + } + if (k > this.maxPoints) { + k = this.maxPoints; + } + } + + if(contained.size() > this.maxConsidered) { + Collections.shuffle(contained); + contained = contained.subList(0, this.maxConsidered); + } + + if(this.pointRatio <= 0) { + Set cellsOverlapped = new HashSet(); + for(Coordinate coord : contained) + cellsOverlapped.add(TopoUtil.getCellNumber(coord, 1.0)); + k = cellsOverlapped.size();// / 4; + if(k < 1) k = 1; + System.out.println(location.getName() + " " + k); + } + + if (contained.size() > 0) { + List representatives = clusterer.clusterList(contained, k, SphericalGeometry.g()); + representatives = Coordinate.removeNaNs(representatives); + location.setRegion(new PointSetRegion(representatives)); + location.recomputeThreshold(); + } + } + } + + + /*private void expandIPE() { + Clusterer clusterer = new KMeans(); + + System.out.println("Selecting points for " + this.ipes.size() + " independent political entities."); + for (String ipe : this.ipes.keySet()) { + Location location = this.locations.get(this.ipes.get(ipe)); + List contained = this.ipePoints.get(ipe);// ALL points in e.g. USA + + int k = 0; + + if(this.pointRatio > 0) { + k = (int) Math.floor(contained.size() * this.pointRatio); + if (k < this.minPoints) { + k = this.minPoints; + } + if (k > this.maxPoints) { + k = this.maxPoints; + } + } + else { + Set cellsOverlapped = new HashSet(); + for(Coordinate coord : contained) + cellsOverlapped.add(TopoUtil.getCellNumber(coord, 1.0)); + k = cellsOverlapped.size(); + } + + //System.err.format("Clustering: %d points for %s.\n", k, location.getName()); + + if (contained.size() > this.maxConsidered) { + Collections.shuffle(contained); + contained = contained.subList(0, this.maxConsidered); + } + + if (contained.size() > 0) { + List representatives = clusterer.clusterList(contained, k, SphericalGeometry.g()); + representatives = Coordinate.removeNaNs(representatives); + location.setRegion(new PointSetRegion(representatives)); + location.recomputeThreshold(); + //contained.clear(); + //contained = null; + } + //this.ipePoints.get(ipe).clear(); + } + //this.ipePoints.clear(); + //this.ipePoints = null; + } + + private void expandADM() { + Clusterer clusterer = new KMeans(); + + System.out.println("Selecting points for " + this.adms.size() + " administrative regions."); + for (String adm : this.adms.keySet()) { + Location location = this.locations.get(this.adms.get(adm)); + List contained = this.admPoints.get(adm); + + if (contained != null) { + int k = (int) Math.floor(contained.size() * this.pointRatio); + if (k < this.minPoints) { + k = this.minPoints; + } + if (k > this.maxPoints) { + k = this.maxPoints; + } + + //System.err.format("Clustering: %d points for %s.\n", k, location.getName()); + + if (contained.size() > this.maxConsidered) { + Collections.shuffle(contained); + contained = contained.subList(0, this.maxConsidered); + } + + if (contained.size() > 0) { + List representatives = clusterer.clusterList(contained, k, SphericalGeometry.g()); + representatives = Coordinate.removeNaNs(representatives); + location.setRegion(new PointSetRegion(representatives)); + location.recomputeThreshold(); + } + } + } + }*/ + + private String standardize(String name) { + return name.toLowerCase().replace("’", "'"); + } + + private int load(BufferedReader reader) { + int index = 0; + int count = 0; + try { + System.out.print("["); + for (String line = reader.readLine(); + line != null; line = reader.readLine()) { + String[] fields = line.split("\t"); + if (fields.length > 14) { + String primaryName = fields[1]; + count++; + if(count % 750000 == 0) { + System.out.print("."); + } + Set nameSet = new HashSet(); + nameSet.add(this.standardize(primaryName)); + + String[] names = fields[3].split(","); + for (int i = 0; i < names.length; i++) { + nameSet.add(this.standardize(names[i])); + } + + String cat = fields[6]; + String type = fields[7]; + + if (this.ignore(cat, type)) { + continue; + } + + String ipe = fields[8]; + String adm = ipe + fields[10]; + + String admin1code = ipe + "." + fields[10]; + + double lat = 0.0; + double lng = 0.0; + try { + lat = Double.parseDouble(fields[4]); + lng = Double.parseDouble(fields[5]); + } catch (NumberFormatException e) { + System.err.format("Invalid coordinates: %s\n", primaryName); + } + + //if(primaryName.equalsIgnoreCase("united states")) + // System.out.println(lat + ", " + lng); + + // try to get coordinates from right side in the case of weird characters in names messing up tabs between fields: + if((Double.isNaN(lat) || Double.isNaN(lng)) && fields.length > 19) { + try { + lat = Double.parseDouble(fields[fields.length-15]); + lng = Double.parseDouble(fields[fields.length-14]); + } catch (NumberFormatException e) { + System.err.format("Invalid coordinates: %s\n", primaryName); + } + } + + //if(primaryName.equalsIgnoreCase("united states")) + // System.out.println(lat + ", " + lng); + + // give up on trying to get coordinates: + if(Double.isNaN(lat) || Double.isNaN(lng)) + continue; + + Coordinate coordinate = Coordinate.fromDegrees(lat, lng); + + int population = 0; + if (fields[14].length() > 0) { + try { + population = Integer.parseInt(fields[14]); + } catch (NumberFormatException e) { + System.err.format("Invalid population: %s\n", primaryName); + } + } + + if (!this.ipePoints.containsKey(ipe)) { + this.ipePoints.put(ipe, new ArrayList()); + } + this.ipePoints.get(ipe).add(coordinate); + + if (!this.admPoints.containsKey(adm)) { + this.admPoints.put(adm, new ArrayList()); + } + this.admPoints.get(adm).add(coordinate); + + if (type.equals("PCLI")) { + this.ipes.put(ipe, index); + } else if (type.equals("ADM1")) { + this.adms.put(adm, index); + } + + if (this.store(cat, type)) { + Region region = new PointRegion(coordinate); + Location location = new Location(index, primaryName, region, this.getLocationType(cat), population, admin1code, + 10.0); + this.locations.add(location); + + for (String name : nameSet) { + if (!this.names.containsKey(name)) { + this.names.put(name, new ArrayList()); + } + this.names.get(name).add(location); + } + + index += 1; + } + } + } + System.out.println("]"); + reader.close(); + } catch (IOException e) { + System.err.format("Error while reading GeoNames file: %s\n", e); + e.printStackTrace(); + } + + return index; + } + + private Location.Type getLocationType(String cat) { + Location.Type type = Location.Type.UNKNOWN; + if (cat.length() > 0) { + if (cat.equals("A")) { + type = Location.Type.STATE; + } else if (cat.equals("H")) { + type = Location.Type.WATER; + } else if (cat.equals("L")) { + type = Location.Type.PARK; + } else if (cat.equals("P")) { + type = Location.Type.CITY; + } else if (cat.equals("R")) { + type = Location.Type.TRANSPORT; + } else if (cat.equals("S")) { + type = Location.Type.SITE; + } else if (cat.equals("T")) { + type = Location.Type.MOUNTAIN; + } else if (cat.equals("U")) { + type = Location.Type.UNDERSEA; + } else if (cat.equals("V")) { + type = Location.Type.FOREST; + } + } + return type; + } + + /** + * Lookup a toponym in the gazetteer, returning null if no candidate list is + * found. + */ + public List lookup(String query) { + return this.names.get(query.toLowerCase()); + } + + public Set getUniqueLocationNameSet(){ + return names.keySet(); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteerWithList.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteerWithList.java new file mode 100644 index 0000000..bcbe312 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesGazetteerWithList.java @@ -0,0 +1,331 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.*; +import java.util.*; + +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.PointRegion; +import opennlp.fieldspring.tr.topo.PointSetRegion; +import opennlp.fieldspring.tr.topo.Region; +import opennlp.fieldspring.tr.topo.SphericalGeometry; +import opennlp.fieldspring.tr.util.cluster.Clusterer; +import opennlp.fieldspring.tr.util.cluster.KMeans; +import opennlp.fieldspring.tr.util.*; + +public class GeoNamesGazetteerWithList implements Gazetteer, Serializable { + + private static final long serialVersionUID = 42L; + + private final boolean expandRegions; + private final double pointRatio; + private final int minPoints; + private final int maxPoints; + private final int maxConsidered; + + private List locations; + //private final Map> names; + private Map ipes; + //private final Map adms; + private Map> ipePoints; // made mutable so can assign to null when done for faster GC + //private final Map> admPoints; + + private final Lexicon lexicon; + private final List > mainGaz; // instead of names + + public GeoNamesGazetteerWithList(BufferedReader reader) throws IOException { + this(reader, true, 0.005); + } + + public GeoNamesGazetteerWithList(BufferedReader reader, boolean expandRegions) throws IOException { + this(reader, expandRegions, 0.005); + } + + public GeoNamesGazetteerWithList(BufferedReader reader, boolean expandRegions, int kPoints) + throws IOException { + this(reader, expandRegions, 1.0, kPoints, kPoints); + } + + public GeoNamesGazetteerWithList(BufferedReader reader, boolean expandRegions, double pointRatio) + throws IOException { + this(reader, expandRegions, pointRatio, 5, 30); + } + + public GeoNamesGazetteerWithList(BufferedReader reader, boolean expandRegions, double pointRatio, int minPoints, int maxPoints) + throws IOException { + this(reader, expandRegions, pointRatio, minPoints, maxPoints, 2000); + } + + public GeoNamesGazetteerWithList(BufferedReader reader, boolean expandRegions, double pointRatio, int minPoints, int maxPoints, int maxConsidered) + throws IOException { + this.expandRegions = expandRegions; + this.pointRatio = pointRatio; + this.minPoints = minPoints; + this.maxPoints = maxPoints; + this.maxConsidered = maxConsidered; + + this.locations = new ArrayList(); + //this.names = new HashMap>(); + this.ipes = new HashMap(); + //this.adms = new HashMap(); + this.ipePoints = new HashMap>(); + //this.admPoints = new HashMap>(); + + this.lexicon = new SimpleLexicon(); + this.mainGaz = new ArrayList >(); + + this.load(reader); + if (this.expandRegions) { + this.expandIPE(); + //this.expandADM(); + } + } + + private boolean ignore(String cat, String type) { + return (cat.equals("H") || cat.equals("L") || cat.equals("S") || cat.equals("U") || cat.equals("V")); + } + + private boolean store(String cat, String type) { + return true; + } + + private void expandIPE() { + Clusterer clusterer = new KMeans(); + + System.out.println("Selecting points for " + this.ipes.size() + " independent political entities."); + for (String ipe : this.ipes.keySet()) { + Location location = this.locations.get(this.ipes.get(ipe)); + List contained = this.ipePoints.get(ipe);// ALL points in e.g. USA + + int k = (int) Math.floor(contained.size() * this.pointRatio); + if (k < this.minPoints) { + k = this.minPoints; + } + if (k > this.maxPoints) { + k = this.maxPoints; + } + + //System.err.format("Clustering: %d points for %s.\n", k, location.getName()); + + if (contained.size() > this.maxConsidered) { + Collections.shuffle(contained); + contained = contained.subList(0, this.maxConsidered); + } + + if (contained.size() > 0) { + List representatives = clusterer.clusterList(contained, k, SphericalGeometry.g()); + location.setRegion(new PointSetRegion(representatives)); + contained.clear(); + contained = null; + } + this.ipePoints.get(ipe).clear(); + } + this.ipePoints.clear(); + this.ipePoints = null; + this.ipes.clear(); + this.ipes = null; + this.locations.clear(); + this.locations = null; + System.gc(); + } + + /*private void expandADM() { + Clusterer clusterer = new KMeans(); + + System.out.println("Selecting points for " + this.adms.size() + " administrative regions."); + for (String adm : this.adms.keySet()) { + Location location = this.locations.get(this.adms.get(adm)); + List contained = this.admPoints.get(adm); + + if (contained != null) { + int k = (int) Math.floor(contained.size() * this.pointRatio); + if (k < this.minPoints) { + k = this.minPoints; + } + if (k > this.maxPoints) { + k = this.maxPoints; + } + + //System.err.format("Clustering: %d points for %s.\n", k, location.getName()); + + if (contained.size() > this.maxConsidered) { + Collections.shuffle(contained); + contained = contained.subList(0, this.maxConsidered); + } + + if (contained.size() > 0) { + List representatives = clusterer.clusterList(contained, k, SphericalGeometry.g()); + location.setRegion(new PointSetRegion(representatives)); + + /*for (Coordinate c : representatives) { + System.out.println("" + + c.getLngDegrees() + "," + c.getLatDegrees() + + ""); + }*SLASH + } + } + } + }*/ + + private String standardize(String name) { + return name.toLowerCase().replace("’", "'"); + } + + private int load(BufferedReader reader) { + int index = 0; + int count = 0; + try { + System.out.print("["); + for (String line = reader.readLine(); + line != null; line = reader.readLine()) { + String[] fields = line.split("\t"); + if (fields.length > 14) { + String primaryName = fields[1]; + count++; + if(count % 750000 == 0) { + System.out.print("."); + } + Set nameSet = new HashSet(); + nameSet.add(this.standardize(primaryName)); + + String[] names = fields[3].split(","); + for (int i = 0; i < names.length; i++) { + nameSet.add(this.standardize(names[i])); + } + + String cat = fields[6]; + String type = fields[7]; + + if (this.ignore(cat, type)) { + continue; + } + + String ipe = fields[8]; + String adm = ipe + fields[10]; + + String admin1code = ipe + "." + fields[10]; + + double lat = 0.0; + double lng = 0.0; + try { + lat = Double.parseDouble(fields[4]); + lng = Double.parseDouble(fields[5]); + } catch (NumberFormatException e) { + System.err.format("Invalid coordinates: %s\n", primaryName); + } + Coordinate coordinate = Coordinate.fromDegrees(lat, lng); + + int population = 0; + if (fields[14].length() > 0) { + try { + population = Integer.parseInt(fields[14]); + } catch (NumberFormatException e) { + System.err.format("Invalid population: %s\n", primaryName); + } + } + + if (!this.ipePoints.containsKey(ipe)) { + this.ipePoints.put(ipe, new ArrayList()); + } + this.ipePoints.get(ipe).add(coordinate); + + /*if (!this.admPoints.containsKey(adm)) { + this.admPoints.put(adm, new ArrayList()); + } + this.admPoints.get(adm).add(coordinate);*/ + + if (type.equals("PCLI")) { + this.ipes.put(ipe, index); + } //else if (type.equals("ADM1")) { + //this.adms.put(adm, index); + //} + + if (this.store(cat, type)) { + Region region = new PointRegion(coordinate); + Location location = new Location(index, primaryName, region, this.getLocationType(cat), population, admin1code, + 10.0); + this.locations.add(location); + + for (String name : nameSet) { + + int idx = this.lexicon.getOrAdd(name); + + while(this.mainGaz.size() < idx+1) { + this.mainGaz.add(new ArrayList()); + } + + this.mainGaz.get(idx).add(location); + + /*if (!this.names.containsKey(name)) { + this.names.put(name, new ArrayList()); + } + this.names.get(name).add(location);*/ + } + + index += 1; + } + } + } + System.out.println("]"); + reader.close(); + } catch (IOException e) { + System.err.format("Error while reading GeoNames file: %s\n", e); + e.printStackTrace(); + } + + return index; + } + + private Location.Type getLocationType(String cat) { + Location.Type type = Location.Type.UNKNOWN; + if (cat.length() > 0) { + if (cat.equals("A")) { + type = Location.Type.STATE; + } else if (cat.equals("H")) { + type = Location.Type.WATER; + } else if (cat.equals("L")) { + type = Location.Type.PARK; + } else if (cat.equals("P")) { + type = Location.Type.CITY; + } else if (cat.equals("R")) { + type = Location.Type.TRANSPORT; + } else if (cat.equals("S")) { + type = Location.Type.SITE; + } else if (cat.equals("T")) { + type = Location.Type.MOUNTAIN; + } else if (cat.equals("U")) { + type = Location.Type.UNDERSEA; + } else if (cat.equals("V")) { + type = Location.Type.FOREST; + } + } + return type; + } + + /** + * Lookup a toponym in the gazetteer, returning null if no candidate list is + * found. + */ + public List lookup(String query) { + //return this.names.get(query.toLowerCase()); + int idx = this.lexicon.get(query.toLowerCase()); + if(idx == -1) + return null; + return this.mainGaz.get(idx); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesReader.java new file mode 100644 index 0000000..67c1c1c --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/GeoNamesReader.java @@ -0,0 +1,91 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStreamReader; +import java.io.IOException; +import java.io.BufferedReader; +import java.util.zip.GZIPInputStream; + +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.PointRegion; +import opennlp.fieldspring.tr.topo.Region; + +public class GeoNamesReader extends GazetteerLineReader { + public GeoNamesReader(File file) throws FileNotFoundException, IOException { + this(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))))); + } + + public GeoNamesReader(BufferedReader reader) + throws FileNotFoundException, IOException { + super(reader); + } + + @Override + protected Location.Type getLocationType(String code) { + Location.Type type = Location.Type.UNKNOWN; + if (code.length() > 0) { + if (code.equals("A")) { + type = Location.Type.STATE; + } else if (code.equals("H")) { + type = Location.Type.WATER; + } else if (code.equals("L")) { + type = Location.Type.PARK; + } else if (code.equals("P")) { + type = Location.Type.CITY; + } else if (code.equals("R")) { + type = Location.Type.TRANSPORT; + } else if (code.equals("S")) { + type = Location.Type.SITE; + } else if (code.equals("T")) { + type = Location.Type.MOUNTAIN; + } else if (code.equals("U")) { + type = Location.Type.UNDERSEA; + } else if (code.equals("V")) { + type = Location.Type.FOREST; + } + } + return type; + } + + protected Location parseLine(String line, int currentId) { + Location location = null; + String[] fields = line.split("\t"); + if (fields.length > 14) { + try { + String name = fields[1].toLowerCase(); + Location.Type type = this.getLocationType(fields[6], fields[7]); + + double lat = Double.parseDouble(fields[4]); + double lng = Double.parseDouble(fields[5]); + Coordinate coordinate = Coordinate.fromDegrees(lat, lng); + Region region = new PointRegion(coordinate); + + int population = fields[14].length() == 0 ? 0 : Integer.parseInt(fields[14]); + + location = new Location(currentId, name, region, type, population); + } catch (NumberFormatException e) { + System.err.format("Invalid population: %s\n", fields[14]); + } + } + return location; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/InMemoryGazetteer.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/InMemoryGazetteer.java new file mode 100644 index 0000000..6708449 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/InMemoryGazetteer.java @@ -0,0 +1,46 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import opennlp.fieldspring.tr.topo.Location; + +public class InMemoryGazetteer extends LoadableGazetteer { + private final Map> map; + + public InMemoryGazetteer() { + this.map = new HashMap>(); + } + + public void add(String name, Location location) { + name = name.toLowerCase(); + List locations = this.map.get(name); + if (locations == null) { + locations = new ArrayList(); + } + locations.add(location); + this.map.put(name, locations); + } + + public List lookup(String query) { + return this.map.get(query.toLowerCase()); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/LoadableGazetteer.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/LoadableGazetteer.java new file mode 100644 index 0000000..a1dc9d8 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/LoadableGazetteer.java @@ -0,0 +1,38 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.List; +import opennlp.fieldspring.tr.topo.Location; + +public abstract class LoadableGazetteer implements Gazetteer { + public abstract void add(String name, Location location); + + public int load(GazetteerReader reader) { + int count = 0; + for (Location location : reader) { + count++; + this.add(location.getName(), location); + } + reader.close(); + this.finishLoading(); + return count; + } + + public void finishLoading() {} + public void close() {} +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/MultiGazetteer.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/MultiGazetteer.java new file mode 100644 index 0000000..2e4ec4d --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/MultiGazetteer.java @@ -0,0 +1,39 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.util.ArrayList; +import java.util.List; +import opennlp.fieldspring.tr.topo.Location; + +public class MultiGazetteer implements Gazetteer { + private final List gazetteers; + + public MultiGazetteer(List gazetteers) { + this.gazetteers = gazetteers; + } + + public List lookup(String query) { + for (Gazetteer gazetteer : this.gazetteers) { + List candidates = gazetteer.lookup(query); + if (candidates != null) { + return candidates; + } + } + return null; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/topo/gaz/WorldReader.java b/src/main/java/opennlp/fieldspring/tr/topo/gaz/WorldReader.java new file mode 100644 index 0000000..f1d6d0f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/topo/gaz/WorldReader.java @@ -0,0 +1,74 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStreamReader; +import java.io.IOException; +import java.io.BufferedReader; +import java.util.zip.GZIPInputStream; + +import opennlp.fieldspring.tr.topo.Coordinate; +import opennlp.fieldspring.tr.topo.Location; +import opennlp.fieldspring.tr.topo.PointRegion; +import opennlp.fieldspring.tr.topo.Region; + +public class WorldReader extends GazetteerLineReader { + public WorldReader(File file) throws FileNotFoundException, IOException { + this(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))))); + } + + public WorldReader(BufferedReader reader) + throws FileNotFoundException, IOException { + super(reader); + } + + private double convertNumber(String number) { + boolean negative = number.charAt(0) == '-'; + number = "00" + (negative ? number.substring(1) : number); + int split = number.length() - 2; + number = number.substring(0, split) + "." + number.substring(split); + return Double.parseDouble(number) * (negative ? -1 : 1); + } + + protected Location parseLine(String line, int currentId) { + Location location = null; + String[] fields = line.split("\t"); + if (fields.length > 7 && fields[6].length() > 0 && fields[7].length() > 0 && + !(fields[6].equals("0") && fields[7].equals("9999"))) { + String name = fields[1].toLowerCase(); + Location.Type type = this.getLocationType(fields[4].toLowerCase()); + + double lat = this.convertNumber(fields[6].trim()); + double lng = this.convertNumber(fields[7].trim()); + Coordinate coordinate = Coordinate.fromDegrees(lat, lng); + Region region = new PointRegion(coordinate); + + int population = fields[5].trim().length() > 0 ? Integer.parseInt(fields[5]) : 0; + + /*String container = null; + if (fields.length > 10 && fields[10].trim().length() > 0) { + container = fields[10].trim().toLowerCase(); + }*/ + + location = new Location(currentId, name, region, type, population); + } + return location; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/Constants.java b/src/main/java/opennlp/fieldspring/tr/util/Constants.java new file mode 100644 index 0000000..d52c625 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/Constants.java @@ -0,0 +1,83 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2007 Jason Baldridge, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.File; +import java.text.DecimalFormat; + +/** + * Class for keeping constant values. + * + * @author Jason Baldridge + * @version $Revision: 1.53 $, $Date: 2006/10/12 21:20:44 $ + */ +public class Constants { + + /** + * Machine epsilon for comparing equality in floating point numbers. + */ + public static final double EPSILON = 1e-6; + + // the location of Fieldspring + public final static String FIELDSPRING_DIR = System.getenv("FIELDSPRING_DIR"); + public final static String FIELDSPRING_DATA = "data"; + + // the location of the OpenNLP models + public final static String OPENNLP_MODELS = Constants.getOpenNLPModelsDir(); + + + public static String getOpenNLPModelsDir() { + String dir = System.getenv("OPENNLP_MODELS"); + if (dir == null) { + dir = System.getProperty("opennlp.models"); + if (dir == null) { + dir = FIELDSPRING_DIR + File.separator + "data/models"; + //dir = System.getProperty("user.dir") + File.separator + "data/models"; + } + } + return dir; + } + + public static String getGazetteersDir() { + String dir = (System.getenv("FIELDSPRING_DATA")!=null)? System.getenv("FIELDSPRING_DATA"):FIELDSPRING_DATA + File.separator + "gazetteers"; + if (dir == null) { + dir = System.getProperty("gazetteers"); + if (dir == null) { + dir = FIELDSPRING_DIR + File.separator + "data/gazetteers"; + //dir = System.getProperty("user.dir") + File.separator + "data/gazetteers"; + } + } + return dir; + } + + // the location of the World Gazetteer database file +// public final static String WGDB_PATH = System.getenv("WGDB_PATH"); + + // the location of the TR-CoNLL database file +// public static final String TRDB_PATH = System.getenv("TRDB_PATH"); + + // the location of the user's home directory + public final static String USER_HOME = System.getProperty("user.home"); + + // the current working directory + public final static String CWD = System.getProperty("user.dir"); + + // The format for printing precision, recall and f-scores. + public static final DecimalFormat PERCENT_FORMAT = + new DecimalFormat("#,##0.00%"); + + +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/CountingLexicon.java b/src/main/java/opennlp/fieldspring/tr/util/CountingLexicon.java new file mode 100644 index 0000000..eade148 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/CountingLexicon.java @@ -0,0 +1,25 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.Serializable; +import java.util.List; + +public interface CountingLexicon extends Lexicon, Serializable { + public int count(A entry); + public int countAtIndex(int index); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/DoubleStringPair.java b/src/main/java/opennlp/fieldspring/tr/util/DoubleStringPair.java new file mode 100644 index 0000000..77599a3 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/DoubleStringPair.java @@ -0,0 +1,59 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Taesun Moon, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +/** + * A pair of an double and String + * + * @author Taesun Moon + */ +public final class DoubleStringPair implements Comparable { + /** + * + */ + public double doubleValue; + /** + * + */ + public String stringValue; + + /** + * + * + * @param d + * @param s + */ + public DoubleStringPair (double d, String s) { + doubleValue = d; + stringValue = s; + } + + /** + * sorting order is reversed -- higher (int) values come first + * + * @param p + * @return + */ + public int compareTo (DoubleStringPair p) { + if (doubleValue < p.doubleValue) + return 1; + else if (doubleValue > p.doubleValue) + return -1; + else + return 0; + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/EditMapper.java b/src/main/java/opennlp/fieldspring/tr/util/EditMapper.java new file mode 100644 index 0000000..e4a7b3b --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/EditMapper.java @@ -0,0 +1,127 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.util.Span; + +public class EditMapper { + protected enum Operation { + SUB, DEL, INS + } + + private final int cost; + private final List operations; + + public EditMapper(List s, List t) { + int[][] ds = new int[s.size() + 1][t.size() + 1]; + Operation[][] os = new Operation[s.size() + 1][t.size() + 1]; + + for (int i = 0; i <= s.size(); i++) { ds[i][0] = i; os[i][0] = Operation.DEL; } + for (int j = 0; j <= t.size(); j++) { ds[0][j] = j; os[0][j] = Operation.INS; } + + for (int i = 1; i <= s.size(); i++) { + for (int j = 1; j <= t.size(); j++) { + int del = ds[i - 1][j] + delCost(t.get(j - 1)); + int ins = ds[i][j - 1] + insCost(s.get(i - 1)); + int sub = ds[i - 1][j - 1] + subCost(s.get(i - 1), t.get(j - 1)); + + if (sub <= del) { + if (sub <= ins) { + ds[i][j] = sub; + os[i][j] = Operation.SUB; + } else { + ds[i][j] = ins; + os[i][j] = Operation.INS; + } + } else { + if (del <= ins) { + ds[i][j] = del; + os[i][j] = Operation.DEL; + } else { + ds[i][j] = ins; + os[i][j] = Operation.INS; + } + } + } + } + + this.cost = ds[s.size()][t.size()]; + this.operations = new ArrayList(); + + int i = s.size(); + int j = t.size(); + + while (i > 0 || j > 0) { + this.operations.add(os[i][j]); + switch (os[i][j]) { + case SUB: i--; j--; break; + case INS: j--; break; + case DEL: i--; break; + } + } + + Collections.reverse(this.operations); + } + + public int getCost() { + return this.cost; + } + + public List getOperations() { + return this.operations; + } + + protected int delCost(A x) { return 1; } + protected int insCost(A x) { return 1; } + protected int subCost(A x, A y) { return x.equals(y) ? -1 : 1; } + + public Span map(Span span) { + int start = span.getStart(); + int end = span.getEnd(); + + int current = 0; + for (Operation operation : this.operations) { + switch (operation) { + case SUB: + current++; + break; + case DEL: + if (current < span.getStart()) { + start--; + } else if (current < span.getEnd()) { + end--; + } + current++; + break; + case INS: + if (current < span.getStart()) { + start++; + } else if (current < span.getEnd()) { + end++; + } + break; + } + } + + return new Span(start, end, span.getItem()); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java b/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java new file mode 100644 index 0000000..590b5c3 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java @@ -0,0 +1,272 @@ +/** + * Copyright (C) 2007 J4ME, 2010 Travis Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package opennlp.fieldspring.tr.util; + +/** + * Provides faster implementations of inverse trigonometric functions. + * Adapted from code developed by J4ME , some + * of which was adapted from a Pascal implementation posted by Everything2 + * user Gorgonzola . + * + * @author Dean Browne + * @author Randy Simon + * @author Michael Ebbage + * @author Travis Brown + */ +public final class FastTrig { + /** + * Constant used in the atan calculation. + */ + private static final double ATAN_CONSTANT = 1.732050807569; + + /** + * Returns the arc cosine of an angle, in the range of 0.0 through Math.PI. + * Special case: + *
    + *
  • If the argument is NaN or its absolute value is greater than 1, + * then the result is NaN. + *
+ * + * @param a - the value whose arc cosine is to be returned. + * @return the arc cosine of the argument. + */ + public static double acos(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else { + return FastTrig.atan2(Math.sqrt(1.0 - a * a), a); + } + } + + /** + * Returns the arc sine of an angle, in the range of -Math.PI/2 through + * Math.PI/2. Special cases: + *
    + *
  • If the argument is NaN or its absolute value is greater than 1, + * then the result is NaN. + *
  • If the argument is zero, then the result is a zero with the same sign + * as the argument. + *
+ * + * @param a - the value whose arc sine is to be returned. + * @return the arc sine of the argument. + */ + public static double asin(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else if (a == 0.0) { + return a; + } else { + return FastTrig.atan2(a, Math.sqrt(1.0 - a * a)); + } + } + + /** + * Returns the arc tangent of an angle, in the range of -Math.PI/2 + * through Math.PI/2. Special cases: + *
    + *
  • If the argument is NaN, then the result is NaN. + *
  • If the argument is zero, then the result is a zero with the same + * sign as the argument. + *
+ *

+ * A result must be within 1 ulp of the correctly rounded result. Results + * must be semi-monotonic. + * + * @param a - the value whose arc tangent is to be returned. + * @return the arc tangent of the argument. + */ + public static double atan(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else if (a == 0.0) { + return a; + } else { + boolean negative = false; + boolean greaterThanOne = false; + int i = 0; + + if (a < 0.0) { + a = -a; + negative = true; + } + + if (a > 1.0) { + a = 1.0 / a; + greaterThanOne = true; + } + + for (double t = 0.0; a > Math.PI / 12.0; a *= t) { + i++; + t = 1.0 / (a + ATAN_CONSTANT); + a *= ATAN_CONSTANT; + a -= 1.0; + } + + double arcTangent = a * (0.55913709 + / (a * a + 1.4087812) + + 0.60310578999999997 + - 0.051604539999999997 * a * a); + + for (; i > 0; i--) { + arcTangent += Math.PI / 6.0; + } + + if (greaterThanOne) { + arcTangent = Math.PI / 2.0 - arcTangent; + } + + if (negative) + { + arcTangent = -arcTangent; + } + + return arcTangent; + } + } + + /** + * Converts rectangular coordinates (x, y) to polar (r, theta). This method + * computes the phase theta by computing an arc tangent of y/x in the range + * of -pi to pi. Special cases: + *

    + *
  • If either argument is NaN, then the result is NaN. + *
  • If the first argument is positive zero and the second argument is + * positive, or the first argument is positive and finite and the second + * argument is positive infinity, then the result is positive zero. + *
  • If the first argument is negative zero and the second argument is + * positive, or the first argument is negative and finite and the second + * argument is positive infinity, then the result is negative zero. + *
  • If the first argument is positive zero and the second argument is + * negative, or the first argument is positive and finite and the second + * argument is negative infinity, then the result is the double value + * closest to pi. + *
  • If the first argument is negative zero and the second argument is + * negative, or the first argument is negative and finite and the second + * argument is negative infinity, then the result is the double value + * closest to -pi. + *
  • If the first argument is positive and the second argument is positive + * zero or negative zero, or the first argument is positive infinity and + * the second argument is finite, then the result is the double value + * closest to pi/2. + *
  • If the first argument is negative and the second argument is positive + * zero or negative zero, or the first argument is negative infinity and + * the second argument is finite, then the result is the double value + * closest to -pi/2. + *
  • If both arguments are positive infinity, then the result is the double + * value closest to pi/4. + *
  • If the first argument is positive infinity and the second argument is + * negative infinity, then the result is the double value closest to 3*pi/4. + *
  • If the first argument is negative infinity and the second argument is + * positive infinity, then the result is the double value closest to -pi/4. + *
  • If both arguments are negative infinity, then the result is the double + * value closest to -3*pi/4. + *
+ *

+ * A result must be within 2 ulps of the correctly rounded result. Results + * must be semi-monotonic. + * + * @param y - the ordinate coordinate + * @param x - the abscissa coordinate + * @return the theta component of the point (r, theta) in polar + * coordinates that corresponds to the point (x, y) in Cartesian coordinates. + */ + public static double atan2(double y, double x) { + if (Double.isNaN(y) || Double.isNaN(x)) { + return Double.NaN; + } else if (Double.isInfinite(y)) { + if (y > 0.0) { + if (Double.isInfinite(x)) { + if (x > 0.0) { + return Math.PI / 4.0; + } else { + return 3.0 * Math.PI / 4.0; + } + } else if (x != 0.0) { + return Math.PI / 2.0; + } + } + else { + if (Double.isInfinite(x)) { + if (x > 0.0) { + return -Math.PI / 4.0; + } else { + return -3.0 * Math.PI / 4.0; + } + } else if (x != 0.0) { + return -Math.PI / 2.0; + } + } + } else if (y == 0.0) { + if (x > 0.0) { + return y; + } else if (x < 0.0) { + return Math.PI; + } + } else if (Double.isInfinite(x)) { + if (x > 0.0) { + if (y > 0.0) { + return 0.0; + } else if (y < 0.0) { + return -0.0; + } + } + else { + if (y > 0.0) { + return Math.PI; + } else if (y < 0.0) { + return -Math.PI; + } + } + } else if (x == 0.0) { + if (y > 0.0) { + return Math.PI / 2.0; + } else if (y < 0.0) { + return -Math.PI / 2.0; + } + } + + // Implementation a simple version ported from a PASCAL implementation at + // . + double arcTangent; + + // Use arctan() avoiding division by zero. + if (Math.abs(x) > Math.abs(y)) { + arcTangent = FastTrig.atan(y / x); + } else { + arcTangent = FastTrig.atan(x / y); // -PI/4 <= a <= PI/4. + + if (arcTangent < 0.0) { + // a is negative, so we're adding. + arcTangent = -Math.PI / 2 - arcTangent; + } else { + arcTangent = Math.PI / 2 - arcTangent; + } + } + + // Adjust result to be from [-PI, PI] + if (x < 0.0) { + if (y < 0.0) { + arcTangent = arcTangent - Math.PI; + } else { + arcTangent = arcTangent + Math.PI; + } + } + + return arcTangent; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java~ b/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java~ new file mode 100644 index 0000000..2890ebd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/FastTrig.java~ @@ -0,0 +1,272 @@ +/** + * Copyright (C) 2007 J4ME, 2010 Travis Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package opennlp.fieldspring.util; + +/** + * Provides faster implementations of inverse trigonometric functions. + * Adapted from code developed by J4ME , some + * of which was adapted from a Pascal implementation posted by Everything2 + * user Gorgonzola . + * + * @author Dean Browne + * @author Randy Simon + * @author Michael Ebbage + * @author Travis Brown + */ +public final class FastTrig { + /** + * Constant used in the atan calculation. + */ + private static final double ATAN_CONSTANT = 1.732050807569; + + /** + * Returns the arc cosine of an angle, in the range of 0.0 through Math.PI. + * Special case: + *

    + *
  • If the argument is NaN or its absolute value is greater than 1, + * then the result is NaN. + *
+ * + * @param a - the value whose arc cosine is to be returned. + * @return the arc cosine of the argument. + */ + public static double acos(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else { + return FastTrig.atan2(Math.sqrt(1.0 - a * a), a); + } + } + + /** + * Returns the arc sine of an angle, in the range of -Math.PI/2 through + * Math.PI/2. Special cases: + *
    + *
  • If the argument is NaN or its absolute value is greater than 1, + * then the result is NaN. + *
  • If the argument is zero, then the result is a zero with the same sign + * as the argument. + *
+ * + * @param a - the value whose arc sine is to be returned. + * @return the arc sine of the argument. + */ + public static double asin(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else if (a == 0.0) { + return a; + } else { + return FastTrig.atan2(a, Math.sqrt(1.0 - a * a)); + } + } + + /** + * Returns the arc tangent of an angle, in the range of -Math.PI/2 + * through Math.PI/2. Special cases: + *
    + *
  • If the argument is NaN, then the result is NaN. + *
  • If the argument is zero, then the result is a zero with the same + * sign as the argument. + *
+ *

+ * A result must be within 1 ulp of the correctly rounded result. Results + * must be semi-monotonic. + * + * @param a - the value whose arc tangent is to be returned. + * @return the arc tangent of the argument. + */ + public static double atan(double a) { + if (Double.isNaN(a) || Math.abs(a) > 1.0) { + return Double.NaN; + } else if (a == 0.0) { + return a; + } else { + boolean negative = false; + boolean greaterThanOne = false; + int i = 0; + + if (a < 0.0) { + a = -a; + negative = true; + } + + if (a > 1.0) { + a = 1.0 / a; + greaterThanOne = true; + } + + for (double t = 0.0; a > Math.PI / 12.0; a *= t) { + i++; + t = 1.0 / (a + ATAN_CONSTANT); + a *= ATAN_CONSTANT; + a -= 1.0; + } + + double arcTangent = a * (0.55913709 + / (a * a + 1.4087812) + + 0.60310578999999997 + - 0.051604539999999997 * a * a); + + for (; i > 0; i--) { + arcTangent += Math.PI / 6.0; + } + + if (greaterThanOne) { + arcTangent = Math.PI / 2.0 - arcTangent; + } + + if (negative) + { + arcTangent = -arcTangent; + } + + return arcTangent; + } + } + + /** + * Converts rectangular coordinates (x, y) to polar (r, theta). This method + * computes the phase theta by computing an arc tangent of y/x in the range + * of -pi to pi. Special cases: + *

    + *
  • If either argument is NaN, then the result is NaN. + *
  • If the first argument is positive zero and the second argument is + * positive, or the first argument is positive and finite and the second + * argument is positive infinity, then the result is positive zero. + *
  • If the first argument is negative zero and the second argument is + * positive, or the first argument is negative and finite and the second + * argument is positive infinity, then the result is negative zero. + *
  • If the first argument is positive zero and the second argument is + * negative, or the first argument is positive and finite and the second + * argument is negative infinity, then the result is the double value + * closest to pi. + *
  • If the first argument is negative zero and the second argument is + * negative, or the first argument is negative and finite and the second + * argument is negative infinity, then the result is the double value + * closest to -pi. + *
  • If the first argument is positive and the second argument is positive + * zero or negative zero, or the first argument is positive infinity and + * the second argument is finite, then the result is the double value + * closest to pi/2. + *
  • If the first argument is negative and the second argument is positive + * zero or negative zero, or the first argument is negative infinity and + * the second argument is finite, then the result is the double value + * closest to -pi/2. + *
  • If both arguments are positive infinity, then the result is the double + * value closest to pi/4. + *
  • If the first argument is positive infinity and the second argument is + * negative infinity, then the result is the double value closest to 3*pi/4. + *
  • If the first argument is negative infinity and the second argument is + * positive infinity, then the result is the double value closest to -pi/4. + *
  • If both arguments are negative infinity, then the result is the double + * value closest to -3*pi/4. + *
+ *

Joins the elements of the provided array into a single String + * containing the provided list of elements.

+ * + *

No delimiter is added before or after the list. + * A null separator is the same as an empty String (""). + * Null objects or empty strings within the array are represented by + * empty strings.

+ * + *
+     * StringUtils.join(null, *)                = null
+     * StringUtils.join([], *)                  = ""
+     * StringUtils.join([null], *)              = ""
+     * StringUtils.join(["a", "b", "c"], "--")  = "a--b--c"
+     * StringUtils.join(["a", "b", "c"], null)  = "abc"
+     * StringUtils.join(["a", "b", "c"], "")    = "abc"
+     * StringUtils.join([null, "", "a"], ',')   = ",,a"
+     * 
+ * + * @param array the array of values to join together, may be null + * @param separator the separator character to use, null treated as "" + * @param startIndex the first index to start joining from. It is + * an error to pass in an end index past the end of the array + * @param endIndex the index to stop joining from (exclusive). It is + * an error to pass in an end index past the end of the array + * @return the joined String, null if null array input + */ + public static String join(Object[] array, String separator, int startIndex, + int endIndex) { + if (array == null) { + return null; + } + if (separator == null) { + separator = ""; + } + + // endIndex - startIndex > 0: Len = NofStrings *(len(firstString) + len(separator)) + // (Assuming that all Strings are roughly equally long) + int bufSize = (endIndex - startIndex); + if (bufSize <= 0) { + return ""; + } + + bufSize *= ((array[startIndex] == null ? 16 + : array[startIndex].toString().length()) + + separator.length()); + + StringBuffer buf = new StringBuffer(bufSize); + + for (int i = startIndex; i < endIndex; i++) { + if (i > startIndex) { + buf.append(separator); + } + if (array[i] != null) { + buf.append(array[i]); + } + } + return buf.toString(); + } + + /** + *

Joins the elements of the provided array into a single String + * containing the provided list of elements.

+ * + *

No delimiter is added before or after the list. + * A null separator is the same as an empty String (""). + * Null objects or empty strings within the array are represented by + * empty strings.

+ * + *
+     * StringUtils.join(null, *)                = null
+     * StringUtils.join([], *)                  = ""
+     * StringUtils.join([null], *)              = ""
+     * StringUtils.join(["a", "b", "c"], "--")  = "a--b--c"
+     * StringUtils.join(["a", "b", "c"], null)  = "abc"
+     * StringUtils.join(["a", "b", "c"], "")    = "abc"
+     * StringUtils.join([null, "", "a"], ',')   = ",,a"
+     * 
+ * + * @param array the array of values to join together, may be null + * @param separator the separator character to use, null treated as "" + * @param startIndex the first index to start joining from. It is + * an error to pass in an end index past the end of the array + * @param endIndex the index to stop joining from (exclusive). It is + * an error to pass in an end index past the end of the array + * @param internalSep separator for each element in array. each element + * will be split with the separator and the first element will be + * added to the buffer + * @return the joined String, null if null array input + */ + public static String join(Object[] array, String separator, int startIndex, + int endIndex, String internalSep) { + if (array == null) { + return null; + } + if (separator == null) { + separator = ""; + } + + // endIndex - startIndex > 0: Len = NofStrings *(len(firstString) + len(separator)) + // (Assuming that all Strings are roughly equally long) + int bufSize = (endIndex - startIndex); + if (bufSize <= 0) { + return ""; + } + + bufSize *= ((array[startIndex] == null ? 16 + : array[startIndex].toString().length()) + + separator.length()); + + StringBuffer buf = new StringBuffer(bufSize); + + for (int i = startIndex; i < endIndex; i++) { + if (i > startIndex) { + buf.append(separator); + } + if (array[i] != null) { + String token = array[i].toString().split(internalSep)[0]; + buf.append(token); + } + } + return buf.toString(); + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/TopoUtil.java b/src/main/java/opennlp/fieldspring/tr/util/TopoUtil.java new file mode 100644 index 0000000..dd4c3c8 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/TopoUtil.java @@ -0,0 +1,196 @@ +package opennlp.fieldspring.tr.util; + +import opennlp.fieldspring.tr.text.*; +import opennlp.fieldspring.tr.topo.*; +import java.util.*; +import opennlp.fieldspring.tr.resolver.*; +import opennlp.fieldspring.tr.text.io.*; +import opennlp.fieldspring.tr.text.prep.*; +import opennlp.fieldspring.tr.topo.gaz.*; +import opennlp.fieldspring.tr.eval.*; +import opennlp.fieldspring.tr.util.*; +import java.io.*; +import java.util.zip.*; + +public class TopoUtil { + + public static Lexicon buildLexicon(StoredCorpus corpus) { + Lexicon lexicon = new SimpleLexicon(); + + addToponymsToLexicon(lexicon, corpus); + + return lexicon; + } + + public static void addToponymsToLexicon(Lexicon lexicon, StoredCorpus corpus) { + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + lexicon.getOrAdd(toponym.getForm()); + } + } + } + } + } + + public static void buildLexicons(StoredCorpus corpus, Lexicon lexicon, HashMap reverseLexicon) { + for(Document doc : corpus) { + for(Sentence sent : doc) { + for(Toponym toponym : sent.getToponyms()) { + if(toponym.getAmbiguity() > 0) { + int idx = lexicon.getOrAdd(toponym.getForm()); + reverseLexicon.put(idx, toponym.getForm()); + } + } + } + } + } + + public static Set getCellNumbers(Location location, double dpc) { + + Set cellNumbers = new HashSet(); + + for(Coordinate coord : location.getRegion().getRepresentatives()) { + + /*int x = (int) ((coord.getLng() + 180.0) / dpc); + int y = (int) ((coord.getLat() + 90.0) / dpc);*/ + + cellNumbers.add(getCellNumber(coord.getLatDegrees(), coord.getLngDegrees(), dpc)/*x * 1000 + y*/); + } + + return cellNumbers; + } + + public static int getCellNumber(double lat, double lon, double dpc) { + + int x = (int) ((lon + 180.0) / dpc); + int y = (int) ((lat + 90.0) / dpc); + + if(y < 0 || y >= 180/dpc) return -1; + if(x < 0) x += 360/dpc; + if(x > 360/dpc) x -= 360/dpc; + + return x * 1000 + y; + + // Make everything positive: + /*lat += 90/dpc; + lon += 180/dpc; + + if(lat < 0 || lat >= 180/dpc) return -1; + if(lon < 0) lon += 360/dpc; // wrap lon around + if(lon >= 360/dpc) lon -= 360/dpc; // wrap lon around + + return (int)((int)(lat/dpc) * 1000 + (lon/dpc));*/ + } + + public static int getCellNumber(Coordinate coord, double dpc) { + return getCellNumber(coord.getLatDegrees(), coord.getLngDegrees(), dpc); + } + + public static Coordinate getCellCenter(int cellNumber, double dpc) { + int x = cellNumber / 1000; + int y = cellNumber % 1000; + + double lat = (dpc * y - 90) + dpc/2.0; + double lon = (dpc * x - 180) + dpc/2.0; + + if(lat >= 90) lat -= 90; + if(lon >= 180) lon -= 180; + + return Coordinate.fromDegrees(lat, lon); + } + + public static int getCorrectCandidateIndex(Toponym toponym, Map cellDistribution, double dpc) { + double maxMass = Double.NEGATIVE_INFINITY; + int maxIndex = -1; + int index = 0; + for(Location location : toponym.getCandidates()) { + double totalMass = 0.0; + for(int cellNumber : getCellNumbers(location, dpc)) { + //if(regionDistribution == null) + // System.err.println("regionDistribution is null!"); + + Double mass = cellDistribution.get(cellNumber); + //if(mass == null) + // System.err.println("mass null for regionNumber " + regionNumber); + if(mass != null) + totalMass += mass; + } + if(totalMass > maxMass) { + maxMass = totalMass; + maxIndex = index; + } + index++; + } + + return maxIndex; + } + + public static int getCorrectCandidateIndex(Toponym toponym, int cellNumber, double dpc) { + if(cellNumber == -1) System.out.println("-1"); + int index = 0; + for(Location location : toponym.getCandidates()) { + if(getCellNumbers(location, dpc).contains(cellNumber)) + return index; + index++; + } + + return -1; + } + + public static Corpus readCorpusFromSerialized(String serializedCorpusInputPath) throws Exception { + + Corpus corpus; + ObjectInputStream ois = null; + if(serializedCorpusInputPath.toLowerCase().endsWith(".gz")) { + GZIPInputStream gis = new GZIPInputStream(new FileInputStream(serializedCorpusInputPath)); + ois = new ObjectInputStream(gis); + } + else { + FileInputStream fis = new FileInputStream(serializedCorpusInputPath); + ois = new ObjectInputStream(fis); + } + corpus = (StoredCorpus) ois.readObject(); + + return corpus; + } + + + public static StoredCorpus readStoredCorpusFromSerialized(String serializedCorpusInputPath) throws Exception { + + StoredCorpus corpus; + ObjectInputStream ois = null; + if(serializedCorpusInputPath.toLowerCase().endsWith(".gz")) { + GZIPInputStream gis = new GZIPInputStream(new FileInputStream(serializedCorpusInputPath)); + ois = new ObjectInputStream(gis); + } + else { + FileInputStream fis = new FileInputStream(serializedCorpusInputPath); + ois = new ObjectInputStream(fis); + } + corpus = (StoredCorpus) ois.readObject(); + + return corpus; + } + + public static List filter(List locs, Region boundingBox) { + if(boundingBox == null || locs == null) return locs; + + List toReturn = new ArrayList(); + + for(Location loc : locs) { + boolean containsAllPoints = true; + for(Coordinate coord : loc.getRegion().getRepresentatives()) { + if(!boundingBox.contains(coord)) { + containsAllPoints = false; + break; + } + } + if(containsAllPoints) + toReturn.add(loc); + } + + return toReturn; + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/ToponymFinder.java b/src/main/java/opennlp/fieldspring/tr/util/ToponymFinder.java new file mode 100644 index 0000000..4354c6d --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/ToponymFinder.java @@ -0,0 +1,86 @@ +/** + * + */ +package opennlp.fieldspring.tr.util; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; +import java.util.zip.ZipInputStream; + +import opennlp.fieldspring.tr.text.prep.HighRecallToponymRecognizer; +import opennlp.fieldspring.tr.text.prep.NamedEntityRecognizer; +import opennlp.fieldspring.tr.text.prep.NamedEntityType; +import opennlp.fieldspring.tr.text.prep.OpenNLPRecognizer; +import opennlp.fieldspring.tr.text.prep.OpenNLPSentenceDivider; +import opennlp.fieldspring.tr.text.prep.OpenNLPTokenizer; +import opennlp.fieldspring.tr.text.prep.SentenceDivider; +import opennlp.fieldspring.tr.text.prep.Tokenizer; +import opennlp.fieldspring.tr.topo.gaz.GeoNamesGazetteer; +import opennlp.fieldspring.tr.util.Span; +import opennlp.tools.util.InvalidFormatException; + +/** + * @author abhimanu kumar + * + */ +public class ToponymFinder { + + /** + * @param args + */ + private final SentenceDivider sentDivider; + private final Tokenizer tokenizer; + private final NamedEntityRecognizer recognizer; + private BufferedReader input; + + public ToponymFinder(BufferedReader reader, String gazPath) throws Exception{ + sentDivider = new OpenNLPSentenceDivider(); + tokenizer = new OpenNLPTokenizer(); + recognizer = new HighRecallToponymRecognizer(gazPath); + this.input = reader; + } + + + public static void main(String[] args) throws Exception { + ToponymFinder finder = new ToponymFinder(new BufferedReader(new FileReader(args[0]/*"TheStoryTemp.txt"*/)),args[1]/*"data/gazetteers/US.ser.gz"*/); +// long startTime = System.currentTimeMillis(); + finder.find(); +// long stopTime = System.currentTimeMillis(); +// System.out.println((stopTime-startTime)/1000 + "secs"); + } + + + private HashSet find() throws IOException { + String line; + HashSet resultSet = new HashSet(); + while((line=input.readLine())!=null){ + List sentencesString = sentDivider.divide(line); + for (String sentence : sentencesString){ + List tokens = new ArrayList(); + for(String token : tokenizer.tokenize(sentence)){ + tokens.add(token); + } + List> spans =recognizer.recognize(tokens); + for(Span span:spans){ + StringBuilder resultToken= new StringBuilder(); + for(int i=span.getStart();i toReturnAL = new ArrayList(); + + if(hasCorpusElement) { + + NodeList corpusDocs = doc.getChildNodes().item(0).getChildNodes(); + + for(int d = 0; d < corpusDocs.getLength(); d++) { + if(!corpusDocs.item(d).getNodeName().equals("doc")) + continue; + + //System.out.println(doc.getChildNodes().getLength()); + + NodeList sentences = corpusDocs.item(d).getChildNodes(); + + for(int i = 0; i < sentences.getLength(); i++) { + if(!sentences.item(i).getNodeName().equals("s")) + continue; + NodeList tokens = sentences.item(i).getChildNodes(); + for(int j = 0; j < tokens.getLength(); j++) { + Node tokenNode = tokens.item(j); + if(tokenNode.getNodeName().equals("toponym")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("term").getNodeValue()); + } + else if(tokenNode.getNodeName().equals("w")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("tok").getNodeValue()); + } + + } + } + } + } + else { + NodeList sentences = doc.getChildNodes().item(1).getChildNodes(); + + for(int i = 0; i < sentences.getLength(); i++) { + if(!sentences.item(i).getNodeName().equals("s")) + continue; + NodeList tokens = sentences.item(i).getChildNodes(); + for(int j = 0; j < tokens.getLength(); j++) { + Node tokenNode = tokens.item(j); + if(tokenNode.getNodeName().equals("toponym")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("term").getNodeValue()); + } + else if(tokenNode.getNodeName().equals("w")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("tok").getNodeValue()); + } + + } + } + } + + return toReturnAL.toArray(new String[0]); + } + + public static String[] getAllTokens(Document doc) { + return getAllTokens(doc, false); + /*ArrayList toReturnAL = new ArrayList(); + + NodeList sentences = doc.getChildNodes().item(1).getChildNodes(); + + for(int i = 0; i < sentences.getLength(); i++) { + if(!sentences.item(i).getNodeName().equals("s")) + continue; + NodeList tokens = sentences.item(i).getChildNodes(); + for(int j = 0; j < tokens.getLength(); j++) { + Node tokenNode = tokens.item(j); + if(tokenNode.getNodeName().equals("toponym")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("term").getNodeValue()); + } + else if(tokenNode.getNodeName().equals("w")) { + toReturnAL.add(tokenNode.getAttributes().getNamedItem("tok").getNodeValue()); + } + + } + } + + return toReturnAL.toArray(new String[0]); + */ + } + + public static String[] getContextWindow(String[] a, int index, int windowSize) { + ArrayList toReturnAL = new ArrayList(); + + int begin = Math.max(0, index - windowSize); + int end = Math.min(a.length, index + windowSize + 1); + + for(int i = begin; i < end; i++) { + if(i == index) + toReturnAL.add("[h]" + a[i] + "[/h]"); + else + toReturnAL.add(a[i]); + } + + return toReturnAL.toArray(new String[0]); + } +} diff --git a/src/main/python/article_statistics.py b/src/main/python/article_statistics.py new file mode 100755 index 0000000..154d664 --- /dev/null +++ b/src/main/python/article_statistics.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +####### +####### article_statistics.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +# Counts and outputs various statistics about the articles in the +# article-data file. + +from nlputil import * +import process_article_data as pad + +######################## Statistics about articles + +class ArticleStatistics(object): + def __init__(self): + self.total_num = 0 + self.num_by_split = intdict() + self.num_redir = 0 + self.num_by_namespace = intdict() + self.num_list_of = 0 + self.num_disambig = 0 + self.num_list = 0 + + def record_article(self, art): + self.total_num += 1 + self.num_by_split[art.split] += 1 + self.num_by_namespace[art.namespace] += 1 + if art.redir: + self.num_redir += 1 + if art.is_list_of: + self.num_list_of += 1 + if art.is_disambig: + self.num_disambig += 1 + if art.is_list: + self.num_list += 1 + + def output_stats(self, stats_type, outfile=sys.stdout): + def outprint(foo): + uniprint(foo, outfile=outfile) + outprint("Article statistics about %s" % stats_type) + outprint(" Total number = %s" % self.total_num) + outprint(" Number of redirects = %s" % self.num_redir) + outprint(" Number of 'List of' articles = %s" % self.num_list_of) + outprint(" Number of disambiguation pages = %s" % self.num_disambig) + outprint(" Number of list articles = %s" % self.num_list) + outprint(" Statistics by split:") + output_reverse_sorted_table(self.num_by_split, indent=" ") + outprint(" Statistics by namespace:") + output_reverse_sorted_table(self.num_by_namespace, indent=" ") + +class ArticleStatisticsSet(object): + def __init__(self, set_name): + self.set_name = set_name + self.stats_all = ArticleStatistics() + self.stats_redir = ArticleStatistics() + self.stats_non_redir = ArticleStatistics() + self.stats_list_of = ArticleStatistics() + self.stats_disambig = ArticleStatistics() + self.stats_list = ArticleStatistics() + self.stats_non_list = ArticleStatistics() + self.stats_by_split = {} + self.stats_by_namespace = {} + + def record_article(self, art): + def record_by_value(art, table, value): + if value not in table: + table[value] = ArticleStatistics() + table[value].record_article(art) + + self.stats_all.record_article(art) + if art.redir: + self.stats_redir.record_article(art) + else: + self.stats_non_redir.record_article(art) + if art.is_list_of: + self.stats_list_of.record_article(art) + if art.is_disambig: + self.stats_disambig.record_article(art) + if art.is_list: + self.stats_list.record_article(art) + else: + self.stats_non_list.record_article(art) + record_by_value(art, self.stats_by_split, art.split) + record_by_value(art, self.stats_by_namespace, art.namespace) + + def output_stats(self): + def out(stats, stats_type): + stats.output_stats("%s (%s)" % (stats_type, self.set_name)) + out(self.stats_all, "all articles") + out(self.stats_redir, "redirect articles") + out(self.stats_non_redir, "non-redirect articles") + out(self.stats_list_of, "'List of' articles") + out(self.stats_disambig, "disambiguation pages") + out(self.stats_list, "list articles") + for (split, stat) in self.stats_by_split.iteritems(): + out(stat, "articles in %s split" % split) + for (namespace, stat) in self.stats_by_namespace.iteritems(): + out(stat, "articles in namespace %s" % namespace) + uniprint("") + +Stats_set_all = ArticleStatisticsSet("all articles") +Stats_set_non_redir = ArticleStatisticsSet("non-redirect articles") +Stats_set_main_non_redir = ArticleStatisticsSet("non-redirect articles, namespace Main") +Stats_set_main_non_redir_non_list = ArticleStatisticsSet("non-redirect articles, namespace Main, non-list articles") + +def note_article_for_global_stats(art): + Stats_set_all.record_article(art) + if not art.redir: + Stats_set_non_redir.record_article(art) + if art.namespace == 'Main': + Stats_set_main_non_redir.record_article(art) + if not art.is_list: + Stats_set_main_non_redir_non_list.record_article(art) + +def output_global_stats(): + Stats_set_all.output_stats() + Stats_set_non_redir.output_stats() + Stats_set_main_non_redir.output_stats() + Stats_set_main_non_redir_non_list.output_stats() + +def generate_article_stats(filename): + def process(art): + note_article_for_global_stats(art) + pad.read_article_data_file(filename, process=process, + maxtime=Opts.max_time_per_stage) + output_global_stats() + +############################################################################ +# Main code # +############################################################################ + +class ArticleStatisticsProgram(NLPProgram): + def argument_usage(self): + return "article-data-file" + + def handle_arguments(self, opts, op, args): + global Opts + Opts = opts + if len(args) != 1: + op.error("Must specify exactly one article-data file as an argument") + + def implement_main(self, opts, params, args): + generate_article_stats(args[0]) + +ArticleStatisticsProgram() diff --git a/src/main/python/convert-infochimps.py b/src/main/python/convert-infochimps.py new file mode 100644 index 0000000..8945878 --- /dev/null +++ b/src/main/python/convert-infochimps.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +""" + +Steps for converting Infochimps to our format: + +1) Input is a series of files, e.g. part-00000.gz, each about 180 MB. +2) Each line looks like this: + + +100000018081132545 20110807002716 25430513 GTheHardWay Niggas Lost in the Sauce ..smh better slow yo roll and tell them hoes to get a job nigga #MRIloveRatsIcanchange&amp;saveherassNIGGA <a href="http://twitter.com/download/android" rel="nofollow">Twitter for Android</a> en 42.330165 -83.045913 +The fields are: + +1) Tweet ID +2) Time +3) User ID +4) User name +5) Empty? +6) User name being replied to (FIXME: which JSON field is this?) +7) User ID for replied-to user name (but sometimes different ID's for same user name) +8) Empty? +9) Tweet text -- double HTML-encoded (e.g. & becomes &amp;) +10) HTML anchor text indicating a link of some sort, HTML-encoded (FIXME: which JSON field is this?) +11) Language, as a two-letter code +12) Latitude +13) Longitude +14) Empty? +15) Empty? +16) Empty? +17) Empty? + + +3) We want to convert each to two files: (1) containing the article-data + +""" + + diff --git a/src/main/python/convert_to_new_article_format.py b/src/main/python/convert_to_new_article_format.py new file mode 100755 index 0000000..35c8d4f --- /dev/null +++ b/src/main/python/convert_to_new_article_format.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +####### +####### convert_to_new_article_format.py +####### +####### Copyright (c) 2011 Ben Wing. +####### + +# A one-off program to convert article-data files to the new order, which +# puts the most important article fields (id, name, split, coords) before +# other fields that may be specific to the type of article (e.g. Wikipedia +# article, Twitter feed, tweet, etc.). + +import sys +from nlputil import * +from process_article_data import * + +def output_combined_article_data(filename): + arts_seen = [] + def note_article(art): + arts_seen.append(art) + # Note that the article data file indicates the field names at the + # beginning. + read_article_data_file(filename, note_article) + errprint("Writing combined data to stdout ...") + write_article_data_file(sys.stdout, + outfields = combined_article_data_outfields, + articles = arts_seen) + errprint("Done.") + +for filename in sys.argv[1:]: + output_combined_article_data(filename) diff --git a/src/main/python/find-first-tweet-time.py b/src/main/python/find-first-tweet-time.py new file mode 100755 index 0000000..5cce3b3 --- /dev/null +++ b/src/main/python/find-first-tweet-time.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +####### +####### find-first-tweet-time.py +####### +####### Copyright (c) 2012 Ben Wing. +####### + +import sys, re +import math +import fileinput +from subprocess import * +from nlputil import * +import itertools +import time +import os.path +import traceback +import calendar + +############################################################################ +# Quick Start # +############################################################################ + +# This program reads in data from the specified (possibly bzipped) files, +# outputting the time of the first tweet in each as a time in milliseconds +# since the Unix Epoch. (Redirect stdout but not stderr to get these +# values independent of status messages.) + +####################################################################### +# Process files # +####################################################################### + +def process_file(infile): + inproc = None + desc = None + try: + if infile.endswith(".bz2"): + errprint("Opening compressed input %s..." % infile) + # close_fds is necessary to avoid deadlock in certain circumstances + # (see split_bzip.py). Probably not here but won't hurt. + inproc = Popen("bzcat", stdin=open(infile, "rb"), stdout=PIPE, close_fds=True) + desc = inproc.stdout + else: + errprint("Opening input %s..." % infile) + desc = open(infile, "rb") + lineno = 0 + for full_line in desc: + lineno += 1 + line = full_line[:-1] + if not line.startswith('{"'): + errprint("%s: Unparsable line, not JSON?, #%s: %s" % (infile, lineno, line)) + else: + json = None + try: + json = split_json(line) + except Exception, exc: + errprint("%s: Exception parsing JSON in line #%s: %s" % (infile, lineno, line)) + errprint("Exception is %s" % exc) + traceback.print_exc() + if json: + json = json[0] + #errprint("Processing JSON %s:" % json) + #errprint("Length: %s" % len(json)) + for i in xrange(len(json)): + #errprint("Saw %s=%s" % (i, json[i])) + if json[i] == '"created_at"': + #errprint("Saw created") + if i + 2 >= len(json) or json[i+1] != ':' or json[i+2][0] != '"' or json[i+2][-1] != '"': + errprint("%s: Something weird with JSON in line #%s, around here: %s" % (infile, lineno, json[i-1:i+4])) + else: + json_time = json[i+2][1:-1].replace(" +0000 ", " UTC ") + tweet_time = time.strptime(json_time, + "%a %b %d %H:%M:%S %Z %Y") + if not tweet_time: + errprint("%s: Can't parse time in line #%s: %s" % (infile, lineno, json_time)) + else: + print "%s\t%s" % (infile, calendar.timegm(tweet_time)*1000L) + return + finally: + if inproc: + inproc.kill() + inproc.wait() + if desc: desc.close() + +# A very simple JSON splitter. Doesn't take the next step of assembling +# into dictionaries, but easily could. +# +# FIXME: This is totally unnecessary, as Python has a built-in JSON parser. +# (I didn't realize this when I wrote the function.) +def split_json(line): + split = re.split(r'("(?:\\.|[^"])*?"|[][:{},])', line) + split = (x for x in split if x) # Filter out empty strings + curind = 0 + def get_nested(endnest): + nest = [] + try: + while True: + item = next(split) + if item == endnest: + return nest + elif item == '{': + nest += [get_nested('}')] + elif item == '[': + nest += [get_nested(']')] + else: + nest += [item] + except StopIteration: + if not endnest: + return nest + else: + raise + return get_nested(None) + +####################################################################### +# Main code # +####################################################################### + +def main(): + op = OptionParser(usage="%prog [options] FILES ...") + + opts, args = op.parse_args() + + if not args: + op.error("No input files specified") + + for infile in args: + process_file(infile) + +main() diff --git a/src/main/python/fix_redirects.py b/src/main/python/fix_redirects.py new file mode 100755 index 0000000..c0e859b --- /dev/null +++ b/src/main/python/fix_redirects.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +####### +####### fix_redirects.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +# This is a one-off program to fix the redirect fields so that they +# always begin with a capital letter, as is required of article titles in +# Wikipedia. + +import sys +from nlputil import * +from process_article_data import * + +def fix_redirects(filename): + articles_seen = [] + def process(art): + if art.redir: + art.redir = capfirst(art.redir) + articles_seen.append(art) + errprint("Reading from %s..." % filename) + fields = read_article_data_file(filename, process, + maxtime=Opts.max_time_per_stage) + errprint("Writing to stdout...") + write_article_data_file(sys.stdout, outfields=fields, + articles=articles_seen) + errprint("Done.") + +############################################################################ +# Main code # +############################################################################ + +class FixRedirectsProgram(NLPProgram): + def argument_usage(self): + return "article-data-file" + + def handle_arguments(self, opts, op, args): + global Opts + Opts = opts + if len(args) != 1: + op.error("Must specify exactly one article-data file as an argument") + + def implement_main(self, opts, params, args): + fix_redirects(args[0]) + +FixRedirectsProgram() diff --git a/src/main/python/format-thresh-grid.py b/src/main/python/format-thresh-grid.py new file mode 100755 index 0000000..56c1680 --- /dev/null +++ b/src/main/python/format-thresh-grid.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python + +from nlputil import * +import fileinput +import re + +switch_thresh_and_grid = True + +errdists = dictdict() + +for line in fileinput.input(): + line = line.strip() + m = re.match(r'.*thresh: ([0-9.]*), grid: ([0-9.]*),.*true error distance.*\(([0-9.]*) km.*', line) + if not m: + errprint("Can't parse line: %s", line) + else: + thresh = float(m.group(1)) + grid = float(m.group(2)) + dist = float(m.group(3)) + if switch_thresh_and_grid: + errdists[grid][thresh] = dist + else: + errdists[thresh][grid] = dist + +first = True +for (thresh, dic) in key_sorted_items(errdists): + if first: + first = False + errprint(r" & %s \\" % ( + ' & '.join(["%g" % grid for grid in sorted(dic.keys())]))) + errprint(r"\hline") + errprint(r"%g & %s \\" % (thresh, + ' & '.join(["%g" % dist for (grid, dist) in key_sorted_items(dic)]))) + diff --git a/src/main/python/generate-numbers.py b/src/main/python/generate-numbers.py new file mode 100755 index 0000000..4404ed4 --- /dev/null +++ b/src/main/python/generate-numbers.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +from nlputil import * +import fileinput +import re + +def median(values): + values = sorted(values) + vallen = len(values) + + if vallen % 2 == 1: + return values[(vallen+1)/2-1] + else: + lower = values[vallen/2-1] + upper = values[vallen/2] + + return (float(lower + upper)) / 2 + +def mean(values): + return float(sum(values))/len(values) + +def tokm(val): + return 1.609*val + +vals = [] +for line in fileinput.input(): + line = line.strip() + m = re.match(r'.*Distance (.*?) miles to predicted region center', line) + if not m: + errprint("Can't parse line: %s", line) + else: + vals += [float(m.group(1))] + +med = median(vals) +mn = mean(vals) +uniprint("Median: %g miles (%g km)" % (med, tokm(med))) +uniprint("Mean: %g miles (%g km)" % (mn, tokm(mn))) diff --git a/src/main/python/generate_combined.py b/src/main/python/generate_combined.py new file mode 100755 index 0000000..a732800 --- /dev/null +++ b/src/main/python/generate_combined.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +####### +####### generate_combined.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +# Essentially does an SQL JOIN on the --article-data-file, --coords-file +# and --links-file, keeping only the article records with a coordinate and +# attaching the coordinate. + +import sys +from nlputil import * +from process_article_data import * + +def read_incoming_link_info(filename, articles_hash): + errprint("Reading incoming link info from %s..." % filename) + status = StatusMessage('article') + + for line in uchompopen(filename): + if rematch('------------------ Count of incoming links: ------------', + line): continue + elif rematch('==========================================', line): + return + else: + assert rematch('(.*) = ([0-9]+)$', line) + title = m_[1] + links = int(m_[2]) + title = capfirst(title) + art = articles_hash.get(title, None) + if art: + art.incoming_links = int(links) + if status.item_processed(maxtime=Opts.max_time_per_stage): + break + +# Parse the result of a previous run of --only-coords or coords-counts for +# articles with coordinates. +def read_coordinates_file(filename): + errprint("Reading article coordinates from %s..." % filename) + status = StatusMessage('article') + coords_hash = {} + for line in uchompopen(filename): + if rematch('Article title: (.*)$', line): + title = m_[1] + elif rematch('Article coordinates: (.*),(.*)$', line): + coords_hash[title] = Coord(safe_float(m_[1]), safe_float(m_[2])) + if status.item_processed(maxtime=Opts.max_time_per_stage): + break + return coords_hash + +def output_combined_article_data(filename, coords_file, links_file): + coords_hash = read_coordinates_file(coords_file) + articles_hash = {} + articles_seen = [] + + def process(art): + if art.namespace != 'Main': + return + coord = coords_hash.get(art.title, None) + if coord: + art.coord = coord + elif art.redir and capfirst(art.redir) in coords_hash: + pass + else: + return + articles_hash[art.title] = art + articles_seen.append(art) + read_article_data_file(filename, process, maxtime=Opts.max_time_per_stage) + + read_incoming_link_info(links_file, articles_hash) + + errprint("Writing combined data to stdout ...") + write_article_data_file(sys.stdout, + outfields = combined_article_data_outfields, + articles = articles_seen) + errprint("Done.") + +############################################################################ +# Main code # +############################################################################ + +class GenerateCombinedProgram(NLPProgram): + def populate_options(self, op): + op.add_option("-l", "--links-file", + help="""File containing incoming link information for +Wikipedia articles. Output by processwiki.py --find-links.""", + metavar="FILE") + op.add_option("-a", "--article-data-file", + help="""File containing info about Wikipedia articles.""", + metavar="FILE") + op.add_option("-c", "--coords-file", + help="""File containing output from a prior run of +--coords-counts or --only-coords, listing all the articles with associated +coordinates. May be filtered only for articles and coordinates.""", + metavar="FILE") + + def handle_arguments(self, opts, op, args): + global Opts + Opts = opts + self.need('article_data_file') + self.need('coords_file') + self.need('links_file') + + def implement_main(self, opts, params, args): + output_combined_article_data(opts.article_data_file, opts.coords_file, + opts.links_file) + +GenerateCombinedProgram() diff --git a/src/main/python/ner/DummyNER.py b/src/main/python/ner/DummyNER.py new file mode 100644 index 0000000..0419dac --- /dev/null +++ b/src/main/python/ner/DummyNER.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Any Python NER component should provide a function named "recognize" that +# takes a list of string tokens and returns a list of tuples indicating the +# beginning and ending of named entity ranges in the token list. This simple +# example considers all capitalized words to be single-token named entities. +def recognize(tokens): + spans = [] + for i in range(0, len(tokens)): + if tokens[i].isalnum() and tokens[i] == tokens[i].capitalize(): + spans.append((i, i + 1)) + return spans + +#text = ["They", "had", "sailed", "from", "Deptford", +# ",", "from", "Greenwich", ",", "from", "Erith", "..."] + +#print(recognize(text)) + diff --git a/src/main/python/ner/stanford2places.py b/src/main/python/ner/stanford2places.py new file mode 100644 index 0000000..7e5cdaa --- /dev/null +++ b/src/main/python/ner/stanford2places.py @@ -0,0 +1,26 @@ +import sys, re + +locationString = re.compile(r'(\w*/LOCATION(\s*\w*/LOCATION)*)') + +def processFile(filename): + inFile = open(filename,'r') + curLine = inFile.readline() + while(curLine != ""): + for ls in locationString.findall(curLine): + lineToPrint = ls[0].replace("/LOCATION", "").strip() + if(len(lineToPrint) > 0): + print lineToPrint + curLine = inFile.readline() + +def processDirectory(dirname): + fileList = os.listdir(dirname) + if(not dirname[-1] == "/"): + dirname += "/" + for filename in fileList: + if(os.path.isdir(dirname + filename)): + processDirectory(dirname + filename) + elif(os.path.isfile(dirname + filename)): + processFile(dirname + filename) + +for filename in sys.argv[1:]: + processFile(filename) diff --git a/src/main/python/nlputil.py b/src/main/python/nlputil.py new file mode 100644 index 0000000..44fa24b --- /dev/null +++ b/src/main/python/nlputil.py @@ -0,0 +1,1328 @@ +from __future__ import with_statement # For chompopen(), uchompopen() +from optparse import OptionParser +from itertools import * +import itertools +import re # For regexp wrappers +import sys, codecs # For uchompopen() +import math # For float_with_commas() +import bisect # For sorted lists +import time # For status messages, resource usage +from heapq import * # For priority queue +import UserDict # For SortedList, LRUCache +import resource # For resource usage +from collections import deque # For breadth-first search +from subprocess import * # For backquote +from errno import * # For backquote +import os # For get_program_memory_usage_ps() +import os.path # For exists of /proc, etc. +import fileinput # For uchompopen() etc. + +############################################################################# +# Regular expression functions # +############################################################################# + + +#### Some simple wrappers around basic text-processing Python functions to +#### make them easier to use. +#### +#### 1. rematch() and research(): +#### +#### The functions 'rematch' and 'research' are wrappers around re.match() +#### and re.search(), respectively, but instead of returning a match object +#### None, they return True or False, and in the case a match object would +#### have been returned, a corresponding WreMatch object is stored in the +#### global variable m_. Groups can be accessed from this variable using +#### m_.group() or m_.groups(), but they can also be accessed through direct +#### subscripting, i.e. m_[###] = m_.group(###). + +class WreMatch(object): + def setmatch(self, match): + self.match = match + + def groups(self, *foo): + return self.match.groups(*foo) + + def group(self, *foo): + return self.match.group(*foo) + + def __getitem__(self, key): + return self.match.group(key) + +m_ = WreMatch() + +def rematch(pattern, string, flags=0): + m = re.match(pattern, string, flags) + if m: + m_.setmatch(m) + return True + return False + +def research(pattern, string, flags=0): + global m_ + m = re.search(pattern, string, flags) + if m: + m_.setmatch(m) + return True + return False + +############################################################################# +# File reading functions # +############################################################################# + +### NOTE NOTE NOTE: Only works on Python 2.5 and above, due to using the +### "with" statement. + +#### 1. chompopen(): +#### +#### A generator that yields lines from a file, with any terminating newline +#### removed (but no other whitespace removed). Ensures that the file +#### will be automatically closed under all circumstances. +#### +#### 2. uchompopen(): +#### +#### Same as chompopen() but specifically open the file as 'utf-8' and +#### return Unicode strings. + +#""" +#Test gopen +# +#import nlputil +#for line in nlputil.gopen("foo.txt"): +# print line +#for line in nlputil.gopen("foo.txt", chomp=True): +# print line +#for line in nlputil.gopen("foo.txt", encoding='utf-8'): +# print line +#for line in nlputil.gopen("foo.txt", encoding='utf-8', chomp=True): +# print line +#for line in nlputil.gopen("foo.txt", encoding='iso-8859-1'): +# print line +#for line in nlputil.gopen(["foo.txt"], encoding='iso-8859-1'): +# print line +#for line in nlputil.gopen(["foo.txt"], encoding='utf-8'): +# print line +#for line in nlputil.gopen(["foo.txt"], encoding='iso-8859-1', chomp=True): +# print line +#for line in nlputil.gopen(["foo.txt", "foo2.txt"], encoding='iso-8859-1', chomp=True): +# print line +#""" + +# General function for opening a file, with automatic closure after iterating +# through the lines. The encoding can be specified (e.g. 'utf-8'), and if so, +# the error-handling can be given. Whether to remove the final newline +# (chomp=True) can be specified. The filename can be either a regular +# filename (opened with open) or codecs.open(), or a list of filenames or +# None, in which case the argument is passed to fileinput.input() +# (if a non-empty list is given, opens the list of filenames one after the +# other; if an empty list is given, opens stdin; if None is given, takes +# list from the command-line arguments and proceeds as above). When using +# fileinput.input(), the arguments "inplace", "backup" and "bufsize" can be +# given, appropriate to that function (e.g. to do in-place filtering of a +# file). In all cases, +def gopen(filename, mode='r', encoding=None, errors='strict', chomp=False, + inplace=0, backup="", bufsize=0): + if isinstance(filename, basestring): + def yieldlines(): + if encoding is None: + mgr = open(filename) + else: + mgr = codecs.open(filename, mode, encoding=encoding, errors=errors) + with mgr as f: + for line in f: + yield line + iterator = yieldlines() + else: + if encoding is None: + openhook = None + else: + def openhook(filename, mode): + return codecs.open(filename, mode, encoding=encoding, errors=errors) + iterator = fileinput.input(filename, inplace=inplace, backup=backup, + bufsize=bufsize, mode=mode, openhook=openhook) + if chomp: + for line in iterator: + if line and line[-1] == '\n': line = line[:-1] + yield line + else: + for line in iterator: + yield line + +# Open a filename with UTF-8-encoded input and yield lines converted to +# Unicode strings, but with any terminating newline removed (similar to +# "chomp" in Perl). Basically same as gopen() but with defaults set +# differently. +def uchompopen(filename=None, mode='r', encoding='utf-8', errors='strict', + chomp=True, inplace=0, backup="", bufsize=0): + return gopen(filename, mode=mode, encoding=encoding, errors=errors, + chomp=chomp, inplace=inplace, backup=backup, bufsize=bufsize) + +# Open a filename and yield lines, but with any terminating newline +# removed (similar to "chomp" in Perl). Basically same as gopen() but +# with defaults set differently. +def chompopen(filename, mode='r', encoding=None, errors='strict', + chomp=True, inplace=0, backup="", bufsize=0): + return gopen(filename, mode=mode, encoding=encoding, errors=errors, + chomp=chomp, inplace=inplace, backup=backup, bufsize=bufsize) + +# Open a filename with UTF-8-encoded input. Basically same as gopen() +# but with defaults set differently. +def uopen(filename, mode='r', encoding='utf-8', errors='strict', + chomp=False, inplace=0, backup="", bufsize=0): + return gopen(filename, mode=mode, encoding=encoding, errors=errors, + chomp=chomp, inplace=inplace, backup=backup, bufsize=bufsize) + +############################################################################# +# Other basic utility functions # +############################################################################# + +def internasc(text): + '''Intern a string (for more efficient memory use, potentially faster lookup. +If string is Unicode, automatically convert to UTF-8.''' + if type(text) is unicode: text = text.encode("utf-8") + return intern(text) + +def uniprint(text, outfile=sys.stdout, nonl=False, flush=False): + '''Print text string using 'print', converting Unicode as necessary. +If string is Unicode, automatically convert to UTF-8, so it can be output +without errors. Send output to the file given in OUTFILE (default is +stdout). Uses the 'print' command, and normally outputs a newline; but +this can be suppressed using NONL. Output is not normally flushed (unless +the stream does this automatically); but this can be forced using FLUSH.''' + + if type(text) is unicode: + text = text.encode("utf-8") + if nonl: + print >>outfile, text, + else: + print >>outfile, text + if flush: + outfile.flush() + +def uniout(text, outfile=sys.stdout, flush=False): + '''Output text string, converting Unicode as necessary. +If string is Unicode, automatically convert to UTF-8, so it can be output +without errors. Send output to the file given in OUTFILE (default is +stdout). Uses the write() function, which outputs the text directly, +without adding spaces or newlines. Output is not normally flushed (unless +the stream does this automatically); but this can be forced using FLUSH.''' + + if type(text) is unicode: + text = text.encode("utf-8") + outfile.write(text) + if flush: + outfile.flush() + +def errprint(text, nonl=False): + '''Print text to stderr using 'print', converting Unicode as necessary. +If string is Unicode, automatically convert to UTF-8, so it can be output +without errors. Uses the 'print' command, and normally outputs a newline; but +this can be suppressed using NONL.''' + uniprint(text, outfile=sys.stderr, nonl=nonl) + +def errout(text): + '''Output text to stderr, converting Unicode as necessary. +If string is Unicode, automatically convert to UTF-8, so it can be output +without errors. Uses the write() function, which outputs the text directly, +without adding spaces or newlines.''' + uniout(text, outfile=sys.stderr) + +def warning(text): + '''Output a warning, formatting into UTF-8 as necessary''' + errprint("Warning: %s" % text) + +def safe_float(x): + '''Convert a string to floating point, but don't crash on errors; +instead, output a warning.''' + try: + return float(x) + except: + x = x.strip() + if x: + warning("Expected number, saw %s" % x) + return 0. + +def pluralize(word): + '''Pluralize an English word, using a basic but effective algorithm.''' + if word[-1] >= 'A' and word[-1] <= 'Z': upper = True + else: upper = False + lowerword = word.lower() + if re.match(r'.*[b-df-hj-np-tv-z]y$', lowerword): + if upper: return word[:-1] + 'IES' + else: return word[:-1] + 'ies' + elif re.match(r'.*([cs]h|[sx])$', lowerword): + if upper: return word + 'ES' + else: return word + 'es' + else: + if upper: return word + 'S' + else: return word + 's' + +def capfirst(st): + '''Capitalize the first letter of string, leaving the remainder alone.''' + if not st: return st + return st[0].capitalize() + st[1:] + +# From: http://stackoverflow.com/questions/1823058/how-to-print-number-with-commas-as-thousands-separators-in-python-2-x +def int_with_commas(x): + if type(x) not in [type(0), type(0L)]: + raise TypeError("Parameter must be an integer.") + if x < 0: + return '-' + int_with_commas(-x) + result = '' + while x >= 1000: + x, r = divmod(x, 1000) + result = ",%03d%s" % (r, result) + return "%d%s" % (x, result) + +# My own version +def float_with_commas(x): + intpart = int(math.floor(x)) + fracpart = x - intpart + return int_with_commas(intpart) + ("%.2f" % fracpart)[1:] + +def median(list): + "Return the median value of a sorted list." + l = len(list) + if l % 2 == 1: + return list[l // 2] + else: + l = l // 2 + return 0.5*(list[l-1] + list[l]) + +def mean(list): + "Return the mean of a list." + return sum(list) / float(len(list)) + +def split_text_into_words(text, ignore_punc=False, include_nl=False): + # This regexp splits on whitespace, but also handles the following cases: + # 1. Any of , ; . etc. at the end of a word + # 2. Parens or quotes in words like (foo) or "bar" + # These punctuation characters are returned as separate words, unless + # 'ignore_punc' is given. Also, if 'include_nl' is given, newlines are + # returned as their own words; otherwise, they are treated like all other + # whitespace (i.e. ignored). + if include_nl: + split_punc_re = r'[ \t]+' + else: + split_punc_re = r'\s+' + # The use of izip and cycle will pair True with return values that come + # from the grouping in the split re, and False with regular words. + for (ispunc, word) in izip(cycle([False, True]), + re.split('([,;."):]*\s+[("]*)', text)): + if not word: continue + if ispunc: + # Divide the punctuation up + for punc in word: + if punc == '\n': + if include_nl: yield punc + elif punc in ' \t\r\f\v': continue + elif not ignore_punc: yield punc + else: + yield word + +def fromto(fro, to, step=1): + if fro <= to: + step = abs(step) + to += 1 + else: + step = -abs(step) + to -= 1 + return xrange(fro, to, step) + +############################################################################# +# Default dictionaries # +############################################################################# + +# Our own version similar to collections.defaultdict(). The difference is +# that we can specify whether or not simply referencing an unseen key +# automatically causes the key to permanently spring into existence with +# the "missing" value. collections.defaultdict() always behaves as if +# 'add_upon_ref'=True, same as our default. Basically: +# +# foo = defdict(list, add_upon_ref=False) +# foo['bar'] -> [] +# 'bar' in foo -> False +# +# foo = defdict(list, add_upon_ref=True) +# foo['bar'] -> [] +# 'bar' in foo -> True +# +# The former may be useful where you may make many queries involving +# non-existent keys, and you don't want all these keys added to the dict. +# The latter is useful with mutable objects like lists. If I create +# +# foo = defdict(list, add_upon_ref=False) +# +# and then call +# +# foo['glorplebargle'].append('shazbat') +# +# where 'glorplebargle' is a previously non-existent key, the call to +# 'append' will "appear" to work but in fact nothing will happen, because +# the reference foo['glorplebargle'] will create a new list and return +# it, but not store it in the dict, and 'append' will add to this +# temporary list, which will soon disappear. Note that using += will +# actually work, but this is fragile behavior, not something to depend on. +# +class defdict(dict): + def __init__(self, factory, add_upon_ref=True): + super(defdict, self).__init__() + self.factory = factory + self.add_upon_ref = add_upon_ref + + def __missing__(self, key): + val = self.factory() + if self.add_upon_ref: + self[key] = val + return val + + +# A dictionary where asking for the value of a missing key causes 0 (or 0.0, +# etc.) to be returned. Useful for dictionaries that track counts of items. + +def intdict(): + return defdict(int, add_upon_ref=False) + +def floatdict(): + return defdict(float, add_upon_ref=False) + +def booldict(): + return defdict(float, add_upon_ref=False) + +# Similar but the default value is an empty collection. We set +# 'add_upon_ref' to True whenever the collection is mutable; see comments +# above. + +def listdict(): + return defdict(list, add_upon_ref=True) + +def strdict(): + return defdict(str, add_upon_ref=False) + +def dictdict(): + return defdict(dict, add_upon_ref=True) + +def tupledict(): + return defdict(tuple, add_upon_ref=False) + +def setdict(): + return defdict(set, add_upon_ref=True) + +############################################################################# +# Sorted lists # +############################################################################# + +# Return a tuple (keys, values) of lists of items corresponding to a hash +# table. Stored in sorted order according to the keys. Use +# lookup_sorted_list(key) to find the corresponding value. The purpose of +# doing this, rather than just directly using a hash table, is to save +# memory. + +def make_sorted_list(table): + items = sorted(table.items(), key=lambda x:x[0]) + keys = ['']*len(items) + values = ['']*len(items) + for i in xrange(len(items)): + item = items[i] + keys[i] = item[0] + values[i] = item[1] + return (keys, values) + +# Given a sorted list in the tuple form (KEYS, VALUES), look up the item KEY. +# If found, return the corresponding value; else return None. + +def lookup_sorted_list(sorted_list, key, default=None): + (keys, values) = sorted_list + i = bisect.bisect_left(keys, key) + if i != len(keys) and keys[i] == key: + return values[i] + return default + +# A class that provides a dictionary-compatible interface to a sorted list + +class SortedList(object, UserDict.DictMixin): + def __init__(self, table): + self.sorted_list = make_sorted_list(table) + + def __len__(self): + return len(self.sorted_list[0]) + + def __getitem__(self, key): + retval = lookup_sorted_list(self.sorted_list, key) + if retval is None: + raise KeyError(key) + return retval + + def __contains__(self, key): + return lookup_sorted_list(self.sorted_list, key) is not None + + def __iter__(self): + (keys, values) = self.sorted_list + for x in keys: + yield x + + def keys(self): + return self.sorted_list[0] + + def itervalues(self): + (keys, values) = self.sorted_list + for x in values: + yield x + + def iteritems(self): + (keys, values) = self.sorted_list + for (key, value) in izip(keys, values): + yield (key, value) + +############################################################################# +# Table Output # +############################################################################# + +def key_sorted_items(d): + return sorted(d.iteritems(), key=lambda x:x[0]) + +def value_sorted_items(d): + return sorted(d.iteritems(), key=lambda x:x[1]) + +def reverse_key_sorted_items(d): + return sorted(d.iteritems(), key=lambda x:x[0], reverse=True) + +def reverse_value_sorted_items(d): + return sorted(d.iteritems(), key=lambda x:x[1], reverse=True) + +# Given a list of tuples, where the second element of the tuple is a number and +# the first a key, output the list, sorted on the numbers from bigger to +# smaller. Within a given number, sort the items alphabetically, unless +# keep_secondary_order is True, in which case the original order of items is +# left. If 'outfile' is specified, send output to this stream instead of +# stdout. If 'indent' is specified, indent all rows by this string (usually +# some number of spaces). If 'maxrows' is specified, output at most this many +# rows. +def output_reverse_sorted_list(items, outfile=sys.stdout, indent="", + keep_secondary_order=False, maxrows=None): + if not keep_secondary_order: + items = sorted(items, key=lambda x:x[0]) + items = sorted(items, key=lambda x:x[1], reverse=True) + if maxrows: + items = items[0:maxrows] + for key, value in items: + uniprint("%s%s = %s" % (indent, key, value), outfile=outfile) + +# Given a table with values that are numbers, output the table, sorted +# on the numbers from bigger to smaller. Within a given number, sort the +# items alphabetically, unless keep_secondary_order is True, in which case +# the original order of items is left. If 'outfile' is specified, send +# output to this stream instead of stdout. If 'indent' is specified, indent +# all rows by this string (usually some number of spaces). If 'maxrows' +# is specified, output at most this many rows. +def output_reverse_sorted_table(table, outfile=sys.stdout, indent="", + keep_secondary_order=False, maxrows=None): + output_reverse_sorted_list(table.iteritems()) + +############################################################################# +# Status Messages # +############################################################################# + +# Output status messages periodically, at some multiple of +# 'secs_between_output', measured in real time. 'item_name' is the name +# of the items being processed. Every time an item is processed, call +# item_processed() +class StatusMessage(object): + def __init__(self, item_name, secs_between_output=15): + self.item_name = item_name + self.plural_item_name = pluralize(item_name) + self.secs_between_output = secs_between_output + self.items_processed = 0 + self.first_time = time.time() + self.last_time = self.first_time + + def num_processed(self): + return self.items_processed + + def elapsed_time(self): + return time.time() - self.first_time + + def item_unit(self): + if self.items_processed == 1: + return self.item_name + else: + return self.plural_item_name + + def item_processed(self, maxtime=0): + curtime = time.time() + self.items_processed += 1 + total_elapsed_secs = int(curtime - self.first_time) + last_elapsed_secs = int(curtime - self.last_time) + if last_elapsed_secs >= self.secs_between_output: + # Rather than directly recording the time, round it down to the nearest + # multiple of secs_between_output; else we will eventually see something + # like 0, 15, 45, 60, 76, 91, 107, 122, ... + # rather than + # like 0, 15, 45, 60, 76, 90, 106, 120, ... + rounded_elapsed = (int(total_elapsed_secs / self.secs_between_output) * + self.secs_between_output) + self.last_time = self.first_time + rounded_elapsed + errprint("Elapsed time: %s minutes %s seconds, %s %s processed" + % (int(total_elapsed_secs / 60), total_elapsed_secs % 60, + self.items_processed, self.item_unit())) + if maxtime and total_elapsed_secs >= maxtime: + errprint("Maximum time reached, interrupting processing after %s %s" + % (self.items_processed, self.item_unit())) + return True + return False + +############################################################################# +# File Splitting # +############################################################################# + +# Return the next file to output to, when the instances being output to the +# files are meant to be split according to SPLIT_FRACTIONS. The absolute +# quantities in SPLIT_FRACTIONS don't matter, only the values relative to +# the other values, i.e. [20, 60, 10] is the same as [4, 12, 2]. This +# function implements an algorithm that is deterministic (same results +# each time it is run), and spreads out the instances as much as possible. +# For example, if all values are equal, it will cycle successively through +# the different split files; if the values are [1, 1.5, 1], the output +# will be [1, 2, 3, 2, 1, 2, 3, ...]; etc. + +def next_split_set(split_fractions): + + num_splits = len(split_fractions) + cumulative_articles = [0]*num_splits + + # Normalize so that the smallest value is 1. + + minval = min(split_fractions) + split_fractions = [float(val)/minval for val in split_fractions] + + # The algorithm used is as follows. We cycle through the output sets in + # order; each time we return a set, we increment the corresponding + # cumulative count, but before returning a set, we check to see if the + # count has reached the corresponding fraction and skip this set if so. + # If we have run through an entire cycle without returning any sets, + # then for each set we subtract the fraction value from the cumulative + # value. This way, if the fraction value is not a whole number, then + # any fractional quantity (e.g. 0.6 for a value of 7.6) is left over, + # any will ensure that the total ratios still work out appropriately. + + while True: + this_output = False + for j in xrange(num_splits): + #print "j=%s, this_output=%s" % (j, this_output) + if cumulative_articles[j] < split_fractions[j]: + yield j + cumulative_articles[j] += 1 + this_output = True + if not this_output: + for j in xrange(num_splits): + while cumulative_articles[j] >= split_fractions[j]: + cumulative_articles[j] -= split_fractions[j] + +############################################################################# +# NLP Programs # +############################################################################# + +def output_option_parameters(opts, params=None): + errprint("Parameter values:") + for opt in dir(opts): + if not opt.startswith('_') and opt not in \ + ['ensure_value', 'read_file', 'read_module']: + errprint("%30s: %s" % (opt, getattr(opts, opt))) + if params: + for opt in dir(params): + if not opt.startswith('_'): + errprint("%30s: %s" % (opt, getattr(params, opt))) + errprint("") + +class NLPProgram(object): + def __init__(self): + if self.run_main_on_init(): + self.main() + + def implement_main(self, opts, params, args): + pass + + def populate_options(self, op): + pass + + def handle_arguments(self, opts, op, args): + pass + + def argument_usage(self): + return "" + + def get_usage(self): + argusage = self.argument_usage() + if argusage: + argusage = ' ' + argusage + return "%%prog [options]%s" % argusage + + def run_main_on_init(self): + return True + + def populate_shared_options(self, op): + op.add_option("--max-time-per-stage", "--mts", type='int', default=0, + help="""Maximum time per stage in seconds. If 0, no limit. +Used for testing purposes. Default %default.""") + op.add_option("-d", "--debug", metavar="FLAGS", + help="Output debug info of the given types (separated by spaces or commas)") + + def need(self, arg, arg_english=None): + if not arg_english: + arg_english=arg.replace('_', ' ') + if not getattr(self.opts, arg): + self.op.error("Must specify %s using --%s" % + (arg_english, arg.replace('_', '-'))) + + def main(self): + errprint("Beginning operation at %s" % (time.ctime())) + + self.op = OptionParser(usage="%prog [options]") + self.populate_shared_options(self.op) + self.canon_options = self.populate_options(self.op) + + errprint("Arguments: %s" % ' '.join(sys.argv)) + + ### Process the command-line options and set other values from them ### + + self.opts, self.args = self.op.parse_args() + # If a mapper for canonicalizing options is given, apply the mappings + if self.canon_options: + for (opt, mapper) in self.canon_options.iteritems(): + val = getattr(self.opts, opt) + if isinstance(val, list): + # If a list, then map all members, in case of multi-options + val = [mapper.get(x, x) for x in val] + setattr(self.opts, opt, val) + elif val in mapper: + setattr(self.opts, opt, mapper[val]) + + params = self.handle_arguments(self.opts, self.op, self.args) + + output_option_parameters(self.opts, params) + + retval = self.implement_main(self.opts, params, self.args) + errprint("Ending operation at %s" % (time.ctime())) + return retval + +############################################################################# +# Priority Queues # +############################################################################# + +# Priority queue implementation, based on Python heapq documentation. +# Note that in Python 2.6 and on, there is a priority queue implementation +# in the Queue module. +class PriorityQueue(object): + INVALID = 0 # mark an entry as deleted + + def __init__(self): + self.pq = [] # the priority queue list + self.counter = itertools.count(1) # unique sequence count + self.task_finder = {} # mapping of tasks to entries + + def add_task(self, priority, task, count=None): + if count is None: + count = self.counter.next() + entry = [priority, count, task] + self.task_finder[task] = entry + heappush(self.pq, entry) + + #Return the top-priority task. If 'return_priority' is false, just + #return the task itself; otherwise, return a tuple (task, priority). + def get_top_priority(self, return_priority=False): + while True: + priority, count, task = heappop(self.pq) + if count is not PriorityQueue.INVALID: + del self.task_finder[task] + if return_priority: + return (task, priority) + else: + return task + + def delete_task(self, task): + entry = self.task_finder[task] + entry[1] = PriorityQueue.INVALID + + def reprioritize(self, priority, task): + entry = self.task_finder[task] + self.add_task(priority, task, entry[1]) + entry[1] = PriorityQueue.INVALID + +############################################################################# +# Least-recently-used (LRU) Caches # +############################################################################# + +class LRUCache(object, UserDict.DictMixin): + def __init__(self, maxsize=1000): + self.cache = {} + self.pq = PriorityQueue() + self.maxsize = maxsize + self.time = 0 + + def __len__(self): + return len(self.cache) + + def __getitem__(self, key): + if key in self.cache: + time = self.time + self.time += 1 + self.pq.reprioritize(time, key) + return self.cache[key] + + def __delitem__(self, key): + del self.cache[key] + self.pq.delete_task(key) + + def __setitem__(self, key, value): + time = self.time + self.time += 1 + if key in self.cache: + self.pq.reprioritize(time, key) + else: + while len(self.cache) >= self.maxsize: + delkey = self.pq.get_top_priority() + del self.cache[delkey] + self.pq.add_task(time, key) + self.cache[key] = value + + def keys(self): + return self.cache.keys() + + def __contains__(self, key): + return key in self.cache + + def __iter__(self): + return self.cache.iterkeys() + + def iteritems(self): + return self.cache.iteritems() + +############################################################################# +# Resource Usage # +############################################################################# + +beginning_prog_time = time.time() + +def get_program_time_usage(): + return time.time() - beginning_prog_time + +def get_program_memory_usage(): + if os.path.exists("/proc/self/status"): + return get_program_memory_usage_proc() + else: + try: + return get_program_memory_usage_ps() + except: + return get_program_memory_rusage() + + +def get_program_memory_usage_rusage(): + res = resource.getrusage(resource.RUSAGE_SELF) + # FIXME! This is "maximum resident set size". There are other more useful + # values, but on the Mac at least they show up as 0 in this structure. + # On Linux, alas, all values show up as 0 or garbage (e.g. negative). + return res.ru_maxrss + +# Get memory usage by running 'ps'; getrusage() doesn't seem to work very +# well. The following seems to work on both Mac OS X and Linux, at least. +def get_program_memory_usage_ps(): + pid = os.getpid() + input = backquote("ps -p %s -o rss" % pid) + lines = re.split(r'\n', input) + for line in lines: + if line.strip() == 'RSS': continue + return 1024*int(line.strip()) + +# Get memory usage by running 'proc'; this works on Linux and doesn't require +# spawning a subprocess, which can crash when your program is very large. +def get_program_memory_usage_proc(): + with open("/proc/self/status") as f: + for line in f: + line = line.strip() + if line.startswith('VmRSS:'): + rss = int(line.split()[1]) + return 1024*rss + return 0 + +def format_minutes_seconds(secs): + mins = int(secs / 60) + secs = secs % 60 + hours = int(mins / 60) + mins = mins % 60 + if hours > 0: + hourstr = "%s hour%s " % (hours, "" if hours == 1 else "s") + else: + hourstr = "" + secstr = "%s" % secs if type(secs) is int else "%0.1f" % secs + return hourstr + "%s minute%s %s second%s" % ( + mins, "" if mins == 1 else "s", + secstr, "" if secs == 1 else "s") + +def output_resource_usage(): + errprint("Total elapsed time since program start: %s" % + format_minutes_seconds(get_program_time_usage())) + errprint("Memory usage: %s bytes" % + int_with_commas(get_program_memory_usage())) + +############################################################################# +# Hash tables by range # +############################################################################# + +# A table that groups all keys in a specific range together. Instead of +# directly storing the values for a group of keys, we store an object (termed a +# "collector") that the user can use to keep track of the keys and values. +# This way, the user can choose to use a list of values, a set of values, a +# table of keys and values, etc. + +class TableByRange(object): + # Create a new object. 'ranges' is a sorted list of numbers, indicating the + # boundaries of the ranges. One range includes all keys that are + # numerically below the first number, one range includes all keys that are + # at or above the last number, and there is a range going from each number + # up to, but not including, the next number. 'collector' is used to create + # the collectors used to keep track of keys and values within each range; + # it is either a type or a no-argument factory function. We only create + # ranges and collectors as needed. 'lowest_bound' is the value of the + # lower bound of the lowest range; default is 0. This is used only + # it iter_ranges() when returning the lower bound of the lowest range, + # and can be an item of any type, e.g. the number 0, the string "-infinity", + # etc. + def __init__(self, ranges, collector, lowest_bound=0): + self.ranges = ranges + self.collector = collector + self.lowest_bound = lowest_bound + self.items_by_range = {} + + def get_collector(self, key): + lower_range = self.lowest_bound + # upper_range = 'infinity' + for i in self.ranges: + if i <= key: + lower_range = i + else: + # upper_range = i + break + if lower_range not in self.items_by_range: + self.items_by_range[lower_range] = self.collector() + return self.items_by_range[lower_range] + + def iter_ranges(self, unseen_between=True, unseen_all=False): + """Return an iterator over ranges in the table. Each returned value is +a tuple (LOWER, UPPER, COLLECTOR), giving the lower and upper bounds +(inclusive and exclusive, respectively), and the collector item for this +range. The lower bound of the lowest range comes from the value of +'lowest_bound' specified during creation, and the upper bound of the range +that is higher than any numbers specified during creation in the 'ranges' +list will be the string "infinity" is such a range is returned. + +The optional arguments 'unseen_between' and 'unseen_all' control the +behavior of this iterator with respect to ranges that have never been seen +(i.e. no keys in this range have been passed to 'get_collector'). If +'unseen_all' is true, all such ranges will be returned; else if +'unseen_between' is true, only ranges between the lowest and highest +actually-seen ranges will be returned.""" + highest_seen = None + for (lower, upper) in ( + izip(chain([self.lowest_bound], self.ranges), + chain(self.ranges, ['infinity']))): + if lower in self.items_by_range: + highest_seen = upper + + seen_any = False + for (lower, upper) in ( + izip(chain([self.lowest_bound], self.ranges), + chain(self.ranges, ['infinity']))): + collector = self.items_by_range.get(lower, None) + if collector is None: + if not unseen_all: + if not unseen_between: continue + if not seen_any: continue + if upper == 'infinity' or upper > highest_seen: continue + collector = self.collector() + else: + seen_any = True + yield (lower, upper, collector) + + +############################################################################# +# Depth-, breadth-first search # +############################################################################# + +# General depth-first search. 'node' is the node to search, the top of a +# tree. 'matches' indicates whether a given node matches. 'children' +# returns a list of child nodes. +def depth_first_search(node, matches, children): + nodelist = [node] + while len(nodelist) > 0: + node = nodelist.pop() + if matches(node): + yield node + nodelist.extend(reversed(children(node))) + +# General breadth-first search. 'node' is the node to search, the top of a +# tree. 'matches' indicates whether a given node matches. 'children' +# returns a list of child nodes. +def breadth_first_search(node, matches, children): + nodelist = deque([node]) + while len(nodelist) > 0: + node = nodelist.popLeft() + if matches(node): + yield node + nodelist.extend(children(node)) + +############################################################################# +# Merge sequences # +############################################################################# + +# Return an iterator over all elements in all the given sequences, omitting +# elements seen more than once and keeping the order. +def merge_sequences_uniquely(*seqs): + table = {} + for seq in seqs: + for s in seq: + if s not in table: + table[s] = True + yield s + + +############################################################################# +# Subprocesses # +############################################################################# + +# Run the specified command; return its combined output and stderr as a string. +# 'command' can either be a string or a list of individual arguments. Optional +# argument 'shell' indicates whether to pass the command to the shell to run. +# If unspecified, it defaults to True if 'command' is a string, False if a +# list. If optional arg 'input' is given, pass this string as the stdin to the +# command. If 'include_stderr' is True, stderr will be included along with +# the output. If return code is non-zero, throw CommandError if 'throw' is +# specified; else, return tuple of (output, return-code). +def backquote(command, input=None, shell=None, include_stderr=True, throw=True): + #logdebug("backquote called: %s" % command) + if shell is None: + if isinstance(command, basestring): + shell = True + else: + shell = False + stderrval = STDOUT if include_stderr else PIPE + if input is not None: + popen = Popen(command, stdin=PIPE, stdout=PIPE, stderr=stderrval, + shell=shell, close_fds=True) + output = popen.communicate(input) + else: + popen = Popen(command, stdout=PIPE, stderr=stderrval, + shell=shell, close_fds=True) + output = popen.communicate() + if popen.returncode != 0: + if throw: + if output[0]: + outputstr = "Command's output:\n%s" % output[0] + if outputstr[-1] != '\n': + outputstr += '\n' + errstr = output[1] + if errstr and errstr[-1] != '\n': + errstr += '\n' + errmess = ("Error running command: %s\n\n%s\n%s" % + (command, output[0], output[1])) + #log.error(errmess) + oserror(errmess, EINVAL) + else: + return (output[0], popen.returncode) + return output[0] + +def oserror(mess, err): + e = OSError(mess) + e.errno = err + raise e + +############################################################################# +# Generating XML # +############################################################################# + +# This is old code I wrote originally for ccg.ply (the ccg2xml converter), +# for generating XML. It doesn't use the functions in xml.dom.minidom, +# which in any case are significantly more cumbersome than the list/tuple-based +# structure used below. + +# --------- XML ---------- +# +# Thankfully, the structure of XML is extremely simple. We represent +# a single XML statement of the form +# +# +# +# ... +# gurgle +# +# +# as a list +# +# ['biteme', [('foo', '1'), ('blorp', 'baz')], +# ['bitemetoo', ...], +# 'gurgle' +# ] +# +# i.e. an XML statement corresponds to a list where the first element +# is the statement name, the second element lists any properties, and +# the remaining elements list items inside the statement. +# +# ----------- Property lists ------------- +# +# The second element of an XML statement in list form is a "property list", +# a list of two-element tuples (property and value). Some functions below +# (e.g. `getprop', `putprop') manipulate property lists. +# +# FIXME: Just use a hash table. + +def check_arg_type(errtype, arg, ty): + if type(arg) is not ty: + raise TypeError("%s: Type is not %s: %s" % (errtype, ty, arg)) + +def xml_sub(text): + if not isinstance(text, basestring): + text = text.__str__() + if type(text) is unicode: + text = text.encode("utf-8") + text = text.replace('&', '&') + text = text.replace('<', '<') + text = text.replace('>', '>') + return text + +def print_xml_1(file, xml, indent=0): + #if xml_debug > 1: + # errout("%sPrinting: %s\n" % (' ' * indent, str(xml))) + if type(xml) is not list: + file.write('%s%s\n' % (' ' * indent, xml_sub(xml))) + else: + check_arg_type("XML statement", xml[0], str) + file.write(' ' * indent) + file.write('<%s' % xml_sub(xml[0])) + for x in xml[1]: + check_arg_type("XML statement", x, tuple) + if len(x) != 2: + raise TypeError("Bad tuple pair: " + str(x)) + file.write(' %s="%s"' % (xml_sub(x[0]), xml_sub(x[1]))) + subargs = xml[2:] + if len(subargs) == 1 and type(subargs[0]) is not list: + file.write('>%s\n' % (xml_sub(subargs[0]), xml_sub(xml[0]))) + elif not subargs: + file.write('/>\n') + else: + file.write('>\n') + for x in subargs: + print_xml_1(file, x, indent + 2) + file.write(' ' * indent) + file.write('\n' % xml_sub(xml[0])) + +# Pretty-print a section of XML, in the format above, to FILE. +# Start at indent INDENT. + +def print_xml(file, xml): + print_xml_1(file, xml) + +# Function to output a particular XML file +def output_xml_file(filename, xml): + fil = open(filename, 'w') + fil.write('\n') + print_xml(fil, xml) + fil.close() + +# Return True if PROP is seen as a property in PROPLIST, a list of tuples +# of (prop, value) +def property_specified(prop, proplist): + return not not ['foo' for (x,y) in proplist if x == prop] + +# Return value of property PROP in PROPLIST; signal an error if not found. +def getprop(prop, proplist): + for (x,y) in proplist: + if x == prop: + return y + raise ValueError("Property %s not found in %s" % (prop, proplist)) + +# Return value of property PROP in PROPLIST, or DEFAULT. +def getoptprop(prop, proplist, default=None): + for (x,y) in proplist: + if x == prop: + return y + return default + +# Replace value of property PROP with VALUE in PROPLIST. +def putprop(prop, value, proplist): + for i in xrange(len(proplist)): + if proplist[i][0] == prop: + proplist[i] = (prop, value) + return + else: + proplist += [(prop, value)] + +# Replace property named PROP with NEW in PROPLIST. Often this is called with +# with PROP equal to None; the None occurs when a PROP=VALUE clause is expected +# but a bare value is supplied. The context will supply a particular default +# property (e.g. 'name') to be used when the property name is omitted, but the +# generic code to handle property-value clauses doesn't know what this is. +# The surrounding code calls property_name_replace() to fill in the proper name. + +def property_name_replace(prop, new, proplist): + for i in xrange(len(proplist)): + if proplist[i][0] == prop: + proplist[i] = (new, proplist[i][1]) + + +############################################################################# +# Extra functions for working with sequences # +# Part of the Python docs for itertools # +############################################################################# + +def take(n, iterable): + "Return first n items of the iterable as a list" + return list(islice(iterable, n)) + +def tabulate(function, start=0): + "Return function(0), function(1), ..." + return imap(function, count(start)) + +def consume(iterator, n): + "Advance the iterator n-steps ahead. If n is none, consume entirely." + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + collections.deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + +def nth(iterable, n, default=None): + "Returns the nth item or a default value" + return next(islice(iterable, n, None), default) + +def quantify(iterable, pred=bool): + "Count how many times the predicate is true" + return sum(imap(pred, iterable)) + +def padnone(iterable): + """Returns the sequence elements and then returns None indefinitely. + + Useful for emulating the behavior of the built-in map() function. + """ + return chain(iterable, repeat(None)) + +def ncycles(iterable, n): + "Returns the sequence elements n times" + return chain.from_iterable(repeat(tuple(iterable), n)) + +def dotproduct(vec1, vec2): + return sum(imap(operator.mul, vec1, vec2)) + +def flatten(listOfLists): + "Flatten one level of nesting" + return chain.from_iterable(listOfLists) + +def repeatfunc(func, times=None, *args): + """Repeat calls to func with specified arguments. + + Example: repeatfunc(random.random) + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + +def pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = tee(iterable) + next(b, None) + return izip(a, b) + +def grouper(n, iterable, fillvalue=None): + "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return izip_longest(fillvalue=fillvalue, *args) + +def roundrobin(*iterables): + "roundrobin('ABC', 'D', 'EF') --> A D E B F C" + # Recipe credited to George Sakkis + pending = len(iterables) + nexts = cycle(iter(it).next for it in iterables) + while pending: + try: + for next in nexts: + yield next() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + +def powerset(iterable): + "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) + +def unique_everseen(iterable, key=None): + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + # unique_everseen('ABBCcAD', str.lower) --> A B C D + seen = set() + seen_add = seen.add + if key is None: + for element in ifilterfalse(seen.__contains__, iterable): + seen_add(element) + yield element + else: + for element in iterable: + k = key(element) + if k not in seen: + seen_add(k) + yield element + +def unique_justseen(iterable, key=None): + "List unique elements, preserving order. Remember only the element just seen." + # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B + # unique_justseen('ABBCcAD', str.lower) --> A B C A D + return imap(next, imap(itemgetter(1), groupby(iterable, key))) + +def iter_except(func, exception, first=None): + """ Call a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like __builtin__.iter(func, sentinel) but uses an exception instead + of a sentinel to end the loop. + + Examples: + bsddbiter = iter_except(db.next, bsddb.error, db.first) + heapiter = iter_except(functools.partial(heappop, h), IndexError) + dictiter = iter_except(d.popitem, KeyError) + dequeiter = iter_except(d.popleft, IndexError) + queueiter = iter_except(q.get_nowait, Queue.Empty) + setiter = iter_except(s.pop, KeyError) + + """ + try: + if first is not None: + yield first() + while 1: + yield func() + except exception: + pass + +def random_product(*args, **kwds): + "Random selection from itertools.product(*args, **kwds)" + pools = map(tuple, args) * kwds.get('repeat', 1) + return tuple(random.choice(pool) for pool in pools) + +def random_permutation(iterable, r=None): + "Random selection from itertools.permutations(iterable, r)" + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(random.sample(pool, r)) + +def random_combination(iterable, r): + "Random selection from itertools.combinations(iterable, r)" + pool = tuple(iterable) + n = len(pool) + indices = sorted(random.sample(xrange(n), r)) + return tuple(pool[i] for i in indices) + +def random_combination_with_replacement(iterable, r): + "Random selection from itertools.combinations_with_replacement(iterable, r)" + pool = tuple(iterable) + n = len(pool) + indices = sorted(random.randrange(n) for i in xrange(r)) + return tuple(pool[i] for i in indices) + diff --git a/src/main/python/parse-wex.py b/src/main/python/parse-wex.py new file mode 100755 index 0000000..7fb9517 --- /dev/null +++ b/src/main/python/parse-wex.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +import fileinput +import re +import sys +import itertools +from nlputil import * + +def printable(line): + return line.replace('\n', r'\n') + +def process_params(article_wpid, params, args): + #ty = None + #reg = None + #for param in these_params: + # for (key, val) in param: + # if key == 'type' and ty is None: + # ty = val + # if key == 'region' and reg is None: + # reg = val + errprint("article_wpid: %s, %s, %s" % (article_wpid, params, args)) + +def process_article_lines(article_id, rawargs): + for call_id, idargs in itertools.groupby(rawargs, key=lambda x:x[0]): + params = {} + args = [] + for idarg in idargs: + try: + (_, name, xml, _, id, section_id, template_article_name, line) = idarg + if call_id != id: + warning("Call id is %s but id is %s; line [%s]" % ( + call_id, id, line)) + m = re.match('(.*)$', xml.strip()) + if not m: + warning("Can't parse line: [%s]" % line) + continue + paramname = m.group(1) + if paramname != name: + warning("paramname is %s but name is %s; line [%s]" % ( + paramname, name, line)) + xmlval = m.group(2) + if ':' in name: + xmlval = name + ':' + xmlval + elif not re.match(r'^[0-9]+$', name): + params[name] = xmlval + continue + if ':' in xmlval: + param_pairs = xmlval.split('_') + for pair in param_pairs: + (key, val) = pair.split(':') + params[key] = val + else: + args += [xmlval] + except Exception, e: + warning("Saw exception %s: Line is [%s]" % (e, line)) + process_params(article_id, params, args) + +# Read in the raw file containing output from PostGreSQL, ignore headers, +# join continued lines +def yield_joined_lines(file): + # First two lines are headers + file.next() + file.next() + + for line in file: + line = line.strip() + try: + while not line.endswith('Template:Coord'): + contline = file.next().strip() + line += contline + except StopIteration: + warning("Partial, unfinished line [%s]" % line) + break + yield line + +# Read in lines, split and yield lists of the arguments, with the line itself +# as the last argument +def yield_arguments(lines): + for line in lines: + rawarg = [x.strip() for x in line.split('|')] + if len(rawarg) != 7: + warning("Wrong number of fields in line [%s]" % line) + continue + yield rawarg + [line] + +def process_file(): + gen = yield_arguments(yield_joined_lines(fileinput.input())) + + for article_id, article_rawargs in itertools.groupby(gen, key=lambda x:x[3]): + process_article_lines(article_id, sorted(article_rawargs)) + +process_file() diff --git a/src/main/python/permute_wiki.py b/src/main/python/permute_wiki.py new file mode 100755 index 0000000..0283cd1 --- /dev/null +++ b/src/main/python/permute_wiki.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +####### +####### permute_wiki.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +from __future__ import with_statement + +import random +import re +import sys +from nlputil import * +from process_article_data import * + +""" +We want to randomly permute the articles and then reorder the articles +in the dump file accordingly. A simple algorithm would be to load the +entire dump file into memory as huge list, permute the list randomly, +then write the results. However, this might easily exceed the memory +of the computer. So instead, we split the task into chunks, proceeding +as follows, keeping in mind that we have the list of article names +available in a separate file: + +1. The dump consists of a prolog, a bunch of articles, and an epilog. +2. (The 'permute' step:) Take the full list of articles, permute randomly + and output again. +3. (The 'split' step:) Split the dump into pieces, of perhaps a few GB each, + the idea being that we can sort each piece separately and concatenate the + results. (This is the 'split' step.) The idea is that we first split the + permuted list of articles into some number of pieces (default 8), and + create a mapping listing which split each article goes in; then we create + a file for each split; then we read through the dump file, and each time we + find an article, we look up its split and write it to the corresponding + split file. We also write the prolog and epilog into separate files. + Note that in this step we have effectively done a rough sort of the + articles by split, preserving the original order within each split. +4. (The 'sort' step:) Sort each split. For each split, we read the permuted + article list into memory to get the proper order, then we read the entire + split into memory and output the articles in the order indicated in the + article list. +5. Concatenate the results. + +Note that this might be a perfect task for Hadoop; it automatically does +the splitting, sorting and merging. + +""" + +all_articles = [] + +def read_article_data(filename): + def process(art): + all_articles.append(art) + + infields = read_article_data_file(filename, process, + maxtime=Opts.max_time_per_stage) + return infields + +def write_permutation(infields): + random.shuffle(all_articles) + errprint("Writing combined data to stdout ...") + write_article_data_file(sys.stdout, outfields=infields, + articles=all_articles) + errprint("Done.") + + +def break_up_xml_dump(infile): + prolog = '' + inpage = False + for x in infile: + if re.match('.*', x): + thispage = [x] + inpage = True + break + else: + prolog += x + + if prolog: + yield ('prolog', prolog) + + thisnonpage = '' + for x in infile: + if inpage: + if re.match('.*', x): + inpage = False + thispage.append(x) + thisnonpage = '' + yield ('page', ''.join(thispage)) + else: + thispage.append(x) + else: + if re.match('.*', x): + if thisnonpage: + yield ('nonpage', thisnonpage) + thispage = [x] + inpage = True + else: + thisnonpage += x + if inpage: + warning("Saw but no ") + if thisnonpage: + yield ('epilog', thisnonpage) + + +def get_id_from_page(text): + m = re.match('(?s).*?(.*?)', text) + if not m: + warning("Can't find ID in article; beginning of article text follows:") + maxlen = min(100, len(text)) + errprint(text[0:maxlen]) + return -1 + id = m.group(1) + try: + id = int(id) + except ValueError: + print "Exception when parsing %s, assumed non-int" % id + return -1 + return id + + +def split_files(infields, split_prefix, num_splits): + errprint("Generating this split article-table files...") + splits = {} + num_arts = len(all_articles) + splitsize = (num_arts + num_splits - 1) // num_splits + for i in xrange(num_splits): + minval = i * splitsize + maxval = min(num_arts, (i + 1) * splitsize) + outarts = [] + for j in xrange(minval, maxval): + art = all_articles[j] + splits[art.id] = i + outarts.append(art) + with open("%s.%s.articles" % (split_prefix, i), 'w') as outfile: + write_article_data_file(outfile, outfields=infields, articles=outarts) + + # Clear the big array when we're done with it + del all_articles[:] + + splitfiles = [None]*num_splits + for i in xrange(num_splits): + splitfiles[i] = open("%s.%s" % (split_prefix, i), 'w') + + errprint("Splitting the dump....") + status = StatusMessage("article") + for (type, text) in break_up_xml_dump(sys.stdin): + if type == 'prolog': + with open("%s.prolog" % split_prefix, 'w') as prolog: + prolog.write(text) + elif type == 'epilog': + with open("%s.epilog" % split_prefix, 'w') as epilog: + epilog.write(text) + elif type == 'nonpage': + warning("Saw non-page text %s" % text) + else: + id = get_id_from_page(text) + if id not in splits: + warning("Can't find article %s in article data file" % id) + else: + splitfiles[splits[id]].write(text) + if status.item_processed(maxtime=Opts.max_time_per_stage): + errprint("Interrupting processing") + break + +def sort_file(): + all_pages = {} + for (type, text) in break_up_xml_dump(sys.stdin): + if type != 'page': + warning("Shouldn't see type '%s' in split file: %s" % (type, text)) + else: + id = get_id_from_page(text) + all_pages[id] = text + for art in all_articles: + text = all_pages.get(art.id, None) + if text is None: + warning("Didn't see article ID %s in XML file" % art.id) + else: + sys.stdout.write(text) + + +############################################################################ +# Main code # +############################################################################ + +class PermuteWikipediaDumpProgram(NLPProgram): + def populate_options(self, op): + op.add_option("-a", "--article-data-file", + help="""File containing all the articles.""") + op.add_option("-s", "--number-of-splits", type='int', default=8, + help="""Number of splits.""") + op.add_option("--split-prefix", help="""Prefix for split files.""") + op.add_option("-m", "--mode", type='choice', + default=None, choices=['permute', 'split', 'sort'], + help="""Format of evaluation file(s). Default '%default'.""") + + def handle_arguments(self, opts, op, args): + global Opts + Opts = opts + self.need('mode') + self.need('article_data_file') + if opts.mode == 'split': + self.need('split_prefix') + + def implement_main(self, opts, params, args): + infields = read_article_data(opts.article_data_file) + if opts.mode == 'permute': + write_permutation(infields) + elif opts.mode == 'split': + split_files(infields, opts.split_prefix, opts.number_of_splits) + elif opts.mode == 'sort': + sort_file() + +PermuteWikipediaDumpProgram() diff --git a/src/main/python/process_article_data.py b/src/main/python/process_article_data.py new file mode 100755 index 0000000..fe9fe06 --- /dev/null +++ b/src/main/python/process_article_data.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +####### +####### process_article_data.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +from nlputil import * + +#!/usr/bin/env python + +############################################################################ +# Main code # +############################################################################ + +minimum_latitude = -90.0 +maximum_latitude = 90.0 +minimum_longitude = -180.0 +maximum_longitude = 180.0 - 1e-10 + +# A 2-dimensional coordinate. +# +# The following fields are defined: +# +# lat, long: Latitude and longitude of coordinate. + +class Coord(object): + __slots__ = ['lat', 'long'] + + ### If coerce_within_bounds=True, then force the values to be within + ### the allowed range, by wrapping longitude and bounding latitude. + def __init__(self, lat, long, coerce_within_bounds=False): + if coerce_within_bounds: + if lat > maximum_latitude: lat = maximum_latitude + while long > maximum_longitude: long -= 360. + if lat < minimum_latitude: lat = minimum_latitude + while long < minimum_longitude: long += 360. + self.lat = lat + self.long = long + + def __str__(self): + return '(%.2f,%.2f)' % (self.lat, self.long) + +# A Wikipedia article. Defined fields: +# +# title: Title of article. +# id: ID of article, as an int. +# coord: Coordinates of article. +# incoming_links: Number of incoming links, or None if unknown. +# split: Split of article ('training', 'dev', 'test') +# redir: If this is a redirect, article title that it redirects to; else +# an empty string. +# namespace: Namespace of article (e.g. 'Main', 'Wikipedia', 'File') +# is_list_of: Whether article title is 'List of *' +# is_disambig: Whether article is a disambiguation page. +# is_list: Whether article is a list of any type ('List of *', disambig, +# or in Category or Book namespaces) +class Article(object): + __slots__ = ['title', 'id', 'coord', 'incoming_links', 'split', 'redir', + 'namespace', 'is_list_of', 'is_disambig', 'is_list'] + def __init__(self, title='unknown', id=None, coord=None, incoming_links=None, + split='unknown', redir='', namespace='Main', is_list_of=False, + is_disambig=False, is_list=False): + self.title = title + self.id = id + self.coord = coord + self.incoming_links = incoming_links + self.split = split + self.redir = redir + self.namespace = namespace + self.is_list_of = is_list_of + self.is_disambig = is_disambig + self.is_list = is_list + + def __str__(self): + coordstr = " at %s" % self.coord if self.coord else "" + redirstr = ", redirect to %s" % self.redir if self.redir else "" + return '%s(%s)%s%s' % (self.title, self.id, coordstr, redirstr) + + # Output row of an article-data file, normal format. 'outfields' is + # a list of the fields to output, and 'outfield_types' is a list of + # corresponding types, determined by a call to get_output_field_types(). + def output_row(self, outfile, outfields, outfield_types): + fieldvals = [t(getattr(self, f)) for f,t in zip(outfields, outfield_types)] + uniprint('\t'.join(fieldvals), outfile=outfile) + +def yesno_to_boolean(foo): + if foo == 'yes': return True + else: + if foo != 'no': + warning("Expected yes or no, saw '%s'" % foo) + return False + +def boolean_to_yesno(foo): + if foo: return 'yes' + else: return 'no' + +def commaval_to_coord(foo): + if foo: + (lat, long) = foo.split(',') + return Coord(float(lat), float(long)) + return None + +def coord_to_commaval(foo): + if foo: + return "%s,%s" % (foo.lat, foo.long) + return '' + +def get_int_or_blank(foo): + if not foo: return None + else: return int(foo) + +def put_int_or_blank(foo): + if foo == None: return '' + else: return "%s" % foo + +def identity(foo): + return foo + +def tostr(foo): + return "%s" % foo + +known_fields_input = {'id':int, 'title':identity, 'split':identity, + 'redir':identity, 'namespace':identity, + 'is_list_of':yesno_to_boolean, + 'is_disambig':yesno_to_boolean, + 'is_list':yesno_to_boolean, 'coord':commaval_to_coord, + 'incoming_links':get_int_or_blank} + +known_fields_output = {'id':tostr, 'title':tostr, 'split':tostr, + 'redir':tostr, 'namespace':tostr, + 'is_list_of':boolean_to_yesno, + 'is_disambig':boolean_to_yesno, + 'is_list':boolean_to_yesno, 'coord':coord_to_commaval, + 'incoming_links':put_int_or_blank} + +combined_article_data_outfields = ['id', 'title', 'split', 'coord', + 'incoming_links', 'redir', 'namespace', 'is_list_of', 'is_disambig', + 'is_list'] + +def get_field_types(field_table, field_list): + for f in field_list: + if f not in field_table: + warning("Saw unknown field name %s" % f) + return [field_table.get(f, identity) for f in field_list] + +def get_input_field_types(field_list): + return get_field_types(known_fields_input, field_list) + +def get_output_field_types(field_list): + return get_field_types(known_fields_output, field_list) + +# Read in the article data file. Call PROCESS on each article. +# The type of the article created is given by ARTICLE_TYPE, which defaults +# to Article. MAXTIME is a value in seconds, which limits the total +# processing time (real time, not CPU time) used for reading in the +# file, for testing purposes. +def read_article_data_file(filename, process, article_type=Article, + maxtime=0): + errprint("Reading article data from %s..." % filename) + status = StatusMessage('article') + + fi = uchompopen(filename) + fields = fi.next().split('\t') + field_types = get_input_field_types(fields) + for line in fi: + fieldvals = line.split('\t') + if len(fieldvals) != len(field_types): + warning("""Strange record at line #%s, expected %s fields, saw %s fields; + skipping line=%s""" % (status.num_processed(), len(field_types), + len(fieldvals), line)) + continue + record = dict([(str(f),t(v)) for f,v,t in zip(fields, fieldvals, field_types)]) + art = article_type(**record) + process(art) + if status.item_processed(maxtime=maxtime): + break + errprint("Finished reading %s articles." % (status.num_processed())) + output_resource_usage() + return fields + +def write_article_data_file(outfile, outfields, articles): + field_types = get_output_field_types(outfields) + uniprint('\t'.join(outfields), outfile=outfile) + for art in articles: + art.output_row(outfile, outfields, field_types) + outfile.close() diff --git a/src/main/python/processwiki.py b/src/main/python/processwiki.py new file mode 100755 index 0000000..a42571c --- /dev/null +++ b/src/main/python/processwiki.py @@ -0,0 +1,2456 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +####### +####### processwiki.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +### FIXME: +### +### Cases to fix involving coordinates: + +# 1. Nested coordinates: + +#{{Infobox Australian Place +#| name = Lauderdale +#| image = Lauderdale Canal.JPG +#| caption = +#| loc-x = +#| loc-y = +#| coordinates = {{coord|42|54|40|S|147|29|34|E|display=inline,title}} +#| state = tas +#... +#}} + +import sys, re +from optparse import OptionParser +from nlputil import * +import itertools +import time +from process_article_data import * + +from xml.sax import make_parser +from xml.sax.handler import ContentHandler + +# Debug flags. Different flags indicate different info to output. +debug = booldict() + +# Program options +progopts = None + +disambig_pages_by_id = set() + +article_namespaces = ['User', 'Wikipedia', 'File', 'MediaWiki', 'Template', + 'Help', 'Category', 'Thread', 'Summary', 'Portal', + 'Book'] + +article_namespaces = {} +article_namespaces_lower = {} + +article_namespace_aliases = { + 'P':'Portal', 'H':'Help', 'T':'Template', + 'CAT':'Category', 'Cat':'Category', 'C':'Category', + 'MOS':'Wikipedia', 'MoS':'Wikipedia', 'Mos':'Wikipedia'} + +# Count number of incoming links for articles +incoming_link_count = intdict() + +# Map anchor text to a hash that maps articles to counts +anchor_text_map = {} + +# Set listing articles containing coordinates +coordinate_articles = set() + +debug_cur_title = None + +# Parse the result of a previous run of --coords-counts for articles with +# coordinates +def read_coordinates_file(filename): + errprint("Reading coordinates file %s..." % filename) + status = StatusMessage('article') + for line in uchompopen(filename): + m = re.match('Article title: (.*)', line) + if m: + title = capfirst(m.group(1)) + elif re.match('Article coordinates: ', line): + coordinate_articles.add(title) + if status.item_processed(maxtime=Opts.max_time_per_stage): + break + +# Read in redirects. Record redirects as additional articles with coordinates +# if the article pointed to has coordinates. NOTE: Must be done *AFTER* +# reading coordinates. +def read_redirects_from_article_data(filename): + assert coordinate_articles + errprint("Reading redirects from article data file %s..." % filename) + + def process(art): + if art.namespace != 'Main': + return + if art.redir and capfirst(art.redir) in coordinate_articles: + coordinate_articles.add(art.title) + + read_article_data_file(filename, process, maxtime=Opts.max_time_per_stage) + +# Read the list of disambiguation article ID's. +def read_disambig_id_file(filename): + errprint("Reading disambig ID file %s..." % filename) + status = StatusMessage("article") + for line in uchompopen(filename): + disambig_pages_by_id.add(line) + if status.item_processed(maxtime=Opts.max_time_per_stage): + break + +############################################################################ +# Documentation # +############################################################################ + +##### Quick start + +# This program processes the article dump from Wikipedia. Dump is on +# stdin. Outputs to stdout. Written flexibly so that it can be modified +# to do various things. To run it, use something like this: +# +# bzcat enwiki-20100905-pages-articles.xml.bz2 | processwiki.py > wiki-words.out + +##### How this program works + +# Currently it does the following: +# +# 1. Locate the article title and text. +# +# 2. Find any coordinates specified either in Infobox or Coord templates. +# If found, the first such coordinate in an article is output with lines +# like +# +# Article title: Politics of Angola +# Article coordinates: 13.3166666667,-169.15 +# +# 3. For articles with coordinates in them, locate all the "useful" words in +# the article text. This ignores HTML codes like , comments, +# stuff like [[ or ]], anything inside of ..., etc. It +# tries to do intelligent things with templates (stuff inside of {{...}}) +# and internal links (inside of [[...]]), and ignores external links +# ([http:...]). The words are split on whitespace, ignoring punctuation +# such as periods and commas, and the resulting words are counted up, and +# the count of each different word is output, one per line like +# +# Birmingham = 48 +# +# There is also a debug flag. If set, lots of additional stuff is output. +# Among them are warnings like +# +# Warning: Nesting level would drop below 0; string = }, prevstring = (19 +# +# Note that since words are broken on spaces, there will never be a space +# in the outputted words. Hence, the lines containing directives (e.g. +# the article title) can always be distinguished from lines containing words. +# +# Note also that the following terminology is used here, which may not be +# standard: +# +# Internal link: A link to another Wikipedia article, of the form [[...]]. +# External link: A link to an external URL, of the form [...]. +# Template: An expression of the form {{...}}, with arguments separated by +# the pipe symbol |, that processes the arguments and subsitutes +# the processed text; it may also trigger other sorts of actions. +# Similar to the macros in C or M4. +# Macro: An expression that results in some other text getting substituted, +# either a template {{...}} or an internal link [[...]]. + +##### Internal workings; how to extend the program + +# Note that this program is written so that it can be flexibly extended to +# allow for different sorts of processing of the Wikipedia dump. See the +# following description, which indicates where to change in order to +# implement different behavior. +# +# The basic functioning of this code is controlled by an article handler class. +# The default handler class is ArticleHandler. Usually it is +# sufficient to subclass this handler class, as it provides hooks to do +# interesting things, which by default do nothing. You can also subclass +# ArticleHandlerForUsefulText if you want the source text processed for +# "useful text" (what the Wikipedia user sees, plus similar-quality +# hidden text). +# +# SAX is used to process the XML of the raw Wikipedia dump file. +# Simple SAX handler functions are what invokes the article handler +# functions in the article handler class. +# +# For each article, the article handler function process_article_text() is +# called to process the text of the article, and is passed the article title +# and full text, with entity expressions such as   replaced appropriately. +# This function operates in two passes. The first pass, performed by +# the article handler process_text_for_data(), extracts useful data, e.g. +# coordinates or links. It returns True or False, indicating whether the +# second pass should operate. The purpose of the second pass is to do +# processing involving the article text itself, e.g. counting up words. +# It is implemented by the article handler process_text_for_text(). +# The default handler does two things: +# +# 1. Process the text, filtering out some junk +# (see format_text_second_pass()). +# 2. Use process_source_text() to extract chunks of "actual +# text" (as opposed to directives of various sorts), i.e. text that +# is useful for constructing a language model that can be used +# for classifying a document to find the most similar article. +# Join together and then split into words. Pass the generator +# of words to the article handler process_text_for_words(). +# +# process_source_text() is is a generator that yields processed +# textual chunks containing only "actual text". This function works by +# calling parse_simple_balanced_text() to parse the text into balanced chunks +# (text delimited by balanced braces or brackets, i.e. {...} or [...], +# or text without any braces or brackets), and then handling the chunks +# according to their type: +# +# -- if [[...]], use process_internal_link() +# -- if {{...}}, use process_template() +# -- if {|...|}, use process_table() +# -- if [...] but not [[...]], use process_external_link() +# -- else, return the text unchanged +# +# Each of the above functions is a generator that yields chunks of +# "actual text". Different sorts of processing can be implemented here. +# Note also that a similar structure can and probably should be +# implemented in process_text_for_data(). +# +# As mentioned above, the chunks are concatenated before being split again. +# Concatentation helps in the case of article text like +# +# ... the [[latent variable|hidden value]]s ... +# +# which will get processed into chunks +# +# '... the ', 'latent variable hidden value', 's ...' +# +# Concatenating again will generate a single word "values". +# +# The resulting text is split to find words, using split_text_into_words(). +# This splits on whitespace, but is a bit smarter; it also ignores +# punctuation such as periods or commas that occurs at the end of words, +# as well as parens and quotes at word boundaries, and ignores entirely +# any word with a colon in the middle (a likely URL or other directive that +# has slipped through), and separates on # and _ (which may occur in +# internal links or such). +# +# Note also that prior to processing the text for data and again prior to +# processing the text for words, it is formatted to make it nicer to +# process and to get rid of certain sorts of non-useful text. This is +# implemented in the functions format_text_first_pass() and +# format_text_second_pass(), respectively. This includes things like: +# +# -- removing comments +# -- removing ... sections, which contain differently-formatted +# text (specifically, in TeX format), which will screw up processing +# of templates and links and contains very little useful text +# -- handling certain sorts of embedded entity expressions, e.g. cases +# where &nbsp; appears in the raw dump file. This corresponds to +# cases where   appears in the source text of the article. +# Wikipedia's servers process the source text into HTML and then spit +# out the HTML, which gets rendered by the browser and which will +# handle embedded entity expressions (e.g. convert   into a +# non-breaking-space character). Note that something like   +# appears directly in the raw dump file only when a literal +# non-breaking-space character appears in the article source text. +# -- handling embedded HTML expressions like 2, where < appears +# in the raw dump as <. These get processed by the user's browser. +# We handle them in a simple fashion, special-casing
and +# into whitespace and just removing all the others. +# -- removing === characters in headers like ===Introduction=== +# -- removing multiple single-quote characters, which indicate boldface +# or italics + +##### About generators + +# The code in this program relies heavily on generators, a special type of +# Python function. The following is a quick intro for programmers who +# might not be familiar with generators. +# +# A generator is any function containing a "yield foo" expression. +# Logically speaking, a generator function returns multiple values in +# succession rather than returning a single value. In the actual +# implementation, the result of calling a generator function is a +# generator object, which can be iterated over in a for loop, list +# comprehension or generator expression, e.g. +# +# 1. The following uses a for loop to print out the objects returned by +# a generator function. +# +# for x in generator(): +# print x +# +# 2. The following returns a list resulting from calling a function fun() on +# each object returned by a generator function. +# +# [fun(x) for x in generator()] +# +# 3. The following returns another generator expression resulting from +# calling a function fun() on each object returned by a generator function. +# +# (fun(x) for x in generator()) +# +# There are some subtleties involved in writing generators: +# +# -- A generator can contain a "return" statement, but cannot return a value. +# Returning from a generator, whether explicitly through a "return" +# statement or implicitly by falling off the end of the function, triggers +# a "raise StopIteration" statement. This terminates the iteration loop +# over the values returned by the generator. +# -- Chaining generators, i.e. calling one generator inside of another, is +# a bit tricky. If you have a generator function generator(), and you +# want to pass back the values from another generator function generator2(), +# you cannot simply call "return generator2()", since generators can't +# return values. If you just write "generator2()", nothing will happen; +# the value from generator2() gets discarded. So you usually have to +# write a for loop: +# +# for foo in generator2(): +# yield foo +# +# Note that "return generator2()" *will* work inside of a function that is +# not a generator, i.e. has no "yield" statement in it. + +####################################################################### +# Splitting the output # +####################################################################### + +# Files to output to, when splitting output +split_output_files = None + +# List of split suffixes +split_suffixes = None + +# Current file to output to +cur_output_file = sys.stdout +debug_to_stderr = False + +# Name of current split (training, dev, test) +cur_split_name = '' + +# Generator of files to output to +split_file_gen = None + +# Initialize the split output files, using PREFIX as the prefix +def init_output_files(prefix, split_fractions, the_split_suffixes): + assert len(split_fractions) == len(the_split_suffixes) + global split_output_files + split_output_files = [None]*len(the_split_suffixes) + global split_suffixes + split_suffixes = the_split_suffixes + for i in range(len(the_split_suffixes)): + split_output_files[i] = open("%s.%s" % (prefix, the_split_suffixes[i]), "w") + global split_file_gen + split_file_gen = next_split_set(split_fractions) + +# Find the next split file to output to and set CUR_OUTPUT_FILE appropriately; +# don't do anything if the user hasn't called for splitting. +def set_next_split_file(): + global cur_output_file + global cur_split_name + if split_file_gen: + nextid = split_file_gen.next() + cur_output_file = split_output_files[nextid] + cur_split_name = split_suffixes[nextid] + +####################################################################### +# Chunk text into balanced sections # +####################################################################### + +### Return chunks of balanced text, for use in handling template chunks +### and such. The chunks consist either of text without any braces or +### brackets, chunks consisting of a brace or bracket and all the text +### up to and including the matching brace or bracket, or lone unmatched +### right braces/brackets. Currently, if a chunk is closed with the +### wrong type of character (brace when bracket is expected or vice-versa), +### we still treat it as the closing character, but output a warning. +### +### In addition, some of the versions below will split off additional +### characters if they occur at the top level (e.g. pipe symbols or +### newlines). In these cases, if such a character occurs, three +### successive chunks will be seen: The text up to but not including the +### dividing character, a chunk with only the character, and a chunk +### with the following text. Note that if the dividing character occurs +### inside of bracketed or braced text, it will not divide the text. +### This way, for example, arguments of a template or internal link +### (which are separated by a pipe symbol) can be sectioned off without +### also sectioning off the arguments inside of nested templates or +### internal links. Then, the parser can be called recursively if +### necessary to handle such expressions. + +left_ref_re = r'' +# Return braces and brackets separately from other text. +simple_balanced_re = re.compile(left_ref_re + r'||[^{}\[\]<]+|[{}\[\]]|<') +#simple_balanced_re = re.compile(r'[^{}\[\]]+|[{}\[\]]') + +# Return braces, brackets and pipe symbols separately from other text. +balanced_pipe_re = re.compile(left_ref_re + r'||[^{}\[\]|<]+|[{}\[\]|]|<') +#balanced_pipe_re = re.compile(r'[^{}\[\]|]+|[{}\[\]|]') + +# Return braces, brackets, and newlines separately from other text. +# Useful for handling Wikipedia tables, denoted with {| ... |}. +balanced_table_re = re.compile(left_ref_re + r'||[^{}\[\]\n<]+|[{}\[\]\n]|<') +#balanced_table_re = re.compile(r'[^{}\[\]\n]+|[{}\[\]\n]') + +left_match_chars = {'{':'}', '[':']', '':''} +right_match_chars = {'}':'{', ']':'[', '':''} + +def parse_balanced_text(textre, text, throw_away = 0): + '''Parse text in TEXT containing balanced expressions surrounded by single +or double braces or brackets. This is a generator; it successively yields +chunks of text consisting either of sections without any braces or brackets, +or balanced expressions delimited by single or double braces or brackets, or +unmatched single or double right braces or brackets. TEXTRE is used to +separate the text into chunks; it can be used to separate out additional +top-level separators, such as vertical bar.''' + strbuf = [] + prevstring = "(at beginning)" + leftmatches = [] + parenlevel = 0 + for string in textre.findall(text): + if debug['debugparens']: + errprint("pbt: Saw %s, parenlevel=%s" % (string, parenlevel)) + if string.startswith(''): + wikiwarning("Strange parsing, saw odd ref tag: %s" % string) + if string.endswith('/>'): + continue + string = '' + if string in right_match_chars: + if parenlevel == 0: + wikiwarning("Nesting level would drop below 0; string = %s, prevstring = %s" % (string, prevstring.replace('\n','\\n'))) + yield string + else: + strbuf.append(string) + assert len(leftmatches) == parenlevel + should_left = right_match_chars[string] + should_pop_off = 1 + the_left = leftmatches[-should_pop_off] + if should_left != the_left: + if should_left == '': + wikiwarning("Saw unmatched ") + in_ref = any([match for match in leftmatches if match == '']) + if not in_ref: + wikiwarning("Stray ??; prevstring = %s" % prevstring.replace('\n','\\n')) + should_pop_off = 0 + else: + while (len(leftmatches) - should_pop_off >= 0 and + should_left != leftmatches[len(leftmatches)-should_pop_off]): + should_pop_off += 1 + if should_pop_off >= 0: + wikiwarning("%s non-matching brackets inside of ...: %s ; prevstring = %s" % (should_pop_off - 1, ' '.join(left_match_chars[x] for x in leftmatches[len(leftmatches)-should_pop_off:]), prevstring.replace('\n','\\n'))) + else: + wikiwarning("Inside of but still interpreted as stray ??; prevstring = %s" % prevstring.replace('\n','\\n')) + should_pop_off = 0 + elif the_left == '': + wikiwarning("Stray %s inside of ...; prevstring = %s" % (string, prevstring.replace('\n','\\n'))) + should_pop_off = 0 + else: + wikiwarning("Non-matching brackets: Saw %s, expected %s; prevstring = %s" % (string, left_match_chars[the_left], prevstring.replace('\n','\\n'))) + if should_pop_off > 0: + parenlevel -= should_pop_off + if debug['debugparens']: + errprint("pbt: Decreasing parenlevel by 1 to %s" % parenlevel) + leftmatches = leftmatches[:-should_pop_off] + if parenlevel == 0: + yield ''.join(strbuf) + strbuf = [] + else: + if string in left_match_chars: + if throw_away > 0: + wikiwarning("Throwing away left bracket %s as a reparse strategy" + % string) + throw_away -= 1 + else: + parenlevel += 1 + if debug['debugparens']: + errprint("pbt: Increasing parenlevel by 1 to %s" % parenlevel) + leftmatches.append(string) + if parenlevel > 0: + strbuf.append(string) + else: + yield string + prevstring = string + leftover = ''.join(strbuf) + if leftover: + wikiwarning("Unmatched left paren, brace or bracket: %s characters remaining" % len(leftover)) + wikiwarning("Remaining text: [%s]" % bound_string_length(leftover)) + wikiwarning("Reparsing:") + for string in parse_balanced_text(textre, leftover, throw_away = parenlevel): + yield string + +def parse_simple_balanced_text(text): + '''Parse text in TEXT containing balanced expressions surrounded by single +or double braces or brackets. This is a generator; it successively yields +chunks of text consisting either of sections without any braces or brackets, +or balanced expressions delimited by single or double braces or brackets, or +unmatched single or double right braces or brackets.''' + return parse_balanced_text(simple_balanced_re, text) + +####################################################################### +### Utility functions ### +####################################################################### + +def splitprint(text): + '''Print text (possibly Unicode) to the appropriate output, either stdout +or one of the split output files.''' + uniprint(text, outfile=cur_output_file) + +def outprint(text): + '''Print text (possibly Unicode) to stdout (but stderr in certain debugging +modes).''' + if debug_to_stderr: + errprint(text) + else: + uniprint(text) + +def wikiwarning(foo): + warning("Article %s: %s" % (debug_cur_title, foo)) + +# Output a string of maximum length, adding ... if too long +def bound_string_length(str, maxlen=60): + if len(str) <= maxlen: + return str + else: + return '%s...' % str[0:maxlen] + +def find_template_params(args, strip_values): + '''Find the parameters specified in template arguments, i.e. the arguments +to a template that are of the form KEY=VAL. Given the arguments ARGS of a +template, return a tuple (HASH, NONPARAM) where HASH is the hash table of +KEY->VAL parameter mappings and NONPARAM is a list of all the remaining, +non-parameter arguments. If STRIP_VALUES is true, strip whitespace off the +beginning and ending of values in the hash table (keys will always be +lowercased and have the whitespace stripped from them).''' + hash = {} + nonparam_args = [] + for arg in args: + m = re.match(r'(?s)(.*?)=(.*)', arg) + if m: + key = m.group(1).strip().lower().replace('_','').replace(' ','') + value = m.group(2) + if strip_values: + value = value.strip() + hash[key] = value + else: + #errprint("Unable to process template argument %s" % arg) + nonparam_args.append(arg) + return (hash, nonparam_args) + +def get_macro_args(macro): + '''Split macro MACRO (either a {{...}} or [[...]] expression) +by arguments (separated by | occurrences), but intelligently so that +arguments in nested macros are not also sectioned off. In the case +of a template, i.e. {{...}}, the first "argument" returned will be +the template type, e.g. "Cite web" or "Coord". At least one argument +will always be returned (in the case of an empty macro, it will be +the string "empty macro"), so that code that parses templates need +not worry about crashing on these syntactic errors.''' + + macroargs1 = [foo for foo in + parse_balanced_text(balanced_pipe_re, macro[2:-2])] + macroargs2 = [] + # Concatenate adjacent args if neither one is a | + for x in macroargs1: + if x == '|' or len(macroargs2) == 0 or macroargs2[-1] == '|': + macroargs2 += [x] + else: + macroargs2[-1] += x + macroargs = [x for x in macroargs2 if x != '|'] + if not macroargs: + wikiwarning("Strange macro with no arguments: %s" % macroargs) + return ['empty macro'] + return macroargs + +####################################################################### +# Process source text # +####################################################################### + +# Handle the text of a given article. Yield chunks of processed text. + +class SourceTextHandler(object): + def process_internal_link(self, text): + yield text + + def process_template(self, text): + yield text + + def process_table(self, text): + yield text + + def process_external_link(self, text): + yield text + + def process_reference(self, text): + yield text + + def process_text_chunk(self, text): + yield text + + def process_source_text(self, text): + # Look for all template and link expressions in the text and do something + # sensible with them. Yield the resulting text chunks. The idea is that + # when the chunks are joined back together, we will get raw text that can + # be directly separated into words, without any remaining macros (templates, + # internal or external links, tables, etc.) and with as much extraneous + # junk (directives of various sorts, instead of relevant text) as possible + # filtered out. Note that when we process macros and extract the relevant + # text from them, we need to recursively process that text. + + if debug['lots']: errprint("Entering process_source_text: [%s]" % text) + + for foo in parse_simple_balanced_text(text): + if debug['lots']: errprint("parse_simple_balanced_text yields: [%s]" % foo) + + if foo.startswith('[['): + gen = self.process_internal_link(foo) + + elif foo.startswith('{{'): + gen = self.process_template(foo) + + elif foo.startswith('{|'): + gen = self.process_table(foo) + + elif foo.startswith('['): + gen = self.process_external_link(foo) + + elif foo.startswith(' 60: + wikiwarning("Out-of-bounds minutes %s" % min) + return None + if sec > 60: + wikiwarning("Out-of-bounds seconds %s" % sec) + return None + return nsew*(lat + min/60. + sec/3600.) + +convert_ns = {'N':1, 'S':-1} +convert_ew = {'E':1, 'W':-1, 'L':1, 'O':-1} +# Blah!! O=Ost="east" in German but O=Oeste="west" in Spanish/Portuguese +convert_ew_german = {'E':1, 'W':-1, 'O':1} + +# Get the default value for the hemisphere, as a multiplier +1 or -1. +# We need to handle the following as S latitude, E longitude: +# -- Infobox Australia +# -- Info/Localidade de Angola +# -- Info/Município de Angola +# -- Info/Localidade de Moçambique + +# We need to handle the following as N latitude, W longitude: +# -- Infobox Pittsburgh neighborhood +# -- Info/Assentamento/Madeira +# -- Info/Localidade da Madeira +# -- Info/Assentamento/Marrocos +# -- Info/Localidade dos EUA +# -- Info/PousadaPC +# -- Info/Antigas freguesias de Portugal +# Otherwise assume +1, so that we leave the values alone. This is important +# because some fields may specifically use signed values to indicate the +# hemisphere directly, or use other methods of indicating hemisphere (e.g. +# "German"-style "72/50/35/W"). +def get_hemisphere(temptype, is_lat): + for x in ('infobox australia', 'info/localidade de angola', + u'info/município de angola', u'info/localidade de moçambique'): + if temptype.lower().startswith(x): + if is_lat: return -1 + else: return 1 + for x in ('infobox pittsburgh neighborhood', 'info/assentamento/madeira', + 'info/assentamento/marrocos', 'info/localidade dos eua', 'info/pousadapc', + 'info/antigas freguesias de portugal'): + if temptype.lower().startswith(x): + if is_lat: return 1 + else: return -1 + return 1 + +# Get an argument (ARGSEARCH) by name from a hash table (ARGS). Multiple +# synonymous names can be looked up by giving a list or tuple for ARGSEARCH. +# Other parameters control warning messages. +def getarg(argsearch, temptype, args, rawargs, warnifnot=True): + if isinstance(argsearch, tuple) or isinstance(argsearch, list): + for x in argsearch: + val = args.get(x, None) + if val is not None: + return val + if warnifnot or debug['some']: + wikiwarning("None of params %s seen in template {{%s|%s}}" % ( + ','.join(argsearch), temptype, bound_string_length('|'.join(rawargs)))) + else: + val = args.get(argsearch, None) + if val is not None: + return val + if warnifnot or debug['some']: + wikiwarning("Param %s not seen in template {{%s|%s}}" % ( + argsearch, temptype, bound_string_length('|'.join(rawargs)))) + return None + +# Utility function for get_latd_coord(). +# Extract out either latitude or longitude from a template of type +# TEMPTYPE with arguments ARGS. LATD/LATM/LATS are lists or tuples of +# parameters to look up to retrieve the appropriate value. OFFPARAM is the +# list of possible parameters indicating the offset to the N, S, E or W. +# IS_LAT is True if a latitude is being extracted, False for longitude. +def get_lat_long_1(temptype, args, rawargs, latd, latm, lats, offparam, is_lat): + d = getarg(latd, temptype, args, rawargs) + m = getarg(latm, temptype, args, rawargs, warnifnot=False) + s = getarg(lats, temptype, args, rawargs, warnifnot=False) + hemis = getarg(offparam, temptype, args, rawargs) + if hemis is None: + hemismult = get_hemisphere(temptype, is_lat) + else: + if is_lat: + convert = convert_ns + else: + convert = convert_ew + hemismult = convert.get(hemis, None) + if hemismult is None: + wikiwarning("%s for template type %s has bad value: [%s]" % + (offparam, temptype, hemis)) + return None + return convert_dms(hemismult, d, m, s) + +latd_arguments = ('latd', 'latg', 'latdeg', 'latdegrees', 'latitudedegrees', + 'latitudinegradi', 'latgradi', 'latitudined', 'latitudegraden', + 'breitengrad', 'breddegrad', 'breddegrad') +def get_latd_coord(temptype, args, rawargs): + '''Given a template of type TEMPTYPE with arguments ARGS (converted into +a hash table; also available in raw form as RAWARGS), assumed to have +a latitude/longitude specification in it using latd/lat_deg/etc. and +longd/lon_deg/etc., extract out and return a tuple of decimal +(latitude, longitude) values.''' + lat = get_lat_long_1(temptype, args, rawargs, + latd_arguments, + ('latm', 'latmin', 'latminutes', 'latitudeminutes', + 'latitudineprimi', 'latprimi', + 'latitudineminuti', 'latminuti', 'latitudinem', 'latitudeminuten', + 'breitenminute', 'breddemin'), + ('lats', 'latsec', 'latseconds', 'latitudeseconds', + 'latitudinesecondi', 'latsecondi', 'latitudines', 'latitudeseconden', + 'breitensekunde'), + ('latns', 'latp', 'lap', 'latdir', 'latdirection', 'latitudinens'), + is_lat=True) + long = get_lat_long_1(temptype, args, rawargs, + # Typos like Longtitude do occur in the Spanish Wikipedia at least + ('longd', 'lond', 'longg', 'long', 'longdeg', 'londeg', + 'longdegrees', 'londegrees', + 'longitudinegradi', 'longgradi', 'longitudined', + 'longitudedegrees', 'longtitudedegrees', + 'longitudegraden', + u'längengrad', 'laengengrad', 'lengdegrad', u'længdegrad'), + ('longm', 'lonm', 'longmin', 'lonmin', + 'longminutes', 'lonminutes', + 'longitudineprimi', 'longprimi', + 'longitudineminuti', 'longminuti', 'longitudinem', + 'longitudeminutes', 'longtitudeminutes', + 'longitudeminuten', + u'längenminute', u'længdemin'), + ('longs', 'lons', 'longsec', 'lonsec', + 'longseconds', 'lonseconds', + 'longitudinesecondi', 'longsecondi', 'longitudines', + 'longitudeseconds', 'longtitudeseconds', + 'longitudeseconden', + u'längensekunde'), + ('longew', 'lonew', 'longp', 'lonp', 'longdir', 'londir', + 'longdirection', 'londirection', 'longitudineew'), + is_lat=False) + return (lat, long) + +def get_built_in_lat_long_1(temptype, args, rawargs, latd, latm, lats, is_lat): + d = getarg(latd, temptype, args, rawargs) + m = getarg(latm, temptype, args, rawargs, warnifnot=False) + s = getarg(lats, temptype, args, rawargs, warnifnot=False) + return convert_dms(mult, d, m, s) + +built_in_latd_north_arguments = ('stopnin') +built_in_latd_south_arguments = ('stopnis') +built_in_longd_north_arguments = ('stopnie') +built_in_longd_south_arguments = ('stopniw') + +def get_built_in_lat_coord(temptype, args, rawargs): + '''Given a template of type TEMPTYPE with arguments ARGS (converted into +a hash table; also available in raw form as RAWARGS), assumed to have +a latitude/longitude specification in it using stopniN/etc. (where the +direction NSEW is built into the argument name), extract out and return a +tuple of decimal (latitude, longitude) values.''' + if getarg(built_in_latd_north_arguments, temptype, args, rawargs) is not None: + mult = 1 + elif getarg(built_in_latd_south_arguments, temptype, args, rawargs) is not None: + mult = -1 + else: + wikiwarning("Didn't see any appropriate stopniN/stopniS param") + mult = 1 # Arbitrarily set to N, probably accurate in Poland + lat = get_built_in_lat_long_1(temptype, args, rawargs, + ('stopnin', 'stopnis'), + ('minutn', 'minuts'), + ('sekundn', 'sekunds'), + mult) + if getarg(built_in_longd_north_arguments, temptype, args, rawargs) is not None: + mult = 1 + elif getarg(built_in_longd_south_arguments, temptype, args, rawargs) is not None: + mult = -1 + else: + wikiwarning("Didn't see any appropriate stopniE/stopniW param") + mult = 1 # Arbitrarily set to E, probably accurate in Poland + long = get_built_in_lat_long_1(temptype, args, rawargs, + ('stopnie', 'stopniw'), + ('minute', 'minutw'), + ('sekunde', 'sekundw'), + mult) + return (lat, long) + +latitude_arguments = ('latitude', 'latitud', 'latitudine', + 'breitengrad', + # 'breite', Sometimes used for latitudes but also for other types of width + #'lat' # Appears in non-article coordinates + #'latdec' # Appears to be associated with non-Earth coordinates + ) +longitude_arguments = ('longitude', 'longitud', 'longitudine', + u'längengrad', u'laengengrad', + # u'länge', u'laenge', Sometimes used for longitudes but also for other lengths + #'long' # Appears in non-article coordinates + #'longdec' # Appears to be associated with non-Earth coordinates + ) + +def get_latitude_coord(temptype, args, rawargs): + '''Given a template of type TEMPTYPE with arguments ARGS, assumed to have +a latitude/longitude specification in it, extract out and return a tuple of +decimal (latitude, longitude) values.''' + # German-style (e.g. 72/53/15/E) also occurs with 'latitude' and such, + # so just check for it everywhere. + lat = get_german_style_coord(getarg(latitude_arguments, + temptype, args, rawargs)) + long = get_german_style_coord(getarg(longitude_arguments, + temptype, args, rawargs)) + return (lat, long) + +def get_infobox_ort_coord(temptype, args, rawargs): + '''Given a template 'Infobox Ort' with arguments ARGS, assumed to have +a latitude/longitude specification in it, extract out and return a tuple of +decimal (latitude, longitude) values.''' + # German-style (e.g. 72/53/15/E) also occurs with 'latitude' and such, + # so just check for it everywhere. + lat = get_german_style_coord(getarg((u'breite',), + temptype, args, rawargs)) + long = get_german_style_coord(getarg((u'länge', u'laenge'), + temptype, args, rawargs)) + return (lat, long) + +# Utility function for get_coord(). Extract out the latitude or longitude +# values out of a Coord structure. Return a tuple (OFFSET, VAL) for decimal +# latitude or longitude VAL and OFFSET indicating the offset of the next +# argument after the arguments used to produce the value. +def get_coord_1(args, convert_nsew): + if args[1] in convert_nsew: + d = args[0]; m = 0; s = 0; i = 1 + elif args[2] in convert_nsew: + d = args[0]; m = args[1]; s = 0; i = 2 + elif args[3] in convert_nsew: + d = args[0]; m = args[1]; s = args[2]; i = 3 + else: + # Will happen e.g. in the style where only positive/negative are given + return (1, convert_dms(1, args[0], 0, 0)) + return (i+1, convert_dms(convert_nsew[args[i]], d, m, s)) + +# FIXME! To be more accurate, we need to look at the template parameters, +# which, despite the claim below, ARE quite interesting. In fact, if the +# parameter 'display=title' is seen (or variant like 'display=inline,title'), +# then we have *THE* correct coordinate for the article. So we need to +# return this fact if known, as an additional argument. See comments +# below at extract_coordinates_from_article(). + +def get_coord(temptype, args): + '''Parse a Coord template and return a tuple (lat,long) for latitude and +longitude. TEMPTYPE is the template name. ARGS is the raw arguments for +the template. Coord templates are one of four types: + +{{Coord|44.112|-87.913}} +{{Coord|44.112|N|87.913|W}} +{{Coord|44|6.72|N|87|54.78|W}} +{{Coord|44|6|43.2|N|87|54|46.8|W}} + +Note that all four of the above are equivalent. + +In addition, extra "template" or "coordinate" parameters can be given. +The template parameters mostly control display and are basically +uninteresting. (FIXME: Not true, see above.) However, the coordinate +parameters contain lots of potentially useful information that can be +used as features or whatever. See +http://en.wikipedia.org/wiki/Template:Coord for more information. + +The types of coordinate parameters are: + +type: country, city, city(###) where ### is the population, isle, river, etc. + Very useful feature; can also be used to filter uninteresting info as + some articles will have multiple coordinates in them. +scale: indicates the map scale (note that type: also specifies a default scale) +dim: diameter of viewing circle centered on coordinate (gives some sense of + how big the feature is) +region: the "political region for terrestrial coordinates", i.e. the country + the coordinate is in, as a two-letter ISO 3166-1 alpha-2 code, or the + country plus next-level subdivision (state, province, etc.) +globe: which planet or satellite the coordinate is on (esp. if not the Earth) +''' + if debug['some']: errprint("Coord: Passed in args %s" % args) + # Filter out optional "template arguments", add a bunch of blank arguments + # at the end to make sure we don't get out-of-bounds errors in + # get_coord_1() + filtargs = [x for x in args if '=' not in x] + if filtargs: + filtargs += ['','','','','',''] + (i, lat) = get_coord_1(filtargs, convert_ns) + (_, long) = get_coord_1(filtargs[i:], convert_ew) + return (lat, long) + else: + (paramshash, _) = find_template_params(args, True) + lat = paramshash.get('lat', None) or paramshash.get('latitude', None) + long = paramshash.get('long', None) or paramshash.get('longitude', None) + if lat is None or long is None: + wikiwarning("Can't find latitude/longitude in {{%s|%s}}" % + (temptype, '|'.join(args))) + lat = safe_float(lat) + long = safe_float(long) + return (lat, long) + +def check_for_bad_globe(paramshash): + if debug['some']: errprint("check_for_bad_globe: Passed in args %s" % paramshash) + globe = paramshash.get('globe', "").strip() + if globe: + if globe == "earth": + wikiwarning("Interesting, saw globe=earth") + else: + wikiwarning("Rejecting as non-earth, in template 'Coordinate/Coord/etc.' saw globe=%s" + % globe) + return True + return False + +def get_coordinate_coord(extract_coords_obj, temptype, rawargs): + '''Parse a Coordinate template and return a tuple (lat,long) for latitude and +longitude. TEMPTYPE is the template name. ARGS is the raw arguments for +the template. These templates tend to occur in the German Wikipedia. Examples: + +{{Coordinate|text=DMS|article=DMS|NS=51.50939|EW=-0.11832|type=city|pop=7825200|region=GB-LND}} +{{Coordinate|article=/|NS=41/00/00/N|EW=16/43/00/E|type=adm1st|region=IT-23}} +{Coordinate|NS=51/14/08.16/N|EW=6/48/37.43/E|text=DMS|name=Bronzetafel – Mittelpunkt Düsseldorfs|type=landmark|dim=50|region=DE-NW}} +{{Coordinate|NS=46.421401 <!-- {{subst:CH1903-WGS84|777.367|143.725||koor=B }} -->|EW=9.746124 <!-- {{subst:CH1903-WGS84|777.367|143.725||koor=L }} -->|region=CH-GR|text=DMS|type=isle|dim=500|name=Chaviolas}} +''' + if debug['some']: errprint("Passed in args %s" % rawargs) + (paramshash, _) = find_template_params(rawargs, True) + if check_for_bad_globe(paramshash): + extract_coords_obj.notearth = True + return (None, None) + lat = get_german_style_coord(getarg('ns', temptype, paramshash, rawargs)) + long = get_german_style_coord(getarg('ew', temptype, paramshash, rawargs)) + return (lat, long) + +def get_coord_params(temptype, args): + '''Parse a Coord template and return a list of tuples of coordinate +parameters (see comment under get_coord).''' + if debug['some']: errprint("Passed in args %s" % args) + # Filter out optional "template arguments" + filtargs = [x for x in args if '=' not in x] + if debug['some']: errprint("get_coord_params: filtargs: %s" % filtargs) + hash = {} + if filtargs and ':' in filtargs[-1]: + for x in filtargs[-1].split('_'): + if ':' in x: + (key, value) = x.split(':', 1) + hash[key] = value + return hash + +def get_geocoordenadas_coord(temptype, args): + '''Parse a geocoordenadas template (common in the Portuguese Wikipedia) and +return a tuple (lat,long) for latitude and longitude. TEMPTYPE is the +template name. ARGS is the raw arguments for the template. Typical example +is: + +{{geocoordenadas|39_15_34_N_24_57_9_E_type:waterbody|39° 15′ 34" N, 24° 57′ 9" O}} +''' + if debug['some']: errprint("Passed in args %s" % args) + # Filter out optional "template arguments", add a bunch of blank arguments + # at the end to make sure we don't get out-of-bounds errors in + # get_coord_1() + if len(args) == 0: + wikiwarning("No arguments to template 'geocoordenadas'") + return (None, None) + else: + # Yes, every one of the following problems occurs: Extra spaces; commas + # used instead of periods; lowercase nsew; use of O (Oeste) for "West", + # "L" (Leste) for "East" + arg = args[0].upper().strip().replace(',','.') + m = re.match(r'([0-9.]+)(?:_([0-9.]+))?(?:_([0-9.]+))?_([NS])_([0-9.]+)(?:_([0-9.]+))?(?:_([0-9.]+))?_([EWOL])(?:_.*)?$', arg) + if not m: + wikiwarning("Unrecognized argument %s to template 'geocoordenadas'" % + args[0]) + return (None, None) + else: + (latd, latm, lats, latns, longd, longm, longs, longew) = \ + m.groups() + return (convert_dms(convert_ns[latns], latd, latm, lats), + convert_dms(convert_ew[longew], longd, longm, longs)) + +class ExtractCoordinatesFromSource(RecursiveSourceTextHandler): + '''Given the article text TEXT of an article (in general, after first- +stage processing), extract coordinates out of templates that have coordinates +in them (Infobox, Coord, etc.). Record each coordinate into COORD. + +We don't recursively process text inside of templates or links. If we want +to do that, change this class to inherit from RecursiveSourceTextHandler. + +See process_article_text() for a description of the formatting that is +applied to the text before being sent here.''' + + def __init__(self): + self.coords = [] + self.notearth = False + + def process_template(self, text): + # Look for a Coord, Infobox, etc. template that may have coordinates in it + lat = long = None + if debug['some']: errprint("Enter process_template: [%s]" % text) + tempargs = get_macro_args(text) + temptype = tempargs[0].strip() + if debug['some']: errprint("Template type: %s" % temptype) + lowertemp = temptype.lower() + rawargs = tempargs[1:] + if (lowertemp.startswith('info/crater') or + lowertemp.endswith(' crater data') or + lowertemp.startswith('marsgeo') or + lowertemp.startswith('encelgeo') or + # All of the following are for heavenly bodies + lowertemp.startswith('infobox feature on ') or + lowertemp in (u'info/acidente geográfico de vênus', + u'infobox außerirdische region', + 'infobox lunar mare', 'encelgeo-crater', + 'infobox marskrater', 'infobox mondkrater', + 'infobox mondstruktur')): + self.notearth = True + wikiwarning("Rejecting as not on Earth because saw template %s" % temptype) + return [] + # Look for a coordinate template + if lowertemp in ('coord', 'coordp', 'coords', + 'koord', #Norwegian + 'coor', 'coor d', 'coor dm', 'coor dms', + 'coor title d', 'coor title dm', 'coor title dms', + 'coor dec', 'coorheader') \ + or lowertemp.startswith('geolinks') \ + or lowertemp.startswith('mapit') \ + or lowertemp.startswith('koordynaty'): # Coordinates in Polish: + (lat, long) = get_coord(temptype, rawargs) + coord_params = get_coord_params(temptype, tempargs[1:]) + if check_for_bad_globe(coord_params): + self.notearth = True + return [] + elif lowertemp == 'coordinate': + (lat, long) = get_coordinate_coord(self, temptype, rawargs) + elif lowertemp in ('geocoordenadas', u'coördinaten'): + # geocoordenadas is Portuguese, coördinaten is Dutch, and they work + # the same way + (lat, long) = get_geocoordenadas_coord(temptype, rawargs) + else: + # Look for any other template with a 'latd' or 'latitude' parameter. + # Usually these will be Infobox-type templates. Possibly we should only + # look at templates whose lowercased name begins with "infobox". + (paramshash, _) = find_template_params(rawargs, True) + if getarg(latd_arguments, temptype, paramshash, rawargs, warnifnot=False) is not None: + #errprint("seen: [%s] in {{%s|%s}}" % (getarg(latd_arguments, temptype, paramshash, rawargs), temptype, rawargs)) + (lat, long) = get_latd_coord(temptype, paramshash, rawargs) + # NOTE: DO NOT CHANGE ORDER. We want to check latd first and check + # latitude afterwards for various reasons (e.g. so that cases where + # breitengrad and breitenminute occur get found). FIXME: Maybe we + # don't need get_latitude_coord at all, but get_latd_coord will + # suffice. + elif getarg(latitude_arguments, temptype, paramshash, rawargs, warnifnot=False) is not None: + #errprint("seen: [%s] in {{%s|%s}}" % (getarg(latitude_arguments, temptype, paramshash, rawargs), temptype, rawargs)) + (lat, long) = get_latitude_coord(temptype, paramshash, rawargs) + elif (getarg(built_in_latd_north_arguments, temptype, paramshash, + rawargs, warnifnot=False) is not None or + getarg(built_in_latd_south_arguments, temptype, paramshash, + rawargs, warnifnot=False) is not None): + #errprint("seen: [%s] in {{%s|%s}}" % (getarg(built_in_latd_north_arguments, temptype, paramshash, rawargs), temptype, rawargs)) + #errprint("seen: [%s] in {{%s|%s}}" % (getarg(built_in_latd_south_arguments, temptype, paramshash, rawargs), temptype, rawargs)) + (lat, long) = get_built_in_lat_coord(temptype, paramshash, rawargs) + elif lowertemp in ('infobox ort', 'infobox verwaltungseinheit'): + (lat, long) = get_infobox_ort_coord(temptype, paramshash, rawargs) + + if debug['some']: wikiwarning("Saw coordinate %s,%s in template type %s" % + (lat, long, temptype)) + if lat is None and long is not None: + wikiwarning("Saw longitude %s but no latitude in template: %s" % + (long, bound_string_length(text))) + if long is None and lat is not None: + wikiwarning("Saw latitude %s but no longitude in template: %s" % + (lat, bound_string_length(text))) + if lat is not None and long is not None: + if lat == 0.0 and long == 0.0: + wikiwarning("Rejecting coordinate because zero latitude and longitude seen") + elif lat > 90.0 or lat < -90.0 or long > 180.0 or long < -180.0: + wikiwarning("Rejecting coordinate because out of bounds latitude or longitude: (%s,%s)" % (lat, long)) + else: + if lat == 0.0 or long == 0.0: + wikiwarning("Zero value in latitude and/or longitude: (%s,%s)" % + (lat, long)) + self.coords.append((lowertemp,lat,long)) + templates_with_coords[lowertemp] += 1 + # Recursively process the text inside the template in case there are + # coordinates in it. + return self.process_source_text(text[2:-2]) + +#category_types = [ +# ['neighbourhoods', 'neighborhood'], +# ['neighborhoods', 'neighborhood'], +# ['mountains', 'mountain'], +# ['stations', ('landmark', 'railwaystation')], +# ['rivers', 'river'], +# ['islands', 'isle'], +# ['counties', 'adm2nd'], +# ['parishes', 'adm2nd'], +# ['municipalities', 'city'], +# ['communities', 'city'], +# ['towns', 'city'], +# ['villages', 'city'], +# ['hamlets', 'city'], +# ['communes', 'city'], +# ['suburbs', 'city'], +# ['universities', 'edu'], +# ['colleges', 'edu'], +# ['schools', 'edu'], +# ['educational institutions', 'edu'], +# ['reserves', '?'], +# ['buildings', '?'], +# ['structures', '?'], +# ['landfills' '?'], +# ['streets', '?'], +# ['museums', '?'], +# ['galleries', '?'] +# ['organizations', '?'], +# ['groups', '?'], +# ['lighthouses', '?'], +# ['attractions', '?'], +# ['border crossings', '?'], +# ['forts', '?'], +# ['parks', '?'], +# ['townships', '?'], +# ['cathedrals', '?'], +# ['skyscrapers', '?'], +# ['waterfalls', '?'], +# ['caves', '?'], +# ['beaches', '?'], +# ['cemeteries'], +# ['prisons'], +# ['territories'], +# ['states'], +# ['countries'], +# ['dominions'], +# ['airports', 'airport'], +# ['bridges'], +# ] + + +class ExtractLocationTypeFromSource(RecursiveSourceTextHandler): + '''Given the article text TEXT of an article (in general, after first- +stage processing), extract info about the type of location (if any). +Record info found in 'loctype'.''' + + def __init__(self): + self.loctype = [] + self.categories = [] + + def process_internal_link(self, text): + tempargs = get_macro_args(text) + arg0 = tempargs[0].strip() + if arg0.startswith('Category:'): + self.categories += [arg0[9:].strip()] + return self.process_source_text(text[2:-2]) + + def process_template(self, text): + # Look for a Coord, Infobox, etc. template that may have coordinates in it + lat = long = None + tempargs = get_macro_args(text) + temptype = tempargs[0].strip() + lowertemp = temptype.lower() + # Look for a coordinate template + if lowertemp in ('coord', 'coor d', 'coor dm', 'coor dms', + 'coor dec', 'coorheader') \ + or lowertemp.startswith('geolinks') \ + or lowertemp.startswith('mapit'): + params = get_coord_params(temptype, tempargs[1:]) + if params: + # WARNING, this returns a hash table, not a list of tuples + # like the others do below. + self.loctype += [['coord-params', params]] + else: + (paramshash, _) = find_template_params(tempargs[1:], True) + if lowertemp == 'infobox settlement': + params = [] + for x in ['settlementtype', + 'subdivisiontype', 'subdivisiontype1', 'subdivisiontype2', + 'subdivisionname', 'subdivisionname1', 'subdivisionname2', + 'coordinatestype', 'coordinatesregion']: + val = paramshash.get(x, None) + if val: + params += [(x, val)] + self.loctype += [['infobox-settlement', params]] + elif ('latd' in paramshash or 'latdeg' in paramshash or + 'latitude' in paramshash): + self.loctype += \ + [['other-template-with-coord', [('template', temptype)]]] + # Recursively process the text inside the template in case there are + # coordinates in it. + return self.process_source_text(text[2:-2]) + +####################################################################### +# Process text for words # +####################################################################### + +# For a "macro" (e.g. internal link or template) with arguments, and +# a generator that returns the interesting arguments separately, process +# each of these arguments into chunks, join the chunks of an argument back +# together, and join the processed arguments, with spaces separating them. +# The idea is that for something like +# +# The [[latent variable|hidden node]]s are ... +# +# We will ultimately get something like +# +# The latent variable hidden nodes are ... +# +# after joining chunks. (Even better would be to correct handle something +# like +# +# The sub[[latent variable|node]]s are ... +# +# into +# +# The latent variable subnodes are ... +# +# But that's a major hassle, and such occurrences should be rare.) + +# Process an internal link into separate chunks for each interesting +# argument. Yield the chunks. They will be recursively processed, and +# joined by spaces. +def yield_internal_link_args(text): + tempargs = get_macro_args(text) + m = re.match(r'(?s)\s*([a-zA-Z0-9_]+)\s*:(.*)', tempargs[0]) + if m: + # Something like [[Image:...]] or [[wikt:...]] or [[fr:...]] + namespace = m.group(1).lower() + namespace = article_namespaces_lower.get(namespace, namespace) + if namespace in ('image', 6): # 6 = file + # For image links, filter out non-interesting args + for arg in tempargs[1:]: + # Ignore uninteresting args + if re.match(r'thumb|left|(up)?right|[0-9+](\s*px)?$', arg.strip()): pass + # For alt text, ignore the alt= but use the rest + else: + # Look for parameter spec + m = re.match(r'(?s)\s*([a-zA-Z0-9_]+)\s*=(.*)', arg) + if m: + (param, value) = m.groups() + if param.lower() == 'alt': + yield value + # Skip other parameters + # Use non-parameter args + else: yield arg + elif len(namespace) == 2 or len(namespace) == 3 or namespace == 'simple': + # A link to the equivalent page in another language; foreign words + # probably won't help for word matching. However, this might be + # useful in some other way. + pass + else: + # Probably either a category or wikt (wiktionary). + # The category is probably useful; the wiktionary entry maybe. + # In both cases, go ahead and use. + link = m.group(2) + # Skip "Appendix:" in "wikt:Appendix" + m = re.match(r'(?s)\s*[Aa]ppendix\s*:(.*)', link) + if m: yield m.group(1) + else: yield link + for arg in tempargs[1:]: yield arg + else: + # For textual internal link, use all arguments, unless --raw-text + if Opts.raw_text: + yield tempargs[-1] + else: + for chunk in tempargs: yield chunk + +# Process a template into separate chunks for each interesting +# argument. Yield the chunks. They will be recursively processed, and +# joined by spaces. +def yield_template_args(text): + # For a template, do something smart depending on the template. + if debug['lots']: errprint("yield_template_args called with: %s" % text) + + # OK, this is a hack, but a useful one. There are lots of templates that + # look like {{Emancipation Proclamation draft}} or + # {{Time measurement and standards}} or similar that are useful as words. + # So we look for templates without arguments that look like this. + # Note that we require the first word to have at least two letters, so + # we filter out things like {{R from related word}} or similar redirection- + # related indicators. Note that similar-looking templates that begin with + # a lowercase letter are sometimes useful like {{aviation lists}} or + # {{global warming}} but often are non-useful things like {{de icon}} or + # {{nowrap begin}} or {{other uses}}. Potentially we could be smarter + # about this. + if re.match(r'{{[A-Z][a-z]+ [A-Za-z ]+}}$', text): + yield text[2:-2] + return + + tempargs = get_macro_args(text) + if debug['lots']: errprint("template args: %s" % tempargs) + temptype = tempargs[0].strip().lower() + + if debug['some']: + all_templates[temptype] += 1 + + # Extract the parameter and non-parameter arguments. + (paramhash, nonparam) = find_template_params(tempargs[1:], False) + #errprint("params: %s" % paramhash) + #errprint("nonparam: %s" % nonparam) + + # For certain known template types, use the values from the interesting + # parameter args and ignore the others. For other template types, + # assume the parameter are uninteresting. + if re.match(r'v?cite', temptype): + # A citation, a very common type of template. + for (key,value) in paramhash.items(): + # A fairly arbitrary list of "interesting" parameters. + if re.match(r'(last|first|authorlink)[1-9]?$', key) or \ + re.match(r'(author|editor)[1-9]?-(last|first|link)$', key) or \ + key in ('coauthors', 'others', 'title', 'transtitle', + 'quote', 'work', 'contribution', 'chapter', 'transchapter', + 'series', 'volume'): + yield value + elif re.match(r'infobox', temptype): + # Handle Infoboxes. + for (key,value) in paramhash.items(): + # A fairly arbitrary list of "interesting" parameters. + # Remember that _ and space are removed. + if key in ('name', 'fullname', 'nickname', 'altname', 'former', + 'alt', 'caption', 'description', 'title', 'titleorig', + 'imagecaption', 'imagecaption', 'mapcaption', + # Associated with states, etc. + 'motto', 'mottoenglish', 'slogan', 'demonym', 'capital', + # Add more here + ): + yield value + elif re.match(r'coord', temptype): + return + + # For other template types, ignore all parameters and yield the + # remaining arguments. + # Yield any non-parameter arguments. + for arg in nonparam: + yield arg + +# Process a table into separate chunks. Unlike code for processing +# internal links, the chunks should have whitespace added where necessary. +def yield_table_chunks(text): + if debug['lots']: errprint("Entering yield_table_chunks: [%s]" % text) + + # Given a single line or part of a line, and an indication (ATSTART) of + # whether we just saw a beginning-of-line separator, split on within-line + # separators (|| or !!) and remove table directives that can occur at + # the beginning of a field (terminated by a |). Yield the resulting + # arguments as chunks. + def process_table_chunk_1(text, atstart): + for arg in re.split(r'(?:\|\||!!)', text): + if atstart: + m = re.match('(?s)[^|]*\|(.*)', arg) + if m: + yield m.group(1) + ' ' + continue + yield arg + atstart = True + + # Just a wrapper function around process_table_chunk_1() for logging + # purposes. + def process_table_chunk(text, atstart): + if debug['lots']: errprint("Entering process_table_chunk: [%s], %s" % (text, atstart)) + for chunk in process_table_chunk_1(text, atstart): + if debug['lots']: errprint("process_table_chunk yields: [%s]" % chunk) + yield chunk + + # Strip off {| and |} + text = text[2:-2] + ignore_text = True + at_line_beg = False + + # Loop over balanced chunks, breaking top-level text at newlines. + # Strip out notations like | and |- that separate fields, and strip out + # table directives (e.g. which occur after |-). Pass the remainder to + # process_table_chunk(), which will split a line on within-line separators + # (e.g. || or !!) and strip out directives. + for arg in parse_balanced_text(balanced_table_re, text): + if debug['lots']: errprint("parse_balanced_text(balanced_table_re) yields: [%s]" % arg) + # If we see a newline, reset the flags and yield the newline. This way, + # a whitespace will always be inserted. + if arg == '\n': + ignore_text = False + at_line_beg = True + yield arg + if at_line_beg: + if arg.startswith('|-'): + ignore_text = True + continue + elif arg.startswith('|') or arg.startswith('!'): + arg = arg[1:] + if arg and arg[0] == '+': arg = arg[1:] + # The chunks returned here are separate fields. Make sure whitespace + # separates them. + yield ' '.join(process_table_chunk(arg, True)) + continue + elif ignore_text: continue + # Add whitespace between fields, as above. + yield ' '.join(process_table_chunk(arg, False)) + +# Given raw text, split it into words, filtering out punctuation, and +# yield the words. Also ignore words with a colon in the middle, indicating +# likely URL's and similar directives. +def split_text_into_words(text): + (text, _) = re.subn(left_ref_re, r' ', text) + if Opts.no_tokenize: + # No tokenization requested. Just split on whitespace. But still try + # to eliminate URL's. Rather than just look for :, we look for :/, which + # URL's are likely to contain. Possibly we should look for a colon in + # the middle of a word, which is effectively what the checks down below + # do (or modify those checks to look for :/). + for word in re.split('\s+', text): + if ':/' not in word: + yield word + elif Opts.raw_text: + # This regexp splits on whitespace, but also handles the following cases: + # 1. Any of , ; . etc. at the end of a word + # 2. Parens or quotes in words like (foo) or "bar" + off = 0 + for word in re.split(r'([,;."):]*)\s+([("]*)', text): + if (off % 3) != 0: + for c in word: + yield c + else: + # Sometimes URL's or other junk slips through. Much of this junk has + # a colon in it and little useful stuff does. + if ':' not in word: + # Handle things like "Two-port_network#ABCD-parameters". Do this after + # filtering for : so URL's don't get split up. + for word2 in re.split('[#_]', word): + if word2: yield word2 + off += 1 + else: + # This regexp splits on whitespace, but also handles the following cases: + # 1. Any of , ; . etc. at the end of a word + # 2. Parens or quotes in words like (foo) or "bar" + for word in re.split(r'[,;."):]*\s+[("]*', text): + # Sometimes URL's or other junk slips through. Much of this junk has + # a colon in it and little useful stuff does. + if ':' not in word: + # Handle things like "Two-port_network#ABCD-parameters". Do this after + # filtering for : so URL's don't get split up. + for word2 in re.split('[#_]', word): + if word2: yield word2 + +# Extract "useful" text (generally, text that will be seen by the user, +# or hidden text of similar quality) and yield up chunks. + +class ExtractUsefulText(SourceTextHandler): + def process_and_join_arguments(self, args_of_macro): + return ' '.join(''.join(self.process_source_text(chunk)) + for chunk in args_of_macro) + + def process_internal_link(self, text): + '''Process an internal link into chunks of raw text and yield them.''' + # Find the interesting arguments of an internal link and join + # with spaces. + yield self.process_and_join_arguments(yield_internal_link_args(text)) + + def process_template(self, text): + '''Process a template into chunks of raw text and yield them.''' + # Find the interesting arguments of a template and join with spaces. + yield self.process_and_join_arguments(yield_template_args(text)) + + def process_table(self, text): + '''Process a table into chunks of raw text and yield them.''' + for bar in yield_table_chunks(text): + if debug['lots']: errprint("process_table yields: [%s]" % bar) + for baz in self.process_source_text(bar): + yield baz + + def process_external_link(self, text): + '''Process an external link into chunks of raw text and yield them.''' + # For an external link, use the anchor text of the link, if any + splitlink = re.split(r'\s+', text[1:-1], 1) + if len(splitlink) == 2: + (link, linktext) = splitlink + for chunk in self.process_source_text(linktext): + yield chunk + + def process_reference(self, text): + return self.process_source_text(" " + text[5:-6] + " ") + +####################################################################### +# Formatting text to make processing easier # +####################################################################### + +# Process the text in various ways in preparation for extracting data +# from the text. +def format_text_first_pass(text): + # Remove all comments from the text; may contain malformed stuff of + # various sorts, and generally stuff we don't want to index + (text, _) = re.subn(r'(?s)', '', text) + + # Get rid of all text inside of ..., which is in a different + # format (TeX), and mostly non-useful. + (text, _) = re.subn(r'(?s).*?', '', text) + + # Try getting rid of everything in a reference + #(text, _) = re.subn(r'(?s).*?', '', text) + #(text, _) = re.subn(r'(?s)/]*?/>', '', text) + + # Convert occurrences of   and – and similar, which occur often + # (note that SAX itself should handle entities like this; occurrences that + # remain must have had the ampersand converted to &) + (text, _) = re.subn(r' ', ' ', text) + (text, _) = re.subn(r' ', ' ', text) + (text, _) = re.subn(r'&[nm]dash;', '-', text) + (text, _) = re.subn(r'−', '-', text) + (text, _) = re.subn(r'&', '&', text) + (text, _) = re.subn(r'×', '*', text) + (text, _) = re.subn(r'…', '...', text) + (text, _) = re.subn(r'<', '<', text) + (text, _) = re.subn(r'>', '>', text) + #(text, _) = re.subn(r'[', '[', text) + #(text, _) = re.subn(r']', ']', text) + + return text + +# Process the text in various ways in preparation for extracting +# the words from the text. +def format_text_second_pass(text): + # Convert breaks into newlines + (text, _) = re.subn(r'', r'\n', text) + + # Remove references, but convert to whitespace to avoid concatenating + # words outside and inside a reference together + #(text, _) = re.subn(r'(?s)', ' ', text) + + # An alternative approach. + # Convert references to simple tags. + (text, _) = re.subn(r'(?s)]*?/>', ' ', text) + (text, _) = re.subn(r'(?s)', '< ref>', text) + (text, _) = re.subn(r'(?s)', '< /ref>', text) + + # Similar for nowiki, which may have <'s, brackets and such inside. + (text, _) = re.subn(r'(?s).*?', ' ', text) + + # Another hack: Inside of ..., there are raw filenames. + # Get rid of. + + def process_gallery(text): + # Split on gallery blocks (FIXME, recursion not handled). Putting a + # group around the split text ensures we get it returned along with the + # other text. + chunks = re.split(r'(?s)(.*?)', text) + for chunk in chunks: + # If a gallery, extract the stuff inside ... + m = re.match(r'^(?s)(.*?)$', chunk) + if m: + chunk = m.group(1) + # ... then remove files and images, but keep any text after | + (chunk, _) = re.subn(r'(?m)^(?:File|Image):[^|\n]*$', '', chunk) + (chunk, _) = re.subn(r'(?m)^(?:File|Image):[^|\n]*\|(.*)$', + r'\1', chunk) + yield chunk + + text = ''.join(process_gallery(text)) + + # Remove remaining HTML codes from the text + (text, _) = re.subn(r'(?s)<[A-Za-z/].*?>', '', text) + + (text, _) = re.subn(r'< (/?ref)>', r'<\1>', text) + + # Remove multiple sequences of quotes (indicating boldface or italics) + (text, _) = re.subn(r"''+", '', text) + + # Remove beginning-of-line markers indicating indentation, lists, headers, + # etc. + (text, _) = re.subn(r"(?m)^[*#:]+", '', text) + + # Remove end-of-line markers indicating headers (e.g. ===Introduction===) + (text, _) = re.subn(r"(?m)^=+(.*?)=+$", r'\1', text) + + return text + +####################################################################### +# Article handlers # +####################################################################### + + + +### Default handler class for processing article text. Subclass this to +### implement your own handlers. +class ArticleHandler(object): + def __init__(self): + self.title = None + self.id = None + + redirect_commands = "|".join([ + # English, etc. + 'redirect', 'redirect to', + # Italian (IT) + 'rinvia', 'rinvio', + # Polish (PL) + 'patrz', 'przekieruj', 'tam', + # Dutch (NL) + 'doorverwijzing', + # French (FR) + 'redirection', + # Spanish (ES) + u'redirección', + # Portuguese (PT) + 'redirecionamento', + # German (DE) + 'weiterleitung', + # Russian (RU) + u'перенаправление', + ]) + + global redirect_re + redirect_re = re.compile(ur'(?i)#(?:%s)\s*:?\s*\[\[(.*?)\]\]' % + redirect_commands) + + # Process the text of article TITLE, with text TEXT. The default + # implementation does the following: + # + # 1. Remove comments, math, and other unuseful stuff. + # 2. If article is a redirect, call self.process_redirect() to handle it. + # 3. Else, call self.process_text_for_data() to extract data out. + # 4. If that handler returned True, call self.process_text_for_text() + # to do processing of the text itself (e.g. for words). + + def process_article_text(self, text, title, id, redirect): + self.title = title + self.id = id + global debug_cur_title + debug_cur_title = title + + if debug['some']: + errprint("Article title: %s" % title) + errprint("Article ID: %s" % id) + errprint("Article is redirect: %s" % redirect) + errprint("Original article text:\n%s" % text) + + ### Preliminary processing of text, removing stuff unuseful even for + ### extracting data. + + text = format_text_first_pass(text) + + ### Look to see if the article is a redirect + + if redirect: + m = redirect_re.match(text.strip()) + if m: + self.process_redirect(m.group(1)) + # NOTE: There may be additional templates specified along with a + # redirection page, typically something like {{R from misspelling}} + # that gives the reason for the redirection. Currently, we ignore + # such templates. + return + else: + wikiwarning( + "Article %s (ID %s) is a redirect but can't parse redirect spec %s" + % (title, id, text)) + + ### Extract the data out of templates; if it returns True, also process + ### text for words + + if self.process_text_for_data(text): + self.process_text_for_text(text) + + # Process the text itself, e.g. for words. Default implementation does + # nothing. + def process_text_for_text(self, text): + pass + + # Process an article that is a redirect. Default implementation does + # nothing. + + def process_redirect(self, redirtitle): + pass + + # Process the text and extract data. Return True if further processing of + # the article should happen. (Extracting the real text in + # process_text_for_text() currently takes up the vast majority of running + # time, so skipping it is a big win.) + # + # Default implementation just returns True. + + def process_text_for_data(self, text): + return True + + def finish_processing(self): + pass + + + +### Default handler class for processing article text, including returning +### "useful" text (what the Wikipedia user sees, plus similar-quality +### hidden text). +class ArticleHandlerForUsefulText(ArticleHandler): + # Process the text itself, e.g. for words. Input it text that has been + # preprocessed as described above (remove comments, etc.). Default + # handler does two things: + # + # 1. Further process the text (see format_text_second_pass()) + # 2. Use process_source_text() to extract chunks of useful + # text. Join together and then split into words. Pass the generator + # of words to self.process_text_for_words(). + + def process_text_for_text(self, text): + # Now process the text in various ways in preparation for extracting + # the words from the text + text = format_text_second_pass(text) + # Now process the resulting text into chunks. Join them back together + # again (to handle cases like "the [[latent variable]]s are ..."), and + # split to find words. + self.process_text_for_words( + split_text_into_words( + ''.join(ExtractUsefulText().process_source_text(text)))) + + # Process the real words of the text of an article. Default implementation + # does nothing. + + def process_text_for_words(self, word_generator): + pass + + + +# Print out the info passed in for article words; as for the implementation of +# process_text_for_data(), uses ExtractCoordinatesFromSource() to extract +# coordinates, and outputs all the coordinates seen. Always returns True. + +class OutputAllWords(ArticleHandlerForUsefulText): + def process_text_for_words(self, word_generator): + splitprint("Article title: %s" % self.title) + splitprint("Article ID: %s" % self.id) + for word in word_generator: + if debug['some']: errprint("Saw word: %s" % word) + else: splitprint("%s" % word) + + def process_text_for_data(self, text): + #handler = ExtractCoordinatesFromSource() + #for foo in handler.process_source_text(text): pass + #for (temptype,lat,long) in handler.coords: + # splitprint("Article coordinates: %s,%s" % (lat, long)) + return True + + def finish_processing(self): + ### Output all of the templates that were seen with coordinates in them, + ### along with counts of how many times each template was seen. + if debug['some']: + print("Templates with coordinates:") + output_reverse_sorted_table(templates_with_coords, + outfile=cur_output_file) + + print("All templates:") + output_reverse_sorted_table(all_templates, outfile=cur_output_file) + + print "Notice: ending processing" + + +class OutputCoordWords(OutputAllWords): + def process_text_for_data(self, text): + if extract_coordinates_from_article(text): + return True + return False + + +# Just find redirects. + +class FindRedirects(ArticleHandler): + def process_redirect(self, redirtitle): + splitprint("Article title: %s" % self.title) + splitprint("Article ID: %s" % self.id) + splitprint("Redirect to: %s" % redirtitle) + +def output_title(title, id): + splitprint("Article title: %s" % title) + splitprint("Article ID: %s" % id) + +def output_title_and_coordinates(title, id, lat, long): + output_title(title, id) + splitprint("Article coordinates: %s,%s" % (lat, long)) + +# FIXME: +# +# (1) Figure out whether coordinates had a display=title in them. +# If so, use the last one. +# (2) Else, use the last other Coord, but possibly limit to Coords that +# appear on a line by themselves or at least are at top level (not +# inside some other template, table, etc.). +# (3) Else, do what we prevously did. +# +# Also, we should test to see whether it's better in (2) to limit Coords +# to those that apear on a line by themselves. To do that, we'd generate +# coordinates for Wikipedia, and in the process note +# +# (1) Whether it was step 1, 2 or 3 above that produced the coordinate; +# (2) If step 2, would the result have been different if we did step 2 +# differently? Check the possibilities: No limit in step 2; +# (maybe, if not too hard) limit to those things at top level; +# limit to be on line by itself; don't ever use Coords in step 2. +# If there is a difference among the results of any of these strategies +# debug-output this fact along with the different values and the +# strategies that produced them. +# +# Then +# +# (1) Output counts of how many resolved through steps 1, 2, 3, and how +# many in step 2 triggered a debug-output. +# (2) Go through manually and check e.g. 50 of the ones with debug-output +# and see which one is more correct. + +def extract_coordinates_from_article(text): + handler = ExtractCoordinatesFromSource() + for foo in handler.process_source_text(text): pass + if handler.notearth: + return None + elif len(handler.coords) > 0: + # Prefer a coordinate specified using {{Coord|...}} or similar to + # a coordinate in an Infobox, because the latter tend to be less + # accurate. + for (temptype, lat, long) in handler.coords: + if temptype.startswith('coor'): + return (lat, long) + (temptype, lat, long) = handler.coords[0] + return (lat, long) + else: return None + +def extract_and_output_coordinates_from_article(title, id, text): + retval = extract_coordinates_from_article(text) + if retval == None: return False + (lat, long) = retval + output_title_and_coordinates(title, id, lat, long) + return True + +def extract_location_type(text): + handler = ExtractLocationTypeFromSource() + for foo in handler.process_source_text(text): pass + for (ty, vals) in handler.loctype: + splitprint(" %s: %s" % (ty, vals)) + for cat in handler.categories: + splitprint(" category: %s" % cat) + +# Handler to output count information on words. Only processes articles +# with coordinates in them. Computes the count of each word in the article +# text, after filtering text for "actual text" (as opposed to directives +# etc.), and outputs the counts. + +class OutputCoordCounts(ArticleHandlerForUsefulText): + def process_text_for_words(self, word_generator): + wordhash = intdict() + for word in word_generator: + if word: wordhash[word] += 1 + output_reverse_sorted_table(wordhash, outfile=cur_output_file) + + def process_text_for_data(self, text): + if extract_coordinates_from_article(text): + output_title(self.title, self.id) + return True + return False + +# Same as above but output counts for all articles, not just those with +# coordinates in them. + +class OutputAllCounts(OutputCoordCounts): + def process_text_for_data(self, text): + output_title(self.title, self.id) + return True + +# Handler to output just coordinate information. +class OutputCoords(ArticleHandler): + def process_text_for_data(self, text): + return extract_and_output_coordinates_from_article(self.title, self.id, + text) + +# Handler to try to determine the type of an article with coordinates. +class OutputLocationType(ArticleHandler): + def process_text_for_data(self, text): + iscoord = extract_and_output_coordinates_from_article(self.title, self.id, + text) + if iscoord: + extract_location_type(text) + return iscoord + + +class ToponymEvalDataHandler(ExtractUsefulText): + def join_arguments_as_generator(self, args_of_macro): + first = True + for chunk in args_of_macro: + if not first: yield ' ' + first = False + for chu in self.process_source_text(chunk): + yield chu + + # OK, this is a bit tricky. The definitions of process_template() and + # process_internal_link() in ExtractUsefulText() use yield_template_args() + # and yield_internal_link_args(), respectively, to yield arguments, and + # then call process_source_text() to recursively process the arguments and + # then join everything together into a string, with spaces between the + # chunks corresponding to separate arguments. The joining together + # happens inside of process_and_join_arguments(). This runs into problems + # if we have an internal link inside of another internal link, which often + # happens with images, which are internal links that have an extra caption + # argument, which frequently contains (nested) internal links. The + # reason is that we've overridden process_internal_link() to sometimes + # return a tuple (which signals the outer handler that we found a link + # of the appropriate sort), and the joining together chokes on non-string + # arguments. So instead, we "join" arguments by just yielding everything + # in sequence, with spaces inserted as needed between arguments; this + # happens in join_arguments_as_generator(). We specifically need to + # override process_template() (and already override process_internal_link()), + # because it's exactly those two that currently call + # process_and_join_arguments(). + # + # The idea is that we never join arguments together at any level of + # recursion, but just yield chunks. At the topmost level, we will join + # as necessary and resplit for word boundaries. + + def process_template(self, text): + for chunk in self.join_arguments_as_generator(yield_template_args(text)): + yield chunk + + def process_internal_link(self, text): + tempargs = get_macro_args(text) + m = re.match(r'(?s)\s*([a-zA-Z0-9_]+)\s*:(.*)', tempargs[0]) + if m: + # Something like [[Image:...]] or [[wikt:...]] or [[fr:...]] + # For now, just skip them all; eventually, might want to do something + # useful with some, e.g. categories + pass + else: + article = capfirst(tempargs[0]) + # Skip links to articles without coordinates + if coordinate_articles and article not in coordinate_articles: + pass + else: + yield ('link', tempargs) + return + + for chunk in self.join_arguments_as_generator(yield_internal_link_args(text)): + yield chunk + + +class GenerateToponymEvalData(ArticleHandler): + # Process the text itself, e.g. for words. Input it text that has been + # preprocessed as described above (remove comments, etc.). Default + # handler does two things: + # + # 1. Further process the text (see format_text_second_pass()) + # 2. Use process_source_text() to extract chunks of useful + # text. Join together and then split into words. Pass the generator + # of words to self.process_text_for_words(). + + def process_text_for_text(self, text): + # Now process the text in various ways in preparation for extracting + # the words from the text + text = format_text_second_pass(text) + + splitprint("Article title: %s" % self.title) + chunkgen = ToponymEvalDataHandler().process_source_text(text) + #for chunk in chunkgen: + # errprint("Saw chunk: %s" % (chunk,)) + # groupby() allows us to group all the non-link chunks (which are raw + # strings) together efficiently + for k, g in itertools.groupby(chunkgen, + lambda chunk: type(chunk) is tuple): + #args = [arg for arg in g] + #errprint("Saw k=%s, g=%s" % (k,args)) + if k: + for (linktext, linkargs) in g: + splitprint("Link: %s" % '|'.join(linkargs)) + else: + # Now process the resulting text into chunks. Join them back together + # again (to handle cases like "the [[latent variable]]s are ..."), and + # split to find words. + for word in split_text_into_words(''.join(g)): + if word: + splitprint("%s" % word) + +# Generate article data of various sorts +class GenerateArticleData(ArticleHandler): + def process_article(self, redirtitle): + if rematch('(.*?):', self.title): + namespace = m_[1] + if namespace in article_namespace_aliases: + namespace = article_namespace_aliases[namespace] + elif namespace not in article_namespaces: + namespace = 'Main' + else: + namespace = 'Main' + yesno = {True:'yes', False:'no'} + listof = self.title.startswith('List of ') + disambig = self.id in disambig_pages_by_id + nskey = article_namespace_aliases.get(namespace, namespace) + list = listof or disambig or nskey in (14, 108) # ('Category', 'Book') + outprint("%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s" % + (self.id, self.title, cur_split_name, redirtitle, namespace, + yesno[listof], yesno[disambig], yesno[list])) + + def process_redirect(self, redirtitle): + self.process_article(capfirst(redirtitle)) + + def process_text_for_data(self, text): + self.process_article('') + return False + +# Handler to output link information as well as coordinate information. +# Note that a link consists of two parts: The anchor text and the article +# name. For all links, we keep track of all the possible articles for a +# given anchor text and their counts. We also count all of the incoming +# links to an article (can be used for computing prior probabilities of +# an article). + +class ProcessSourceForCoordLinks(RecursiveSourceTextHandler): + useful_text_handler = ExtractUsefulText() + def process_internal_link(self, text): + tempargs = get_macro_args(text) + m = re.match(r'(?s)\s*([a-zA-Z0-9_]+)\s*:(.*)', tempargs[0]) + if m: + # Something like [[Image:...]] or [[wikt:...]] or [[fr:...]] + # For now, just skip them all; eventually, might want to do something + # useful with some, e.g. categories + pass + else: + article = capfirst(tempargs[0]) + # Skip links to articles without coordinates + if coordinate_articles and article not in coordinate_articles: + pass + else: + anchor = ''.join(self.useful_text_handler. + process_source_text(tempargs[-1])) + incoming_link_count[article] += 1 + if anchor not in anchor_text_map: + nested_anchor_text_map = intdict() + anchor_text_map[anchor] = nested_anchor_text_map + else: + nested_anchor_text_map = anchor_text_map[anchor] + nested_anchor_text_map[article] += 1 + + # Also recursively process all the arguments for links, etc. + return self.process_source_text(text[2:-2]) + +class FindCoordLinks(ArticleHandler): + def process_text_for_data(self, text): + handler = ProcessSourceForCoordLinks() + for foo in handler.process_source_text(text): pass + return False + + def finish_processing(self): + print "------------------ Count of incoming links: ---------------" + output_reverse_sorted_table(incoming_link_count, outfile=cur_output_file) + + print "===========================================================" + print "===========================================================" + print "===========================================================" + print "" + for (anchor,map) in anchor_text_map.items(): + splitprint("-------- Anchor text->article for %s: " % anchor) + output_reverse_sorted_table(map, outfile=cur_output_file) + +####################################################################### +# SAX handler for processing raw dump files # +####################################################################### + +class FinishParsing: + pass + +# We do a very simple-minded way of handling the XML. We maintain the +# path of nested elements that we're within, and we track the text since the +# last time we saw the beginning of an element. We reset the text we're +# tracking every time we see an element begin tag, and we don't record +# text at all after an end tag, until we see a begin tag again. Basically, +# this means we don't handle cases where tags are nested inside of text. +# This isn't a problem since cases like this don't occur in the Wikipedia +# dump. + +class WikipediaDumpSaxHandler(ContentHandler): + '''SAX handler for processing Wikipedia dumps. Note that SAX is a +simple interface for handling XML in a serial fashion (as opposed to a +DOM-type interface, which reads the entire XML file into memory and allows +it to be dynamically manipulated). Given the size of the XML dump file +(around 25 GB uncompressed), we can't read it all into memory.''' + def __init__(self, output_handler): + errprint("Beginning processing of Wikipedia dump...") + self.curpath = [] + self.curattrs = [] + self.curtext = None + self.output_handler = output_handler + self.status = StatusMessage('article') + + def startElement(self, name, attrs): + '''Handler for beginning of XML element.''' + if debug['sax']: + errprint("startElement() saw %s/%s" % (name, attrs)) + for (key,val) in attrs.items(): errprint(" Attribute (%s,%s)" % (key,val)) + # We should never see an element inside of the Wikipedia text. + if self.curpath: + assert self.curpath[-1] != 'text' + self.curpath.append(name) + self.curattrs.append(attrs) + self.curtext = [] + # We care about the title, ID, and redirect status. Reset them for + # every page; this is especially important for redirect status. + if name == 'page': + self.title = None + self.id = None + self.redirect = False + + def characters(self, text): + '''Handler for chunks of text. Accumulate all adjacent chunks. When +the end element is seen, process_article_text() will be called on the +combined chunks.''' + if debug['sax']: errprint("characters() saw %s" % text) + # None means the last directive we saw was an end tag; we don't track + # text any more until the next begin tag. + if self.curtext != None: + self.curtext.append(text) + + def endElement(self, name): + '''Handler for end of XML element.''' + eltext = ''.join(self.curtext) if self.curtext else '' + self.curtext = None # Stop tracking text + self.curpath.pop() + attrs = self.curattrs.pop() + if name == 'title': + self.title = eltext + # ID's occur in three places: the page ID, revision ID and contributor ID. + # We only want the page ID, so check to make sure we've got the right one. + elif name == 'id' and self.curpath[-1] == 'page': + self.id = eltext + elif name == 'redirect': + self.redirect = True + elif name == 'namespace': + key = attrs.getValue("key") + if debug['sax']: errprint("Saw namespace, key=%s, eltext=%s" % + (key, eltext)) + article_namespaces[eltext] = key + article_namespaces_lower[eltext.lower()] = key + elif name == 'text': + # If we saw the end of the article text, join all the text chunks + # together and call process_article_text() on it. + set_next_split_file() + if debug['lots']: + max_text_len = 150 + endslice = min(max_text_len, len(eltext)) + truncated = len(eltext) > max_text_len + errprint( + """Calling process_article_text with title=%s, id=%s, redirect=%s; + text=[%s%s]""" % (self.title, self.id, self.redirect, eltext[0:endslice], + "..." if truncated else "")) + self.output_handler.process_article_text(text=eltext, title=self.title, + id=self.id, redirect=self.redirect) + if self.status.item_processed(maxtime=Opts.max_time_per_stage): + raise FinishParsing() + +####################################################################### +# Main code # +####################################################################### + + +def main_process_input(wiki_handler): + ### Create the SAX parser and run it on stdin. + sax_parser = make_parser() + sax_handler = WikipediaDumpSaxHandler(wiki_handler) + sax_parser.setContentHandler(sax_handler) + try: + sax_parser.parse(sys.stdin) + except FinishParsing: + pass + wiki_handler.finish_processing() + +def main(): + + op = OptionParser(usage="%prog [options] < file") + op.add_option("--output-all-words", + help="Output words of text, for all articles.", + action="store_true") + op.add_option("--output-coord-words", + help="Output text, but only for articles with coordinates.", + action="store_true") + op.add_option("--raw-text", help="""When outputting words, make output +resemble some concept of "raw text". Currently, this just includes +punctuation instead of omitting it, and shows only the anchor text of a +link rather than both the anchor text and actual article name linked to, +when different.""", action="store_true") + op.add_option("--no-tokenize", help="""When outputting words, don't tokenize. +This causes words to only be split on whitespace, rather than also on +punctuation.""", action="store_true") + op.add_option("--find-coord-links", + help="""Find all links and print info about them, for +articles with coordinates or redirects to such articles. Includes count of +incoming links, and, for each anchor-text form, counts of all articles it +maps to.""", + action="store_true") + op.add_option("--output-all-counts", + help="Print info about counts of words, for all articles.", + action="store_true") + op.add_option("--output-coord-counts", + help="Print info about counts of words, but only for articles with coodinates.", + action="store_true") + op.add_option("--output-coords", + help="Print info about coordinates of articles with coordinates.", + action="store_true") + op.add_option("--output-location-type", + help="Print info about type of articles with coordinates.", + action="store_true") + op.add_option("--find-redirects", + help="Output all redirects.", + action="store_true") + op.add_option("--generate-toponym-eval", + help="Generate data files for use in toponym evaluation.", + action="store_true") + op.add_option("--generate-article-data", + help="""Generate file listing all articles and info about them. +If using this option, the --disambig-id-file and --split-training-dev-test +options should also be used. + +The format is + +ID TITLE SPLIT REDIR NAMESPACE LIST-OF DISAMBIG LIST + +where each field is separated by a tab character. + +The fields are + +ID = Numeric ID of article, given by wikiprep +TITLE = Title of article +SPLIT = Split to assign the article to; one of 'training', 'dev', or 'test'. +REDIR = If the article is a redirect, lists the article it redirects to; + else, blank. +NAMESPACE = Namespace of the article, one of 'Main', 'User', 'Wikipedia', + 'File', 'MediaWiki', 'Template', 'Help', 'Category', 'Thread', + 'Summary', 'Portal', 'Book'. These are the basic namespaces + defined in [[Wikipedia:Namespace]]. Articles of the appropriate + namespace begin with the namespace prefix, e.g. 'File:*', except + for articles in the main namespace, which includes everything + else. Note that some of these namespaces don't actually appear + in the article dump; likewise, talk pages don't appear in the + dump. In addition, we automatically include the common namespace + abbreviations in the appropriate space, i.e. + + P Portal + H Help + T Template + CAT, Cat, C Category + MOS, MoS, Mos Wikipedia (used for "Manual of Style" pages) +LIST-OF = 'yes' if article title is of the form 'List of *', typically + containing a list; else 'no'. +DISAMBIG = 'yes' if article is a disambiguation page (used to disambiguate + multiple concepts with the same name); else 'no'. +LIST = 'yes' if article is a list of some sort, else no. This includes + 'List of' articles, disambiguation pages, and articles in the 'Category' + and 'Book' namespaces.""", + action="store_true") + op.add_option("--split-training-dev-test", + help="""Split output into training, dev and test files. +Use the specified value as the file prefix, suffixed with '.train', '.dev' +and '.test' respectively.""", + metavar="FILE") + op.add_option("--training-fraction", type='float', default=80, + help="""Fraction of total articles to use for training. +The absolute amount doesn't matter, only the value relative to the test +and dev fractions, as the values are normalized. Default %default.""", + metavar="FRACTION") + op.add_option("--dev-fraction", type='float', default=10, + help="""Fraction of total articles to use for dev set. +The absolute amount doesn't matter, only the value relative to the training +and test fractions, as the values are normalized. Default %default.""", + metavar="FRACTION") + op.add_option("--test-fraction", type='float', default=10, + help="""Fraction of total articles to use for test set. +The absolute amount doesn't matter, only the value relative to the training +and dev fractions, as the values are normalized. Default %default.""", + metavar="FRACTION") + op.add_option("--coords-file", + help="""File containing output from a prior run of +--coords-counts, listing all the articles with associated coordinates. +This is used to limit the operation of --find-coord-links to only consider +links to articles with coordinates. Currently, if this is not done, then +using --coords-file requires at least 10GB, perhaps more, of memory in order +to store the entire table of anchor->article mappings in memory. (If this +entire table is needed, it may be necessary to implement a MapReduce-style +process where smaller chunks are processed separately and then the results +combined.)""", + metavar="FILE") + op.add_option("--article-data-file", + help="""File containing article data. Used by +--find-coord-links to find the redirects pointing to articles with +coordinates.""", + metavar="FILE") + op.add_option("--disambig-id-file", + help="""File containing list of article ID's that are +disambiguation pages.""", + metavar="FILE") + op.add_option("--max-time-per-stage", "--mts", type='int', default=0, + help="""Maximum time per stage in seconds. If 0, no limit. +Used for testing purposes. Default %default.""") + op.add_option("--debug", metavar="FLAGS", + help="Output debug info of the given types (separated by spaces or commas)") + + errprint("Arguments: %s" % ' '.join(sys.argv)) + opts, args = op.parse_args() + output_option_parameters(opts) + + global Opts + Opts = opts + + global debug + if opts.debug: + flags = re.split(r'[,\s]+', opts.debug) + for f in flags: + debug[f] = True + if debug['err'] or debug['some'] or debug['lots'] or debug['sax']: + cur_output_file = sys.stderr + debug_to_stderr = True + + if opts.split_training_dev_test: + init_output_files(opts.split_training_dev_test, + [opts.training_fraction, opts.dev_fraction, + opts.test_fraction], + ['training', 'dev', 'test']) + + if opts.coords_file: + read_coordinates_file(opts.coords_file) + if opts.article_data_file: + read_redirects_from_article_data(opts.article_data_file) + if opts.disambig_id_file: + read_disambig_id_file(opts.disambig_id_file) + if opts.output_all_words: + main_process_input(OutputAllWords()) + elif opts.output_coord_words: + main_process_input(OutputCoordWords()) + elif opts.find_coord_links: + main_process_input(FindCoordLinks()) + elif opts.find_redirects: + main_process_input(FindRedirects()) + elif opts.output_coords: + main_process_input(OutputCoords()) + elif opts.output_all_counts: + main_process_input(OutputAllCounts()) + elif opts.output_coord_counts: + main_process_input(OutputCoordCounts()) + elif opts.output_location_type: + main_process_input(OutputLocationType()) + elif opts.generate_toponym_eval: + main_process_input(GenerateToponymEvalData()) + elif opts.generate_article_data: + outprint('id\ttitle\tsplit\tredir\tnamespace\tis_list_of\tis_disambig\tis_list') + main_process_input(GenerateArticleData()) + +#import cProfile +#cProfile.run('main()', 'process-wiki.prof') +main() diff --git a/src/main/python/run-geolocate-exper.py b/src/main/python/run-geolocate-exper.py new file mode 100755 index 0000000..490ae97 --- /dev/null +++ b/src/main/python/run-geolocate-exper.py @@ -0,0 +1,1241 @@ +#!/usr/bin/env python + +import os +from nlputil import * + +# Run a series of geolocation experiments. + +tgdir = os.environ['TEXTGROUNDER_DIR'] +if not tgdir: + raise EnvironmentError("TEXTGROUNDER_DIR must be set to the base of the TextGrounder distribution.") +tgbin = '%s/bin' % tgdir + +def runit(fun, id, args): + command='%s --id %s %s' % (Opts.run_cmd, id, args) + errprint("Executing: %s" % command) + if not Opts.dry_run: + os.system("%s" % command) + +def combine(*funs): + def do_combine(fun, *args): + for f in funs: + f(fun, *args) + return do_combine + +def iterate(paramname, vals): + def do_iterate(fun, id, args): + for val in vals: + fun('%s.%s' % (id, val), '%s %s %s' % (args, paramname, val)) + return do_iterate + +def add_param(param): + def do_add_param(fun, id, args): + fun(id, '%s %s' % (args, param)) + return do_add_param + +def recurse(funs, *args): + if not funs: + return + (funs[0])(lambda *args: recurse(funs[1:], *args), *args) + +def nest(*nest_funs): + def do_nest(fun, *args): + recurse(nest_funs + (fun,), *args) + return do_nest + +def run_exper(exper, expername): + exper(lambda fun, *args: runit(id, *args), expername, '') + +def main(): + op = OptionParser(usage="%prog [options] experiment [...]") + op.add_option("-n", "--dry-run", action="store_true", + help="Don't execute anything; just output the commands that would be executed.") + def_runcmd = '%s/nohup-geolocate-wikipedia' % tgbin + op.add_option("-c", "--run-cmd", "--cmd", default=def_runcmd, + help="Command to execute; default '%default'.") + (opts, args) = op.parse_args() + global Opts + Opts = opts + if not args: + op.print_help() + for exper in args: + run_exper(eval(exper), exper) + +############################################################################## +# Description of experiments # +############################################################################## + +MTS10 = iterate('--max-time-per-stage', [10]) +MTS50 = iterate('--max-time-per-stage', [50]) +MTS300 = iterate('--max-time-per-stage', [300]) +Train200k = iterate('--num-training-docs', [200000]) +Train100k = iterate('--num-training-docs', [100000]) +Test2k = iterate('--num-test-docs', [2000]) +Test1k = iterate('--num-test-docs', [1000]) +Test500 = iterate('--num-test-docs', [500]) +#CombinedNonBaselineStrategies = add_param('--strategy partial-kl-divergence --strategy cosine-similarity --strategy naive-bayes-with-baseline --strategy average-cell-probability') +CombinedNonBaselineStrategies = add_param('--strategy partial-kl-divergence --strategy smoothed-cosine-similarity --strategy naive-bayes-with-baseline --strategy average-cell-probability') +CombinedNonBaselineNoCosineStrategies = add_param('--strategy partial-kl-divergence --strategy naive-bayes-with-baseline --strategy average-cell-probability') +NonBaselineStrategies = iterate('--strategy', + ['partial-kl-divergence', 'average-cell-probability', 'naive-bayes-with-baseline', 'smoothed-cosine-similarity']) +BaselineStrategies = iterate('--strategy baseline --baseline-strategy', + ['link-most-common-toponym', 'regdist-most-common-toponym', + 'internal-link', 'num-articles', 'random']) +CombinedBaselineStrategies1 = add_param('--strategy baseline --baseline-strategy link-most-common-toponym --baseline-strategy regdist-most-common-toponym') +CombinedBaselineStrategies2 = add_param('--strategy baseline --baseline-strategy internal-link --baseline-strategy num-articles --baseline-strategy random') +CombinedBaselineStrategies = combine(CombinedBaselineStrategies1, CombinedBaselineStrategies2) +AllStrategies = combine(NonBaselineStrategies, BaselineStrategies) +CombinedKL = add_param('--strategy symmetric-partial-kl-divergence --strategy symmetric-full-kl-divergence --strategy partial-kl-divergence --strategy full-kl-divergence') +CombinedCosine = add_param('--strategy cosine-similarity --strategy smoothed-cosine-similarity --strategy partial-cosine-similarity --strategy smoothed-partial-cosine-similarity') +KLDivStrategy = iterate('--strategy', ['partial-kl-divergence']) +FullKLDivStrategy = iterate('--strategy', ['full-kl-divergence']) +SmoothedCosineStrategy = iterate('--strategy', ['smoothed-cosine-similarity']) +NBStrategy = iterate('--strategy', ['naive-bayes-no-baseline']) + +Coarser1DPR = iterate('--degrees-per-region', [0.1, 10]) +Coarser2DPR = iterate('--degrees-per-region', [0.5, 1, 5]) +CoarseDPR = iterate('--degrees-per-region', + #[90, 30, 10, 5, 3, 2, 1, 0.5] + #[0.5, 1, 2, 3, 5, 10, 30, 90] + [0.5, 1, 2, 3, 5, 10]) +OldFineDPR = iterate('--degrees-per-region', + [90, 75, 60, 50, 40, 30, 25, 20, 15, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2.5, 2, + 1.75, 1.5, 1.25, 1, 0.87, 0.75, 0.63, 0.5, 0.4, 0.3, 0.25, 0.2, 0.15, 0.1] + ) +DPRList1 = iterate('--degrees-per-region', [0.5, 1, 3]) + +DPR3 = iterate('--degrees-per-region', [3]) +DPR7 = iterate('--degrees-per-region', [7]) +DPR5 = iterate('--degrees-per-region', [5]) +DPR10 = iterate('--degrees-per-region', [10]) +DPR1 = iterate('--degrees-per-region', [1]) +DPRpoint5 = iterate('--degrees-per-region', [0.5]) +DPRpoint1 = iterate('--degrees-per-region', [0.1]) + +MinWordCount = iterate('--minimum-word-count', [1, 2, 3, 4, 5]) + +CoarseDisambig = nest(MTS300, AllStrategies, CoarseDPR) + +# PCL experiments +PCLDPR = iterate('--degrees-per-region', [1.5, 0.5, 1, 2, 3, 5]) +corpora_dir = os.getenv('CORPORA_DIR') or '/groups/corpora' +PCLEvalFile = add_param('-f pcl-travel -e %s/pcl_travel/books' % corpora_dir) +PCLDisambig = nest(MTS300, PCLEvalFile, NonBaselineStrategies, PCLDPR) + +# Param experiments + +ParamExper = nest(MTS300, DPRList1, MinWordCount, NonBaselineStrategies) + +# Fine experiments + +FinerDPR = iterate('--degrees-per-region', [0.3, 0.2, 0.1]) +EvenFinerDPR = iterate('--degrees-per-region', [0.1, 0.05]) +Finer3DPR = iterate('--degrees-per-region', [0.01, 0.05]) +FinerExper = nest(MTS300, FinerDPR, KLDivStrategy) +EvenFinerExper = nest(MTS300, EvenFinerDPR, KLDivStrategy) + +# Missing experiments + +MissingNonBaselineStrategies = iterate('--strategy', + ['naive-bayes-no-baseline', 'partial-cosine-similarity', 'cosine-similarity']) +MissingBaselineStrategies = iterate('--strategy baseline --baseline-strategy', + ['link-most-common-toponym' + #, 'regdist-most-common-toponym' + ]) +MissingOtherNonBaselineStrategies = iterate('--strategy', + ['partial-cosine-similarity', 'cosine-similarity']) +MissingAllButNBStrategies = combine(MissingOtherNonBaselineStrategies, + MissingBaselineStrategies) +#Original MissingExper failed on or didn't include all but +#regdist-most-common-toponym. +#MissingExper = nest(MTS300, CoarseDPR, MissingAllStrategies) + +MissingNBExper = nest(MTS300, CoarseDPR, NBStrategy) +MissingOtherExper = nest(MTS300, CoarseDPR, MissingAllButNBStrategies) +MissingBaselineExper = nest(MTS300, CoarseDPR, MissingBaselineStrategies) +FullKLDivExper = nest(MTS300, CoarseDPR, FullKLDivStrategy) + +# Newer experiments on 200k/1k + +#CombinedKLExper = nest(Train100k, Test1k, DPR5, CombinedKL) +#CombinedCosineExper = nest(Train100k, Test1k, DPR5, CombinedCosine) +CombinedKLExper = nest(Train100k, Test500, DPR5, CombinedKL) +CombinedCosineExper = nest(Train100k, Test500, DPR5, CombinedCosine) + +NewCoarser1Exper = nest(Train100k, Test500, Coarser1DPR, CombinedNonBaselineStrategies) +NewCoarser2Exper = nest(Train100k, Test500, Coarser2DPR, CombinedNonBaselineStrategies) +NewFiner3Exper = nest(Train100k, Test500, Finer3DPR, KLDivStrategy) +NewIndiv4Exper = nest(Train100k, Test500, DPRpoint5, CombinedNonBaselineNoCosineStrategies) +NewIndiv5Exper = nest(Train100k, Test500, DPRpoint5, CombinedBaselineStrategies1) + +NewDPR = iterate('--degrees-per-region', [0.1, 0.5, 1, 5]) +NewDPR2 = iterate('--degrees-per-region', [0.1, 0.5, 1, 5, 10]) +New10DPR = iterate('--degrees-per-region', [10]) +New510DPR = iterate('--degrees-per-region', [5, 10]) +New1DPR = iterate('--degrees-per-region', [1]) +NewSmoothedCosineExper = nest(Train100k, Test500, SmoothedCosineStrategy, NewDPR) +NewSmoothedCosineExper2 = nest(Train100k, Test500, SmoothedCosineStrategy, New10DPR) +New10Exper = nest(Train100k, Test500, New10DPR, CombinedNonBaselineStrategies) +NewBaselineExper = nest(Train100k, Test500, NewDPR2, CombinedBaselineStrategies) +NewBaseline2Exper1 = nest(Train100k, Test500, New1DPR, CombinedBaselineStrategies2) +NewBaseline2Exper2 = nest(Train100k, Test500, New510DPR, CombinedBaselineStrategies) + +# Final experiments performed prior to original submission, c. Dec 17 2010 + +TestDPR = iterate('--degrees-per-region', [0.1]) +TestSet = add_param('--eval-set test') +TestStrat1 = iterate('--strategy', ['partial-kl-divergence']) +TestStrat2 = iterate('--strategy', ['average-cell-probability']) +TestStrat3 = iterate('--strategy', ['naive-bayes-with-baseline']) +Test2Sec1 = add_param('--skip-initial 31 --every-nth 3') +Test2Sec2 = add_param('--skip-initial 32 --every-nth 3') +Test2Sec3 = add_param('--skip-initial 33 --every-nth 6') +Test2Sec4 = add_param('--skip-initial 36 --every-nth 6') +TestExper1 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat1) +TestExper2 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat2) +TestExper2Sec1 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat2, Test2Sec1) +TestExper2Sec2 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat2, Test2Sec2) +TestExper2Sec3 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat2, Test2Sec3) +TestExper2Sec4 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat2, Test2Sec4) +TestExper3 = nest(Train100k, Test1k, TestSet, TestDPR, TestStrat3) + +TestStratBase1 = add_param('--strategy baseline --baseline-strategy link-most-common-toponym --baseline-strategy regdist-most-common-toponym') +TestStratBase2 = add_param('--strategy baseline --baseline-strategy num-articles --baseline-strategy random') +TestExperBase1 = nest(Train100k, Test1k, TestSet, TestDPR, TestStratBase1) +TestExperBase2 = nest(Train100k, Test500, TestSet, TestDPR, TestStratBase2) + +# Final experiments performed prior to final submission, c. Apr 10-15 2011 +WikiFinalKL = nest(Test2k, TestSet, TestDPR, TestStrat1) + +WikiFinal1 = nest(TestSet, TestDPR, TestStrat1) +WikiFinal2 = nest(TestSet, TestDPR, TestStrat2) +WikiFinal3 = nest(TestSet, TestDPR, TestStrat3) + +Final1Sec1 = add_param('--skip-initial 0 --every-nth 6') +Final1Sec2 = add_param('--skip-initial 1 --every-nth 6') +Final1Sec3 = add_param('--skip-initial 2 --every-nth 6') +Final1Sec4 = add_param('--skip-initial 3 --every-nth 6') +Final1Sec5 = add_param('--skip-initial 4 --every-nth 6') +Final1Sec6 = add_param('--skip-initial 5 --every-nth 6') + +OracleOnly = add_param('--oracle-results') +WikiOracleDPRpoint1 = nest(TestSet, DPRpoint1, TestStratBase1, OracleOnly) +WikiOracleDPRpoint5 = nest(TestSet, DPRpoint5, TestStratBase1, OracleOnly) +WikiOracleDPR1 = nest(TestSet, DPR1, TestStratBase1, OracleOnly) +WikiOracleDPR5 = nest(TestSet, DPR5, TestStratBase1, OracleOnly) +WikiOracleDPR5Test = nest(TestSet, DPR5, TestStratBase1, OracleOnly, MTS10) + +WikiFinal1Sec1 = nest(TestSet, TestDPR, TestStrat1, Final1Sec1) +WikiFinal1Sec2 = nest(TestSet, TestDPR, TestStrat1, Final1Sec2) +WikiFinal1Sec3 = nest(TestSet, TestDPR, TestStrat1, Final1Sec3) +WikiFinal1Sec4 = nest(TestSet, TestDPR, TestStrat1, Final1Sec4) +WikiFinal1Sec5 = nest(TestSet, TestDPR, TestStrat1, Final1Sec5) +WikiFinal1Sec6 = nest(TestSet, TestDPR, TestStrat1, Final1Sec6) + +Final1Sec7 = add_param('--skip-initial 33786 --every-nth 6') +Final1Sec8 = add_param('--skip-initial 34453 --every-nth 6') +Final1Sec9 = add_param('--skip-initial 33272 --every-nth 6') +Final1Sec10 = add_param('--skip-initial 35121 --every-nth 6') +Final1Sec11 = add_param('--skip-initial 33796 --every-nth 6') +Final1Sec12 = add_param('--skip-initial 35363 --every-nth 6') + +WikiFinal1Sec7 = nest(TestSet, TestDPR, TestStrat1, Final1Sec7) +WikiFinal1Sec8 = nest(TestSet, TestDPR, TestStrat1, Final1Sec8) +WikiFinal1Sec9 = nest(TestSet, TestDPR, TestStrat1, Final1Sec9) +WikiFinal1Sec10 = nest(TestSet, TestDPR, TestStrat1, Final1Sec10) +WikiFinal1Sec11 = nest(TestSet, TestDPR, TestStrat1, Final1Sec11) +WikiFinal1Sec12 = nest(TestSet, TestDPR, TestStrat1, Final1Sec12) + +# Experiments to test memory usage and speed with different sizes of LRU +# cache. +TestLRU150 = iterate('--lru', ['150']) +TestLRU200 = iterate('--lru', ['200']) +TestLRU300 = iterate('--lru', ['300']) +TestLRU350 = iterate('--lru', ['350']) +TestLRU400 = iterate('--lru', ['400']) +TestLRU500 = iterate('--lru', ['500']) +TestLRU600 = iterate('--lru', ['600']) +TestLRU700 = iterate('--lru', ['700']) +TestLRU1200 = iterate('--lru', ['1200']) +TestLRU4000 = iterate('--lru', ['4000']) + +WikiFinal2LRU400 = nest(TestSet, TestDPR, TestStrat2, TestLRU400) +WikiFinal2LRU1200 = nest(TestSet, TestDPR, TestStrat2, TestLRU1200) +WikiFinal2LRU4000 = nest(TestSet, TestDPR, TestStrat2, TestLRU4000) + +TestSkip59 = add_param('--every-nth 60') +TestSkip31 = add_param('--every-nth 32') + +TestOffset0 = iterate('--skip-initial', ['0']) +TestOffset1 = iterate('--skip-initial', ['1']) +TestOffset2 = iterate('--skip-initial', ['2']) +TestOffset3 = iterate('--skip-initial', ['3']) +TestOffset4 = iterate('--skip-initial', ['4']) +TestOffset5 = iterate('--skip-initial', ['5']) +TestOffset6 = iterate('--skip-initial', ['6']) +TestOffset7 = iterate('--skip-initial', ['7']) +TestOffset8 = iterate('--skip-initial', ['8']) +TestOffset9 = iterate('--skip-initial', ['9']) +TestOffset10 = iterate('--skip-initial', ['10']) +TestOffset11 = iterate('--skip-initial', ['11']) +TestOffset12 = iterate('--skip-initial', ['12']) +TestOffset13 = iterate('--skip-initial', ['13']) +TestOffset14 = iterate('--skip-initial', ['14']) +TestOffset15 = iterate('--skip-initial', ['15']) +TestOffset16 = iterate('--skip-initial', ['16']) +TestOffset17 = iterate('--skip-initial', ['17']) +TestOffset18 = iterate('--skip-initial', ['18']) +TestOffset19 = iterate('--skip-initial', ['19']) +TestOffset20 = iterate('--skip-initial', ['20']) +TestOffset21 = iterate('--skip-initial', ['21']) +TestOffset22 = iterate('--skip-initial', ['22']) +TestOffset23 = iterate('--skip-initial', ['23']) +TestOffset24 = iterate('--skip-initial', ['24']) +TestOffset25 = iterate('--skip-initial', ['25']) +TestOffset26 = iterate('--skip-initial', ['26']) +TestOffset27 = iterate('--skip-initial', ['27']) +TestOffset28 = iterate('--skip-initial', ['28']) +TestOffset29 = iterate('--skip-initial', ['29']) +TestOffset30 = iterate('--skip-initial', ['30']) +TestOffset31 = iterate('--skip-initial', ['31']) +TestOffset32 = iterate('--skip-initial', ['32']) +TestOffset33 = iterate('--skip-initial', ['33']) +TestOffset34 = iterate('--skip-initial', ['34']) +TestOffset35 = iterate('--skip-initial', ['35']) +TestOffset36 = iterate('--skip-initial', ['36']) +TestOffset37 = iterate('--skip-initial', ['37']) +TestOffset38 = iterate('--skip-initial', ['38']) +TestOffset39 = iterate('--skip-initial', ['39']) +TestOffset40 = iterate('--skip-initial', ['40']) +TestOffset41 = iterate('--skip-initial', ['41']) +TestOffset42 = iterate('--skip-initial', ['42']) +TestOffset43 = iterate('--skip-initial', ['43']) +TestOffset44 = iterate('--skip-initial', ['44']) +TestOffset45 = iterate('--skip-initial', ['45']) +TestOffset46 = iterate('--skip-initial', ['46']) +TestOffset47 = iterate('--skip-initial', ['47']) +TestOffset48 = iterate('--skip-initial', ['48']) +TestOffset49 = iterate('--skip-initial', ['49']) +TestOffset50 = iterate('--skip-initial', ['50']) +TestOffset51 = iterate('--skip-initial', ['51']) +TestOffset52 = iterate('--skip-initial', ['52']) +TestOffset53 = iterate('--skip-initial', ['53']) +TestOffset54 = iterate('--skip-initial', ['54']) +TestOffset55 = iterate('--skip-initial', ['55']) +TestOffset56 = iterate('--skip-initial', ['56']) +TestOffset57 = iterate('--skip-initial', ['57']) +TestOffset58 = iterate('--skip-initial', ['58']) +TestOffset59 = iterate('--skip-initial', ['59']) + +WikiFinal2Sec0 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset0) +WikiFinal2Sec1 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset1) +WikiFinal2Sec2 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset2) +WikiFinal2Sec3 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset3) +WikiFinal2Sec4 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset4) +WikiFinal2Sec5 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset5) +WikiFinal2Sec6 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset6) +WikiFinal2Sec7 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset7) +WikiFinal2Sec8 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset8) +WikiFinal2Sec9 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset9) +WikiFinal2Sec10 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset10) +WikiFinal2Sec11 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset11) +WikiFinal2Sec12 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset12) +WikiFinal2Sec13 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset13) +WikiFinal2Sec14 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset14) +WikiFinal2Sec15 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset15) +WikiFinal2Sec16 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset16) +WikiFinal2Sec17 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset17) +WikiFinal2Sec18 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset18) +WikiFinal2Sec19 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset19) +WikiFinal2Sec20 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset20) +WikiFinal2Sec21 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset21) +WikiFinal2Sec22 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset22) +WikiFinal2Sec23 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset23) +WikiFinal2Sec24 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset24) +WikiFinal2Sec25 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset25) +WikiFinal2Sec26 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset26) +WikiFinal2Sec27 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset27) +WikiFinal2Sec28 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset28) +WikiFinal2Sec29 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset29) +WikiFinal2Sec30 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset30) +WikiFinal2Sec31 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset31) +WikiFinal2Sec32 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset32) +WikiFinal2Sec33 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset33) +WikiFinal2Sec34 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset34) +WikiFinal2Sec35 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU500, TestOffset35) +WikiFinal2Sec36 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset36) +WikiFinal2Sec37 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset37) +WikiFinal2Sec38 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset38) +WikiFinal2Sec39 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset39) +WikiFinal2Sec40 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset40) +WikiFinal2Sec41 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset41) +WikiFinal2Sec42 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset42) +WikiFinal2Sec43 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset43) +WikiFinal2Sec44 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset44) +WikiFinal2Sec45 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset45) +WikiFinal2Sec46 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset46) +WikiFinal2Sec47 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset47) +WikiFinal2Sec48 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset48) +WikiFinal2Sec49 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset49) +WikiFinal2Sec50 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset50) +WikiFinal2Sec51 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset51) +WikiFinal2Sec52 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset52) +WikiFinal2Sec53 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset53) +WikiFinal2Sec54 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset54) +WikiFinal2Sec55 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset55) +WikiFinal2Sec56 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset56) +WikiFinal2Sec57 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset57) +WikiFinal2Sec58 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset58) +WikiFinal2Sec59 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset59) + +TestOffset2a0= iterate('--skip-initial', ['46201']) +TestOffset2a1= iterate('--skip-initial', ['45312']) +TestOffset2a2= iterate('--skip-initial', ['48496']) +TestOffset2a3= iterate('--skip-initial', ['48258']) +TestOffset2a4= iterate('--skip-initial', ['47359']) +TestOffset2a5= iterate('--skip-initial', ['46404']) +TestOffset2a6= iterate('--skip-initial', ['48386']) +TestOffset2a7= iterate('--skip-initial', ['47127']) +TestOffset2a8= iterate('--skip-initial', ['46588']) +TestOffset2a9= iterate('--skip-initial', ['46953']) +TestOffset2a10= iterate('--skip-initial', ['47104']) +TestOffset2a11= iterate('--skip-initial', ['44466']) +TestOffset2a12= iterate('--skip-initial', ['43927']) + +TestOffset2b0= iterate('--skip-initial', ['48000']) +TestOffset2b1= iterate('--skip-initial', ['47300']) +TestOffset2b2= iterate('--skip-initial', ['46600']) +TestOffset2b3= iterate('--skip-initial', ['45900']) +TestOffset2b4= iterate('--skip-initial', ['45200']) +TestOffset2b5= iterate('--skip-initial', ['44500']) +TestOffset2b6= iterate('--skip-initial', ['43800']) + +WikiFinal2aSec0 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a0) +WikiFinal2aSec1 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a1) +WikiFinal2aSec2 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a2) +WikiFinal2aSec3 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a3) +WikiFinal2aSec4 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a4) +WikiFinal2aSec5 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a5) +WikiFinal2aSec6 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a6) +WikiFinal2aSec7 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a7) +WikiFinal2aSec8 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a8) +WikiFinal2aSec9 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a9) +WikiFinal2aSec10 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a10) +WikiFinal2aSec11 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU700, TestOffset2a11) +WikiFinal2aSec12 = nest(TestSet, TestDPR, TestStrat2, TestSkip59, TestLRU150, TestOffset2a12) + +WikiFinal2bSec0 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b0) +WikiFinal2bSec1 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b1) +WikiFinal2bSec2 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b2) +WikiFinal2bSec3 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b3) +WikiFinal2bSec4 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b4) +WikiFinal2bSec5 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b5) +WikiFinal2bSec6 = nest(TestSet, TestDPR, TestStrat2, TestLRU700, TestOffset2b6) + + + +WikiFinal3Sec0 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset0) +WikiFinal3Sec1 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset1) +WikiFinal3Sec2 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset2) +WikiFinal3Sec3 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset3) +WikiFinal3Sec4 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset4) +WikiFinal3Sec5 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset5) +WikiFinal3Sec6 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset6) +WikiFinal3Sec7 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset7) +WikiFinal3Sec8 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset8) +WikiFinal3Sec9 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset9) +WikiFinal3Sec10 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset10) +WikiFinal3Sec11 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset11) +WikiFinal3Sec12 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset12) +WikiFinal3Sec13 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset13) +WikiFinal3Sec14 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset14) +WikiFinal3Sec15 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset15) +WikiFinal3Sec16 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset16) +WikiFinal3Sec17 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset17) +WikiFinal3Sec18 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset18) +WikiFinal3Sec19 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset19) +WikiFinal3Sec20 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset20) +WikiFinal3Sec21 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset21) +WikiFinal3Sec22 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset22) +WikiFinal3Sec23 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset23) +WikiFinal3Sec24 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset24) +WikiFinal3Sec25 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset25) +WikiFinal3Sec26 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset26) +WikiFinal3Sec27 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset27) +WikiFinal3Sec28 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset28) +WikiFinal3Sec29 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset29) +WikiFinal3Sec30 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset30) +WikiFinal3Sec31 = nest(TestSet, TestDPR, TestStrat3, TestSkip31, TestOffset31) + + +TestFinalStratBase1 = add_param('--strategy baseline --baseline-strategy link-most-common-toponym') +TestFinalStratBase2 = add_param('--strategy baseline --baseline-strategy regdist-most-common-toponym') +TestFinalStratBase3 = add_param('--strategy baseline --baseline-strategy num-articles') +TestFinalStratBase4 = add_param('--strategy baseline --baseline-strategy random') + +WikiFinalBase1 = nest(TestSet, TestDPR, TestFinalStratBase1) +FinalBase1aSkip = add_param('--skip-initial 46916') +WikiFinalBase1a = nest(TestSet, TestDPR, TestFinalStratBase1, FinalBase1aSkip) +WikiFinalBase2 = nest(TestSet, TestDPR, TestFinalStratBase2) +WikiFinalBase3 = nest(TestSet, TestDPR, TestFinalStratBase3) +WikiFinalBase4 = nest(TestSet, TestDPR, TestFinalStratBase4) + +FinalBase2aSkip = add_param('--skip-initial 40105') +WikiFinalBase2a = nest(TestSet, TestDPR, TestFinalStratBase2, FinalBase2aSkip) + +Split32Sec0 = add_param('--skip-initial 0 --every-nth 32') +Split32Sec1 = add_param('--skip-initial 1 --every-nth 32') +Split32Sec2 = add_param('--skip-initial 2 --every-nth 32') +Split32Sec3 = add_param('--skip-initial 3 --every-nth 32') +Split32Sec4 = add_param('--skip-initial 4 --every-nth 32') +Split32Sec5 = add_param('--skip-initial 5 --every-nth 32') +Split32Sec6 = add_param('--skip-initial 6 --every-nth 32') +Split32Sec7 = add_param('--skip-initial 7 --every-nth 32') +Split32Sec8 = add_param('--skip-initial 8 --every-nth 32') +Split32Sec9 = add_param('--skip-initial 9 --every-nth 32') +Split32Sec10 = add_param('--skip-initial 10 --every-nth 32') +Split32Sec11 = add_param('--skip-initial 11 --every-nth 32') +Split32Sec12 = add_param('--skip-initial 12 --every-nth 32') +Split32Sec13 = add_param('--skip-initial 13 --every-nth 32') +Split32Sec14 = add_param('--skip-initial 14 --every-nth 32') +Split32Sec15 = add_param('--skip-initial 15 --every-nth 32') +Split32Sec16 = add_param('--skip-initial 16 --every-nth 32') +Split32Sec17 = add_param('--skip-initial 17 --every-nth 32') +Split32Sec18 = add_param('--skip-initial 18 --every-nth 32') +Split32Sec19 = add_param('--skip-initial 19 --every-nth 32') +Split32Sec20 = add_param('--skip-initial 20 --every-nth 32') +Split32Sec21 = add_param('--skip-initial 21 --every-nth 32') +Split32Sec22 = add_param('--skip-initial 22 --every-nth 32') +Split32Sec23 = add_param('--skip-initial 23 --every-nth 32') +Split32Sec24 = add_param('--skip-initial 24 --every-nth 32') +Split32Sec25 = add_param('--skip-initial 25 --every-nth 32') +Split32Sec26 = add_param('--skip-initial 26 --every-nth 32') +Split32Sec27 = add_param('--skip-initial 27 --every-nth 32') +Split32Sec28 = add_param('--skip-initial 28 --every-nth 32') +Split32Sec29 = add_param('--skip-initial 29 --every-nth 32') +Split32Sec30 = add_param('--skip-initial 30 --every-nth 32') +Split32Sec31 = add_param('--skip-initial 31 --every-nth 32') + +WikiFinalBase1point5Sec0 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec0) +WikiFinalBase1point5Sec1 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec1) +WikiFinalBase1point5Sec2 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec2) +WikiFinalBase1point5Sec3 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec3) +WikiFinalBase1point5Sec4 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec4) +WikiFinalBase1point5Sec5 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec5) +WikiFinalBase1point5Sec6 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec6) +WikiFinalBase1point5Sec7 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec7) +WikiFinalBase1point5Sec8 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec8) +WikiFinalBase1point5Sec9 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec9) +WikiFinalBase1point5Sec10 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec10) +WikiFinalBase1point5Sec11 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec11) +WikiFinalBase1point5Sec12 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec12) +WikiFinalBase1point5Sec13 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec13) +WikiFinalBase1point5Sec14 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec14) +WikiFinalBase1point5Sec15 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec15) +WikiFinalBase1point5Sec16 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec16) +WikiFinalBase1point5Sec17 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec17) +WikiFinalBase1point5Sec18 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec18) +WikiFinalBase1point5Sec19 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec19) +WikiFinalBase1point5Sec20 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec20) +WikiFinalBase1point5Sec21 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec21) +WikiFinalBase1point5Sec22 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec22) +WikiFinalBase1point5Sec23 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec23) +WikiFinalBase1point5Sec24 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec24) +WikiFinalBase1point5Sec25 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec25) +WikiFinalBase1point5Sec26 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec26) +WikiFinalBase1point5Sec27 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec27) +WikiFinalBase1point5Sec28 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec28) +WikiFinalBase1point5Sec29 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec29) +WikiFinalBase1point5Sec30 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec30) +WikiFinalBase1point5Sec31 = nest(TestSet, DPRpoint5, TestFinalStratBase1, Split32Sec31) + +WikiFinalBase3DPR5Sec0 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec0) +WikiFinalBase3DPR5Sec1 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec1) +WikiFinalBase3DPR5Sec2 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec2) +WikiFinalBase3DPR5Sec3 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec3) +WikiFinalBase3DPR5Sec4 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec4) +WikiFinalBase3DPR5Sec5 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec5) +WikiFinalBase3DPR5Sec6 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec6) +WikiFinalBase3DPR5Sec7 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec7) +WikiFinalBase3DPR5Sec8 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec8) +WikiFinalBase3DPR5Sec9 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec9) +WikiFinalBase3DPR5Sec10 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec10) +WikiFinalBase3DPR5Sec11 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec11) +WikiFinalBase3DPR5Sec12 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec12) +WikiFinalBase3DPR5Sec13 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec13) +WikiFinalBase3DPR5Sec14 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec14) +WikiFinalBase3DPR5Sec15 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec15) +WikiFinalBase3DPR5Sec16 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec16) +WikiFinalBase3DPR5Sec17 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec17) +WikiFinalBase3DPR5Sec18 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec18) +WikiFinalBase3DPR5Sec19 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec19) +WikiFinalBase3DPR5Sec20 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec20) +WikiFinalBase3DPR5Sec21 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec21) +WikiFinalBase3DPR5Sec22 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec22) +WikiFinalBase3DPR5Sec23 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec23) +WikiFinalBase3DPR5Sec24 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec24) +WikiFinalBase3DPR5Sec25 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec25) +WikiFinalBase3DPR5Sec26 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec26) +WikiFinalBase3DPR5Sec27 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec27) +WikiFinalBase3DPR5Sec28 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec28) +WikiFinalBase3DPR5Sec29 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec29) +WikiFinalBase3DPR5Sec30 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec30) +WikiFinalBase3DPR5Sec31 = nest(TestSet, DPR5, TestFinalStratBase3, Split32Sec31) + +# Twitter Experiments + + +TwitterDPR1 = iterate('--degrees-per-region', [0.5, 1, 0.1, 5, 10]) +TwitterExper1 = nest(TwitterDPR1, KLDivStrategy) +TwitterDPR2 = iterate('--degrees-per-region', [1, 5, 10, 0.5, 0.1]) +TwitterStrategy2 = add_param('--strategy naive-bayes-with-baseline --strategy smoothed-cosine-similarity --strategy average-cell-probability') +TwitterExper2 = nest(TwitterDPR2, TwitterStrategy2) +TwitterDPR3 = iterate('--degrees-per-region', [5, 10, 1, 0.5, 0.1]) +TwitterStrategy3 = add_param('--strategy cosine-similarity') +TwitterExper3 = nest(TwitterDPR3, TwitterStrategy3) +TwitterBaselineExper1 = nest(TwitterDPR3, BaselineStrategies) +TwitterAllThresh1 = iterate('--doc-thresh', [40, 5, 0, 20, 10, 3, 2]) +TwitterAllThresh2 = iterate('--doc-thresh', [20, 10, 3, 2]) +TwitterAllThresh2 = iterate('--doc-thresh', [5, 10, 3, 2]) +TwitterDevSet = add_param('--eval-set dev') +TwitterTestSet = add_param('--eval-set test') + +Thresh0 = iterate('--doc-thresh', [0]) +Thresh1 = iterate('--doc-thresh', [1]) # Same as thresh 0 +Thresh2 = iterate('--doc-thresh', [2]) +Thresh3 = iterate('--doc-thresh', [3]) +Thresh5 = iterate('--doc-thresh', [5]) +Thresh10 = iterate('--doc-thresh', [10]) +Thresh20 = iterate('--doc-thresh', [20]) +Thresh40 = iterate('--doc-thresh', [40]) + +#TwitterAllThreshExper1 = nest(TwitterAllThresh1, TwitterExper1) +TwitterAllThreshExper1 = nest(TwitterAllThresh1, TwitterDPR1, KLDivStrategy) +TwitterDPR5 = iterate('--degrees-per-region', [5]) +TwitterAllThreshExper2 = nest(TwitterAllThresh2, TwitterDPR5, KLDivStrategy) +TwitterAllThreshExper3 = nest(TwitterAllThresh2, TwitterDPR1, KLDivStrategy) + +TwitterDevStrat1Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat1, TwitterDevSet) +TwitterDevStrat1Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat1, TwitterDevSet) + +TwitterDevStrat2Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat2, TwitterDevSet) +TwitterDevStrat2Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat2, TwitterDevSet) + +TwitterDevStrat3Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat3, TwitterDevSet) +TwitterDevStrat3Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat3, TwitterDevSet) + +TwitterDevStratBase1Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase1, TwitterDevSet) +TwitterDevStratBase1Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase1, TwitterDevSet) + +TwitterDevStratBase2Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase2, TwitterDevSet) +TwitterDevStratBase2Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase2, TwitterDevSet) + +TwitterDevStratBase3Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase3, TwitterDevSet) +TwitterDevStratBase3Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase3, TwitterDevSet) + +TwitterDevStratBase4Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase4, TwitterDevSet) +TwitterDevStratBase4Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase4, TwitterDevSet) + +TwitterTestStrat1Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat1, TwitterTestSet) +TwitterTestStrat1Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat1, TwitterTestSet) + +TwitterTestStrat2Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat2, TwitterTestSet) +TwitterTestStrat2Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat2, TwitterTestSet) + +TwitterTestStrat3Thresh20DPR1 = nest(Thresh20, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh20DPR5 = nest(Thresh20, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh20DPR10 = nest(Thresh20, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh40DPR1 = nest(Thresh40, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh40DPR5 = nest(Thresh40, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh40DPR10 = nest(Thresh40, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh10DPR1 = nest(Thresh10, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh10DPR5 = nest(Thresh10, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh10DPR10 = nest(Thresh10, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh5DPR1 = nest(Thresh5, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh5DPR10 = nest(Thresh5, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh2DPR1 = nest(Thresh2, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh2DPR5 = nest(Thresh2, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh2DPR10 = nest(Thresh2, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh3DPR1 = nest(Thresh3, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh3DPR5 = nest(Thresh3, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh3DPR10 = nest(Thresh3, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh1DPR1 = nest(Thresh1, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh1DPR5 = nest(Thresh1, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh1DPR10 = nest(Thresh1, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh0DPR1 = nest(Thresh0, DPR1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh0DPR5 = nest(Thresh0, DPR5, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh0DPR10 = nest(Thresh0, DPR10, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestStrat3, TwitterTestSet) +TwitterTestStrat3Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestStrat3, TwitterTestSet) + +TwitterTestStratBase1Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase1, TwitterTestSet) +TwitterTestStratBase1Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase1, TwitterTestSet) + +TwitterTestStratBase2Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase2, TwitterTestSet) +TwitterTestStratBase2Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase2, TwitterTestSet) + +TwitterTestStratBase3Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase3, TwitterTestSet) +TwitterTestStratBase3Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase3, TwitterTestSet) + +TwitterTestStratBase4Thresh20DPR1 = nest(Thresh20, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh20DPR5 = nest(Thresh20, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh20DPR10 = nest(Thresh20, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh20DPRpoint1 = nest(Thresh20, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh20DPRpoint5 = nest(Thresh20, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh40DPR1 = nest(Thresh40, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh40DPR5 = nest(Thresh40, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh40DPR10 = nest(Thresh40, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh40DPRpoint1 = nest(Thresh40, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh40DPRpoint5 = nest(Thresh40, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh10DPR1 = nest(Thresh10, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh10DPR5 = nest(Thresh10, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh10DPR10 = nest(Thresh10, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh10DPRpoint1 = nest(Thresh10, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh10DPRpoint5 = nest(Thresh10, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh5DPR1 = nest(Thresh5, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh5DPR5 = nest(Thresh5, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh5DPR10 = nest(Thresh5, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh5DPRpoint1 = nest(Thresh5, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh5DPRpoint5 = nest(Thresh5, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh2DPR1 = nest(Thresh2, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh2DPR5 = nest(Thresh2, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh2DPR10 = nest(Thresh2, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh2DPRpoint1 = nest(Thresh2, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh2DPRpoint5 = nest(Thresh2, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh3DPR1 = nest(Thresh3, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh3DPR5 = nest(Thresh3, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh3DPR10 = nest(Thresh3, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh3DPRpoint1 = nest(Thresh3, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh3DPRpoint5 = nest(Thresh3, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh1DPR1 = nest(Thresh1, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh1DPR5 = nest(Thresh1, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh1DPR10 = nest(Thresh1, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh1DPRpoint1 = nest(Thresh1, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh1DPRpoint5 = nest(Thresh1, DPRpoint5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh0DPR1 = nest(Thresh0, DPR1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh0DPR5 = nest(Thresh0, DPR5, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh0DPR10 = nest(Thresh0, DPR10, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh0DPRpoint1 = nest(Thresh0, DPRpoint1, TestFinalStratBase4, TwitterTestSet) +TwitterTestStratBase4Thresh0DPRpoint5 = nest(Thresh0, DPRpoint5, TestFinalStratBase4, TwitterTestSet) + +TwitterDevStrategy1 = add_param('--strategy partial-kl-divergence --strategy naive-bayes-with-baseline --strategy average-cell-probability') +#TwitterDevStrategy2 = add_param('--strategy baseline --baseline-strategy link-most-common-toponym --baseline-strategy regdist-most-common-toponym') +TwitterDevStrategy2 = add_param('--strategy baseline --baseline-strategy link-most-common-toponym') +TwitterDevStrategy3 = add_param('--strategy baseline --baseline-strategy num-articles --baseline-strategy random') +TwitterDev1 = nest(TwitterDPR3, TwitterDevSet, TwitterDevStrategy1) +TwitterDev2 = nest(TwitterDPR3, TwitterDevSet, TwitterDevStrategy2) +TwitterDev3 = nest(TwitterDPR3, TwitterDevSet, TwitterDevStrategy3) +TwitterDPR4 = iterate('--degrees-per-region', [3, 4, 6, 7]) +TwitterDev4 = nest(TwitterDPR4, TwitterDevSet, TwitterDevStrategy1) +TwitterDev5 = nest(TwitterDPR4, TwitterDevSet, TwitterDevStrategy2) +TwitterDev6 = nest(TwitterDPR4, TwitterDevSet, TwitterDevStrategy3) + + +WithStopwords = add_param('--include-stopwords-in-article-dists') +TwitterExper4 = nest(WithStopwords, TwitterDPR3, KLDivStrategy) +TwitterExper5 = nest(WithStopwords, TwitterDPR3, TwitterStrategy2) +TwitterExper6 = nest(WithStopwords, TwitterDPR3, BaselineStrategies) + +TwitterWikiNumTest = iterate('--num-test-docs', [1894]) +TwitterWikiDPR1 = iterate('--degrees-per-region', [0.1]) +TwitterWikiStrategyAll = add_param('--strategy partial-kl-divergence --strategy naive-bayes-with-baseline --strategy smoothed-cosine-similarity --strategy average-cell-probability') +TwitterWikiDPR2 = iterate('--degrees-per-region', [0.5]) +TwitterWikiDPR3 = iterate('--degrees-per-region', [5, 10, 1]) +TwitterWikiStrategy3 = add_param('--strategy partial-kl-divergence') +TwitterWikiDPR4 = iterate('--degrees-per-region', [5, 10, 1]) +TwitterWikiStrategy4 = add_param('--strategy naive-bayes-with-baseline --strategy smoothed-cosine-similarity --strategy average-cell-probability') +TwitterWikiExper1 = nest(Train100k, TwitterWikiNumTest, TwitterWikiDPR1, TwitterWikiStrategyAll) +TwitterWikiExper2 = nest(Train100k, TwitterWikiNumTest, TwitterWikiDPR2, TwitterWikiStrategyAll) +TwitterWikiExper3 = nest(Train100k, TwitterWikiNumTest, TwitterWikiDPR3, TwitterWikiStrategy3) +TwitterWikiExper4 = nest(Train100k, TwitterWikiNumTest, TwitterWikiDPR4, TwitterWikiStrategy4) + +# Error Analysis + +TwitterOracleDPRpoint1 = nest(Thresh5, TestSet, DPRpoint1, TestStratBase1, OracleOnly) +TwitterOracleDPRpoint5 = nest(Thresh5, TestSet, DPRpoint5, TestStratBase1, OracleOnly) +TwitterOracleDPR1 = nest(Thresh5, TestSet, DPR1, TestStratBase1, OracleOnly) +TwitterOracleDPR5 = nest(Thresh5, TestSet, DPR5, TestStratBase1, OracleOnly) +TwitterOracleThresh40DPR5 = nest(Thresh40, TestSet, DPR5, TestStratBase1, OracleOnly) +TwitterOracleThresh40DPR10 = nest(Thresh40, TestSet, DPR10, TestStratBase1, OracleOnly) +TwitterOracleDPR5Test = nest(Thresh5, TestSet, DPR5, TestStratBase1, OracleOnly, MTS10) + + +DebugKLDiv = iterate('--debug', ['kldiv']) + + +DebugTwitterDevStrat1Thresh5DPR5 = nest(Thresh5, DPR5, TestStrat1, TwitterDevSet, DebugKLDiv) + +# Test + +main() diff --git a/src/main/python/split_bzip.py b/src/main/python/split_bzip.py new file mode 100755 index 0000000..3d20a77 --- /dev/null +++ b/src/main/python/split_bzip.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python + +####### +####### split_bzip.py +####### +####### Copyright (c) 2011 Ben Wing. +####### + +import sys, re +import math +import fileinput +from subprocess import * +from nlputil import * +import itertools +import time +import os.path +import traceback + +############################################################################ +# Quick Start # +############################################################################ + +# This program reads in data from the specified bzipped files, concatenates +# them, splits them at newlines after a certain amount of data has been +# read, and bzips the results. The files are assumed to contains tweets in +# JSON format, and the resulting split files after named after the date +# in the first tweet of the split. We take some care to ensure that we +# start the file with a valid tweet, in case something invalid is in the +# file. + +############################################################################ +# Notes # +############################################################################ + +# When this script was run, an example run was +# +# run-nohup ~tgp/split_bzip.py -s 1450000000 -o split.global. ../global.tweets.2011-*.bz2 & +# +# The value given to -s is the uncompressed size of each split and has been +# empirically determined to give compressed sizes slightly under 192 MB -- +# useful for Hadoop as it means that each split will take slightly under 3 +# HDFS blocks at the default 64MB block size. +# +# NOTE: On the Longhorn work machines with 48GB of memory, it takes about 12 +# hours to process 20-21 days of global tweets and 24-25 days of spritzer +# tweets. Given that you only have a maximum of 24 hours of time, you +# should probably not process more than about a month's worth of tweets in +# a single run. (As an alternative, if your process gets terminated due to +# running out of time or for any other reason, try removing the last, +# partially written split file and then rerunning the command with the +# additional option --skip-existing. This will cause it to redo the same split +# but not overwrite the files that already exist. Since bzip compression takes +# up most of the time, this should fairly rapidly scan through all of the +# already-written files and then do the rest of them. As an example, a split +# run on spritzer output that took 24 hours to process 49 days took only +# 3.5 hours to skip through them when --skip-existing was used.) + +####################################################################### +# Process files # +####################################################################### + +def split_tweet_bzip_files(opts, args): + status = StatusMessage("tweet") + totalsize = 0 + outproc = None + skip_tweets = False + + def finish_outproc(outproc): + errprint("Total uncompressed size this split: %s" % totalsize) + errprint("Total number of tweets so far: %s" % status.num_processed()) + outproc.stdin.close() + errprint("Waiting for termination of output process ...") + outproc.wait() + errprint("Waiting for termination of output process ... done.") + + for infile in args: + errprint("Opening input %s..." % infile) + errprint("Total uncompressed size this split so far: %s" % totalsize) + errprint("Total number of tweets so far: %s" % status.num_processed()) + # NOTE: close_fds=True turns out to be necessary to avoid a deadlock in + # the following circumstance: + # + # 1) Open input from bzcat. + # 2) Open output to bzip2. + # 3) bzcat ends partway through a split (possibly after multiple splits, + # and hence multiple invocations of bzip2). + # 4) Wait for bzcat to finish, then start another bzcat for the next file. + # 5) End the split, close bzip2's stdin (our output pipe), and wait for + # bzip2 to finish. + # 6) Blammo! Deadlock while waiting for bzip2 to finish. + # + # When we opened the second bzcat, if we don't call close_fds, it + # inherits the file descriptor of the pipe to bzip2, and that screws + # things up. Presumably, the file descriptor inheritance means that + # there's still a file descriptor to the pipe to bzip2, so closing the + # output doesn't actually cause the pipe to get closed -- hence bzip2 + # waits indefinitely for more input. + inproc = Popen("bzcat", stdin=open(infile, "rb"), stdout=PIPE, close_fds=True) + for full_line in inproc.stdout: + line = full_line[:-1] + status.item_processed() + if not line.startswith('{"'): + errprint("Unparsable line, not JSON?, #%s: %s" % (status.num_processed(), line)) + else: + if totalsize >= opts.split_size or (not outproc and not skip_tweets): + # We need to open a new file. But keep writing the old file + # (if any) until we see a tweet with a time in it. + json = None + try: + json = split_json(line) + except Exception, exc: + errprint("Exception parsing JSON in line #%s: %s" % (status.num_processed(), line)) + errprint("Exception is %s" % exc) + traceback.print_exc() + if json: + json = json[0] + #errprint("Processing JSON %s:" % json) + #errprint("Length: %s" % len(json)) + for i in xrange(len(json)): + #errprint("Saw %s=%s" % (i, json[i])) + if json[i] == '"created_at"': + #errprint("Saw created") + if i + 2 >= len(json) or json[i+1] != ':' or json[i+2][0] != '"' or json[i+2][-1] != '"': + errprint("Something weird with JSON in line #%s, around here: %s" % (status.num_processed(), json[i-1:i+4])) + else: + json_time = json[i+2][1:-1].replace(" +0000 ", " UTC ") + tweet_time = time.strptime(json_time, + "%a %b %d %H:%M:%S %Z %Y") + if not tweet_time: + errprint("Can't parse time in line #%s: %s" % (status.num_processed(), json_time)) + else: + # Now we're ready to create a new split. + skip_tweets = False + timesuff = time.strftime("%Y-%m-%d.%H%M-UTC", tweet_time) + def make_filename(suff): + return opts.output_prefix + suff + ".bz2" + outfile = make_filename(timesuff) + if os.path.exists(outfile): + if opts.skip_existing: + errprint("Skipping writing tweets to existing %s" % outfile) + skip_tweets = True + else: + errprint("Warning, path %s exists, not overwriting" % outfile) + for ind in itertools.count(1): + # Use _ not - because - sorts before the . of .bz2 but + # _ sorts after (as well as after all letters and numbers). + outfile = make_filename(timesuff + ("_%03d" % ind)) + if not os.path.exists(outfile): + break + if outproc: + finish_outproc(outproc) + outproc = None + totalsize = 0 + if not skip_tweets: + errprint("About to write to %s..." % outfile) + outfd = open(outfile, "wb") + outproc = Popen("bzip2", stdin=PIPE, stdout=outfd, close_fds=True) + outfd.close() + break + totalsize += len(full_line) + if skip_tweets: + pass + elif outproc: + outproc.stdin.write(full_line) + else: + errprint("Warning: Nowhere to write bad line #%s, skipping: %s" % (status.num_processed(), line)) + errprint("Waiting for termination of input process ...") + inproc.stdout.close() + # This sleep probably isn't necessary. I put it in while attempting to + # solve the deadlock when closing the pipe to bzip2 (see comments above + # about close_fds). + sleep_time = 5 + errprint("Sleeping %s seconds ..." % sleep_time) + time.sleep(sleep_time) + inproc.wait() + errprint("Waiting for termination of input process ... done.") + if outproc: + finish_outproc(outproc) + outproc = None + +# A very simple JSON splitter. Doesn't take the next step of assembling +# into dictionaries, but easily could. +# +# FIXME: This is totally unnecessary, as Python has a built-in JSON parser. +# (I didn't realize this when I wrote the function.) +def split_json(line): + split = re.split(r'("(?:\\.|[^"])*?"|[][:{},])', line) + split = (x for x in split if x) # Filter out empty strings + curind = 0 + def get_nested(endnest): + nest = [] + try: + while True: + item = next(split) + if item == endnest: + return nest + elif item == '{': + nest += [get_nested('}')] + elif item == '[': + nest += [get_nested(']')] + else: + nest += [item] + except StopIteration: + if not endnest: + return nest + else: + raise + return get_nested(None) + +####################################################################### +# Main code # +####################################################################### + +def main(): + op = OptionParser(usage="%prog [options] input_dir") + op.add_option("-s", "--split-size", metavar="SIZE", + type="int", default=1000000000, + help="""Size (uncompressed) of each split. Note that JSON +tweets compress in bzip about 8 to 1, hence 1 GB is a good uncompressed size +for Hadoop. Default %default.""") + op.add_option("-o", "--output-prefix", metavar="PREFIX", + help="""Prefix to use for all splits.""") + op.add_option("--skip-existing", action="store_true", + help="""If we would try and open an existing file, +skip writing any of the corresponding tweets.""") + + opts, args = op.parse_args() + + if not opts.output_prefix: + op.error("Must specify output prefix using -o or --output-prefix") + if not args: + op.error("No input files specified") + + split_tweet_bzip_files(opts, args) + +main() diff --git a/src/main/python/splitdevtest.py b/src/main/python/splitdevtest.py new file mode 100644 index 0000000..9e13625 --- /dev/null +++ b/src/main/python/splitdevtest.py @@ -0,0 +1,17 @@ +import sys, shutil, os + +def processDirectory(dirname): + fileList = os.listdir(dirname) + if(not dirname[-1] == "/"): + dirname += "/" + count = 0 + for filename in fileList: + if(count % 3 == 2): + shutil.copy(dirname + filename, sys.argv[3]) + print (dirname + filename) + " --> " + sys.argv[3] + else: + shutil.copy(dirname + filename, sys.argv[2]) + print (dirname + filename) + " --> " + sys.argv[2] + count += 1 + +processDirectory(sys.argv[1]) diff --git a/src/main/python/tei2txt.py b/src/main/python/tei2txt.py new file mode 100755 index 0000000..87e2c23 --- /dev/null +++ b/src/main/python/tei2txt.py @@ -0,0 +1,69 @@ +#! /usr/bin/env python + +import sys +import os +import re +import gzip +import fnmatch + +from codecs import latin_1_decode +from unicodedata import normalize +from tei_entities import pcl_tei_entities + +commaRE = re.compile(",") +nonAlpha = re.compile("[^A-Za-z]") + +pte = pcl_tei_entities() + +def cleanWord(word): + word = word.lower() + if len(word) < 2: + word = "" + return word + +def strip_text (text): + text = latin_1_decode(text)[0] + text = normalize('NFD',text).encode('ascii','ignore') + + text = re.sub('&mdash+;', ' ', text) # convert mdash to " " +# text = re.sub('&', ' and ', text) # convert mdash to " " + text = pte.replace_entities(text) +# text = re.sub('&[A-Za-z]+;', '', text) # convert ampersand stuff to "" + text = re.sub('<[^>]*>', ' ', text) # strip HTML markup + text = re.sub('\s+', ' ', text) # strip whitespace + + return text + +directory_name = sys.argv[1] +output_raw_dir = sys.argv[2] + +if not os.path.exists(output_raw_dir): + os.makedirs(output_raw_dir) + +files = os.listdir(directory_name) +for file in files: + add_line = False + write_text = False + if fnmatch.fnmatch(file,"*.xml"): + print "******",file + newname = file[:-4]+".txt" + raw_writer = open(output_raw_dir+"/"+newname,"w") + file_reader = open(directory_name+"/"+file) + text = "" + + header_end = False + while not header_end: + line = file_reader.readline() + m = re.search('\s*\]>', line) + if m: + header_end = True + + for line in file_reader.readlines(): + text = line.strip() + text = strip_text(text).strip() + if text != "": + raw_writer.write(text) + raw_writer.write("\n") + + raw_writer.close() + diff --git a/src/main/python/tei_entities.py b/src/main/python/tei_entities.py new file mode 100644 index 0000000..d880ad7 --- /dev/null +++ b/src/main/python/tei_entities.py @@ -0,0 +1,484 @@ +import re, sys + +# Class which replaces XML entities in raw text with the corresponding +# Unicode character. +# +# Comments from Ben: This is garbage! SAX and similar packages that are +# built into Python surely have functions to automatically do this. + +from codecs import latin_1_decode + +class pcl_tei_entities: + + def __init__(self): + self.entity_code_dict = { + "amp": 0x0026, + "pound": 0x00A3, + "aacute": 0x00E1, + "ampersand": 0x0026, + "Aacute": 0x00C1, + "acirc": 0x00E2, + "Acirc": 0x00C2, + "agrave": 0x00E0, + "Agrave": 0x00C0, + "aring": 0x00E5, + "Aring": 0x00C5, + "atilde": 0x00E3, + "Atilde": 0x00C3, + "auml": 0x00E4, + "Auml": 0x00C4, + "aelig": 0x00E6, + "AElig": 0x00C6, + "ccedil": 0x00E7, + "Ccedil": 0x00C7, + "eth": 0x00F0, + "ETH": 0x00D0, + "eacute": 0x00E9, + "Eacute": 0x00C9, + "ecirc": 0x00EA, + "Ecirc": 0x00CA, + "egrave": 0x00E8, + "Egrave": 0x00C8, + "euml": 0x00EB, + "Euml": 0x00CB, + "iacute": 0x00ED, + "Iacute": 0x00CD, + "icirc": 0x00EE, + "Icirc": 0x00CE, + "igrave": 0x00EC, + "Igrave": 0x00CC, + "iuml": 0x00EF, + "Iuml": 0x00CF, + "ntilde": 0x00F1, + "Ntilde": 0x00D1, + "oacute": 0x00F3, + "Oacute": 0x00D3, + "ocirc": 0x00F4, + "Ocirc": 0x00D4, + "ograve": 0x00F2, + "Ograve": 0x00D2, + "oslash": 0x00F8, + "Oslash": 0x00D8, + "otilde": 0x00F5, + "Otilde": 0x00D5, + "ouml": 0x00F6, + "Ouml": 0x00D6, + "szlig": 0x00DF, + "thorn": 0x00FE, + "THORN": 0x00DE, + "uacute": 0x00FA, + "Uacute": 0x00DA, + "ucirc": 0x00FB, + "Ucirc": 0x00DB, + "ugrave": 0x00F9, + "Ugrave": 0x00D9, + "uuml": 0x00FC, + "Uuml": 0x00DC, + "yacute": 0x00FD, + "Yacute": 0x00DD, + "yuml": 0x00FF, + "abreve": 0x0103, + "Abreve": 0x0102, + "amacr": 0x0101, + "Amacr": 0x0100, + "aogon": 0x0105, + "Aogon": 0x0104, + "cacute": 0x0107, + "Cacute": 0x0106, + "ccaron": 0x010D, + "Ccaron": 0x010C, + "ccirc": 0x0109, + "Ccirc": 0x0108, + "cdot": 0x010B, + "Cdot": 0x010A, + "dcaron": 0x010F, + "Dcaron": 0x010E, + "dstrok": 0x0111, + "Dstrok": 0x0110, + "ecaron": 0x011B, + "Ecaron": 0x011A, + "edot": 0x0117, + "Edot": 0x0116, + "emacr": 0x0113, + "Emacr": 0x0112, + "eogon": 0x0119, + "Eogon": 0x0118, + "gacute": 0x01F5, + "gbreve": 0x011F, + "Gbreve": 0x011E, + "Gcedil": 0x0122, + "gcirc": 0x011D, + "Gcirc": 0x011C, + "gdot": 0x0121, + "Gdot": 0x0120, + "hcirc": 0x0125, + "Hcirc": 0x0124, + "hstrok": 0x0127, + "Hstrok": 0x0126, + "Idot": 0x0130, + "Imacr": 0x012A, + "imacr": 0x012B, + "ijlig": 0x0133, + "IJlig": 0x0132, + "inodot": 0x0131, + "iogon": 0x012F, + "Iogon": 0x012E, + "itilde": 0x0129, + "Itilde": 0x0128, + "jcirc": 0x0135, + "Jcirc": 0x0134, + "kcedil": 0x0137, + "Kcedil": 0x0136, + "kgreen": 0x0138, + "lacute": 0x013A, + "Lacute": 0x0139, + "lcaron": 0x013E, + "Lcaron": 0x013D, + "lcedil": 0x013C, + "Lcedil": 0x013B, + "lmidot": 0x0140, + "Lmidot": 0x013F, + "lstrok": 0x0142, + "Lstrok": 0x0141, + "nacute": 0x0144, + "Nacute": 0x0143, + "eng": 0x014B, + "ENG": 0x014A, + "napos": 0x0149, + "ncaron": 0x0148, + "Ncaron": 0x0147, + "ncedil": 0x0146, + "Ncedil": 0x0145, + "odblac": 0x0151, + "Odblac": 0x0150, + "Omacr": 0x014C, + "omacr": 0x014D, + "oelig": 0x0153, + "OElig": 0x0152, + "racute": 0x0155, + "Racute": 0x0154, + "rcaron": 0x0159, + "Rcaron": 0x0158, + "rcedil": 0x0157, + "Rcedil": 0x0156, + "sacute": 0x015B, + "Sacute": 0x015A, + "scaron": 0x0161, + "Scaron": 0x0160, + "scedil": 0x015F, + "Scedil": 0x015E, + "scirc": 0x015D, + "Scirc": 0x015C, + "tcaron": 0x0165, + "Tcaron": 0x0164, + "tcedil": 0x0163, + "Tcedil": 0x0162, + "tstrok": 0x0167, + "Tstrok": 0x0166, + "ubreve": 0x016D, + "Ubreve": 0x016C, + "udblac": 0x0171, + "Udblac": 0x0170, + "umacr": 0x016B, + "Umacr": 0x016A, + "uogon": 0x0173, + "Uogon": 0x0172, + "uring": 0x016F, + "Uring": 0x016E, + "utilde": 0x0169, + "Utilde": 0x0168, + "wcirc": 0x0175, + "Wcirc": 0x0174, + "ycirc": 0x0177, + "Ycirc": 0x0176, + "Yuml": 0x0178, + "zacute": 0x017A, + "Zacute": 0x0179, + "zcaron": 0x017E, + "Zcaron": 0x017D, + "zdot": 0x017C, + "Zdot": 0x017B, + "agr": 0x03B1, + "Agr": 0x0391, + "bgr": 0x03B2, + "Bgr": 0x0392, + "ggr": 0x03B3, + "Ggr": 0x0393, + "dgr": 0x03B4, + "Dgr": 0x0394, + "egr": 0x03B5, + "Egr": 0x0395, + "zgr": 0x03B6, + "Zgr": 0x0396, + "eegr": 0x03B7, + "EEgr": 0x0397, + "thgr": 0x03B8, + "THgr": 0x0398, + "igr": 0x03B9, + "Igr": 0x0399, + "kgr": 0x03BA, + "Kgr": 0x039A, + "lgr": 0x03BB, + "Lgr": 0x039B, + "mgr": 0x03BC, + "Mgr": 0x039C, + "ngr": 0x03BD, + "Ngr": 0x039D, + "xgr": 0x03BE, + "Xgr": 0x039E, + "ogr": 0x03BF, + "Ogr": 0x039F, + "pgr": 0x03C0, + "Pgr": 0x03A0, + "rgr": 0x03C1, + "Rgr": 0x03A1, + "sgr": 0x03C3, + "Sgr": 0x03A3, + "sfgr": 0x03C2, + "tgr": 0x03C4, + "Tgr": 0x03A4, + "ugr": 0x03C5, + "Ugr": 0x03A5, + "phgr": 0x03C6, + "PHgr": 0x03A6, + "khgr": 0x03C7, + "KHgr": 0x03A7, + "psgr": 0x03C8, + "PSgr": 0x03A8, + "ohgr": 0x03C9, + "OHgr": 0x03A9, + "half": 0x00BD, + "frac12": 0x00BD, + "frac14": 0x00BC, + "frac34": 0x00BE, + "frac18": 0x215B, + "frac38": 0x215C, + "frac58": 0x215D, + "frac78": 0x215E, + "sup1": 0x00B9, + "sup2": 0x00B2, + "sup3": 0x00B3, + "plus": 0x002B, + "plusmn": 0x00B1, + "equals": 0x003D, + "gt": 0x003E, + "divide": 0x00F7, + "times": 0x00D7, + "curren": 0x00A4, + "pound": 0x00A3, + "dollar": 0x0024, + "cent": 0x00A2, + "yen": 0x00A5, + "num": 0x0023, + "percnt": 0x0025, + "ast": 0x2217, + "commat": 0x0040, + "lsqb": 0x005B, + "bsol": 0x005C, + "rsqb": 0x005D, + "lcub": 0x007B, + "horbar": 0x2015, + "verbar": 0x007C, + "rcub": 0x007D, + "micro": 0x00B5, + "ohm": 0x2126, + "deg": 0x00B0, + "ordm": 0x00BA, + "ordf": 0x00AA, + "sect": 0x00A7, + "para": 0x00B6, + "middot": 0x00B7, + "larr": 0x2190, + "rarr": 0x2192, + "uarr": 0x2191, + "darr": 0x2193, + "copy": 0x00A9, + "reg": 0x00AF, + "trade": 0x2122, + "brvbar": 0x00A6, + "not": 0x00AC, + "sung": 0x2669, + "excl": 0x0021, + "iexcl": 0x00A1, + "quot": 0x0022, + "apos": 0x0027, + "lpar": 0x0028, + "rpar": 0x0029, + "comma": 0x002C, + "lowbar": 0x005F, + "hyphen": 0xE4F8, + "period": 0x002E, + "sol": 0x002F, + "colon": 0x003A, + "semi": 0x003B, + "quest": 0x003F, + "iquest": 0x00BF, + "laquo": 0x00AB, + "raquo": 0x00BB, + "lsquo": 0x2018, + "rsquo": 0x2019, + "ldquo": 0x201C, + "rdquo": 0x201D, + "nbsp": 0x00A0, + "shy": 0x00AD, + "acute": 0x00B4, + "breve": 0x02D8, + "caron": 0x02C7, + "cedil": 0x00B8, + "circ": 0x2218, + "dblac": 0x02DD, + "die": 0x00A8, + "dot": 0x02D9, + "grave": 0x0060, + "macr": 0x00AF, + "ogon": 0x02DB, + "ring": 0x02DA, + "tilde": 0x007E, + "uml": 0x00A8, + "emsp": 0x2003, + "ensp": 0x2002, + "emsp13": 0x2004, + "emsp14": 0x2005, + "numsp": 0x2007, + "puncsp": 0x2008, + "thinsp": 0x2009, + "hairsp": 0x200A, + "mdash": 0x2014, + "ndash": 0x2013, + "dash": 0x2010, + "blank": 0x2423, + "hellip": 0x2026, + "nldr": 0x2025, + "frac13": 0x2153, + "frac23": 0x2154, + "frac15": 0x2155, + "frac25": 0x2156, + "frac35": 0x2157, + "frac45": 0x2158, + "frac16": 0x2159, + "frac56": 0x215A, + "incare": 0x2105, + "block": 0x2588, + "uhblk": 0x2580, + "lhblk": 0x2584, + "blk14": 0x2591, + "blk12": 0x2592, + "blk34": 0x2593, + "marker": 0x25AE, + "cir": 0x25CB, + "squ": 0x25A1, + "rect": 0x25AD, + "utri": 0x25B5, + "dtri": 0x25BF, + "star": 0x22C6, + "bull": 0x2022, + "squf": 0x25AA, + "utrif": 0x25B4, + "dtrif": 0x25BE, + "ltrif": 0x25C2, + "rtrif": 0x25B8, + "clubs": 0x2663, + "diams": 0x2666, + "hearts": 0x2665, + "spades": 0x2660, + "malt": 0x2720, + "dagger": 0x2020, + "Dagger": 0x2021, + "check": 0x2713, + "cross": 0x2717, + "sharp": 0x266F, + "flat": 0x266D, + "male": 0x2642, + "female": 0x2640, + "phone": 0x260E, + "telrec": 0x2315, + "copysr": 0x2117, + "caret": 0x2041, + "lsquor": 0x201A, + "ldquor": 0x201E, + "fflig": 0xFB00, + "filig": 0xFB01, + "ffilig": 0xFB03, + "ffllig": 0xFB04, + "fllig": 0xFB02, + "mldr": 0x2026, + "rdquor": 0x201C, + "rsquor": 0x2018, + "vellip": 0x22EE, + "hybull": 0x2043, + "loz": 0x25CA, + "lozf": 0x2726, + "ltri": 0x25C3, + "rtri": 0x25B9, + "starf": 0x2605, + "natur": 0x266E, + "rx": 0x211E, + "sext": 0x2736, + "target": 0x2316, + "dlcrop": 0x230D, + "drcrop": 0x230C, + "ulcrop": 0x230F, + "urcrop": 0x230E, + "boxh": 0x2500, + "boxv": 0x2502, + "boxur": 0x2514, + "boxul": 0x2518, + "boxdl": 0x2510, + "boxdr": 0x250C, + "boxvr": 0x251C, + "boxhu": 0x2534, + "boxvl": 0x2524, + "boxhd": 0x252C, + "boxvh": 0x253C, + "boxvR": 0x255E, + "boxhU": 0x2567, + "boxvL": 0x2561, + "boxhD": 0x2564, + "boxvH": 0x256A, + "boxH": 0x2550, + "boxV": 0x2551, + "boxUR": 0x2558, + "boxUL": 0x255B, + "boxDL": 0x2555, + "boxDR": 0x2552, + "boxVR": 0x255F, + "boxHU": 0x2568, + "boxVL": 0x2562, + "boxHD": 0x2565, + "boxVH": 0x256B, + "boxVr": 0x2560, + "boxHu": 0x2569, + "boxVl": 0x2563, + "boxHd": 0x2566, + "boxVh": 0x256C, + "boxuR": 0x2559, + "boxUl": 0x255C, + "boxdL": 0x2556, + "boxDr": 0x2553, + "boxUr": 0x255A, + "boxuL": 0x255D, + "boxDl": 0x2557, + "boxdR": 0x2554 + } + self.entity_char_dict = {} + + for ent, code in self.entity_code_dict.iteritems(): + try: + self.entity_char_dict[ent] = latin_1_decode(chr(code),"utf8")[0] + except ValueError,UnicodeEncodeError: + self.entity_char_dict[ent] = unichr(code) + + def print_chars(self): + for ent, code in self.entity_char_dict.iteritems(): + print ent, code + + def replace_entities(self,text): + text = re.sub('&(?P[A-Za-z]+);', r'%(\?)', text) + otext = text + try: + text = text % self.entity_code_dict + return text + except KeyError: + print >>sys.stderr, "KeyError occurred at :", text + print >>sys.stderr, "Original text was :", otext + diff --git a/src/main/python/trrraw2plain.py b/src/main/python/trrraw2plain.py new file mode 100644 index 0000000..3ef7e81 --- /dev/null +++ b/src/main/python/trrraw2plain.py @@ -0,0 +1,35 @@ +import sys, os + +outDir = sys.argv[2] +if(not outDir[-1] == "/"): + outDir += "/" + +def processFile(filename): + global outDir + inFile = open(filename,'r') + newFilename = filename[filename.rfind("/")+1:-3] + ".txt" + outFile = open(outDir + newFilename, 'w') + wroteSomething = False + while(True): + curLine = inFile.readline() + if(curLine == ""): break + if(curLine.startswith(" ") or curLine.startswith("\t")): continue + nextToken = curLine.split()[0] + processedToken = nextToken.replace("&equo;", "'").replace("&dquo;", '"').replace("$", "$").replace("‐", "-").replace("&", "&").replace("×", "*") + if(processedToken[0].isalnum() and wroteSomething): + outFile.write(" ") + outFile.write(processedToken) + wroteSomething = True + inFile.close() + +def processDirectory(dirname): + fileList = os.listdir(dirname) + if(not dirname[-1] == "/"): + dirname += "/" + for filename in fileList: + if(os.path.isdir(dirname + filename)): + processDirectory(dirname + filename) + elif(os.path.isfile(dirname + filename)): + processFile(dirname + filename) + +processDirectory(sys.argv[1]) diff --git a/src/main/python/twitter-graphs/twitter.py b/src/main/python/twitter-graphs/twitter.py new file mode 100644 index 0000000..31a4f1e --- /dev/null +++ b/src/main/python/twitter-graphs/twitter.py @@ -0,0 +1,3908 @@ +#!/usr/bin/python2.4 +# +# Copyright 2007 The Python-Twitter Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''A library that provides a Python interface to the Twitter API''' + +__author__ = 'python-twitter@googlegroups.com' +__version__ = '0.8.2' + + +import base64 +import calendar +import datetime +import httplib +import os +import rfc822 +import sys +import tempfile +import textwrap +import time +import calendar +import urllib +import urllib2 +import urlparse +import gzip +import StringIO + +try: + # Python >= 2.6 + import json as simplejson +except ImportError: + try: + # Python < 2.6 + import simplejson + except ImportError: + try: + # Google App Engine + from django.utils import simplejson + except ImportError: + raise ImportError, "Unable to load a json library" + +# parse_qsl moved to urlparse module in v2.6 +try: + from urlparse import parse_qsl, parse_qs +except ImportError: + from cgi import parse_qsl, parse_qs + +try: + from hashlib import md5 +except ImportError: + from md5 import md5 + +import oauth2 as oauth + + +CHARACTER_LIMIT = 140 + +# A singleton representing a lazily instantiated FileCache. +DEFAULT_CACHE = object() + +REQUEST_TOKEN_URL = 'https://api.twitter.com/oauth/request_token' +ACCESS_TOKEN_URL = 'https://api.twitter.com/oauth/access_token' +AUTHORIZATION_URL = 'https://api.twitter.com/oauth/authorize' +SIGNIN_URL = 'https://api.twitter.com/oauth/authenticate' + + +class TwitterError(Exception): + '''Base class for Twitter errors''' + + @property + def message(self): + '''Returns the first argument used to construct this error.''' + return self.args[0] + + +class Status(object): + '''A class representing the Status structure used by the twitter API. + + The Status structure exposes the following properties: + + status.created_at + status.created_at_in_seconds # read only + status.favorited + status.in_reply_to_screen_name + status.in_reply_to_user_id + status.in_reply_to_status_id + status.truncated + status.source + status.id + status.text + status.location + status.relative_created_at # read only + status.user + status.urls + status.user_mentions + status.hashtags + status.geo + status.place + status.coordinates + status.contributors + ''' + def __init__(self, + created_at=None, + favorited=None, + id=None, + text=None, + location=None, + user=None, + in_reply_to_screen_name=None, + in_reply_to_user_id=None, + in_reply_to_status_id=None, + truncated=None, + source=None, + now=None, + urls=None, + user_mentions=None, + hashtags=None, + geo=None, + place=None, + coordinates=None, + contributors=None, + retweeted=None, + retweeted_status=None, + retweet_count=None): + '''An object to hold a Twitter status message. + + This class is normally instantiated by the twitter.Api class and + returned in a sequence. + + Note: Dates are posted in the form "Sat Jan 27 04:17:38 +0000 2007" + + Args: + created_at: + The time this status message was posted. [Optional] + favorited: + Whether this is a favorite of the authenticated user. [Optional] + id: + The unique id of this status message. [Optional] + text: + The text of this status message. [Optional] + location: + the geolocation string associated with this message. [Optional] + relative_created_at: + A human readable string representing the posting time. [Optional] + user: + A twitter.User instance representing the person posting the + message. [Optional] + now: + The current time, if the client choses to set it. + Defaults to the wall clock time. [Optional] + urls: + user_mentions: + hashtags: + geo: + place: + coordinates: + contributors: + retweeted: + retweeted_status: + retweet_count: + ''' + self.created_at = created_at + self.favorited = favorited + self.id = id + self.text = text + self.location = location + self.user = user + self.now = now + self.in_reply_to_screen_name = in_reply_to_screen_name + self.in_reply_to_user_id = in_reply_to_user_id + self.in_reply_to_status_id = in_reply_to_status_id + self.truncated = truncated + self.retweeted = retweeted + self.source = source + self.urls = urls + self.user_mentions = user_mentions + self.hashtags = hashtags + self.geo = geo + self.place = place + self.coordinates = coordinates + self.contributors = contributors + self.retweeted_status = retweeted_status + self.retweet_count = retweet_count + + def GetCreatedAt(self): + '''Get the time this status message was posted. + + Returns: + The time this status message was posted + ''' + return self._created_at + + def SetCreatedAt(self, created_at): + '''Set the time this status message was posted. + + Args: + created_at: + The time this status message was created + ''' + self._created_at = created_at + + created_at = property(GetCreatedAt, SetCreatedAt, + doc='The time this status message was posted.') + + def GetCreatedAtInSeconds(self): + '''Get the time this status message was posted, in seconds since the epoch. + + Returns: + The time this status message was posted, in seconds since the epoch. + ''' + return calendar.timegm(rfc822.parsedate(self.created_at)) + + created_at_in_seconds = property(GetCreatedAtInSeconds, + doc="The time this status message was " + "posted, in seconds since the epoch") + + def GetFavorited(self): + '''Get the favorited setting of this status message. + + Returns: + True if this status message is favorited; False otherwise + ''' + return self._favorited + + def SetFavorited(self, favorited): + '''Set the favorited state of this status message. + + Args: + favorited: + boolean True/False favorited state of this status message + ''' + self._favorited = favorited + + favorited = property(GetFavorited, SetFavorited, + doc='The favorited state of this status message.') + + def GetId(self): + '''Get the unique id of this status message. + + Returns: + The unique id of this status message + ''' + return self._id + + def SetId(self, id): + '''Set the unique id of this status message. + + Args: + id: + The unique id of this status message + ''' + self._id = id + + id = property(GetId, SetId, + doc='The unique id of this status message.') + + def GetInReplyToScreenName(self): + return self._in_reply_to_screen_name + + def SetInReplyToScreenName(self, in_reply_to_screen_name): + self._in_reply_to_screen_name = in_reply_to_screen_name + + in_reply_to_screen_name = property(GetInReplyToScreenName, SetInReplyToScreenName, + doc='') + + def GetInReplyToUserId(self): + return self._in_reply_to_user_id + + def SetInReplyToUserId(self, in_reply_to_user_id): + self._in_reply_to_user_id = in_reply_to_user_id + + in_reply_to_user_id = property(GetInReplyToUserId, SetInReplyToUserId, + doc='') + + def GetInReplyToStatusId(self): + return self._in_reply_to_status_id + + def SetInReplyToStatusId(self, in_reply_to_status_id): + self._in_reply_to_status_id = in_reply_to_status_id + + in_reply_to_status_id = property(GetInReplyToStatusId, SetInReplyToStatusId, + doc='') + + def GetTruncated(self): + return self._truncated + + def SetTruncated(self, truncated): + self._truncated = truncated + + truncated = property(GetTruncated, SetTruncated, + doc='') + + def GetRetweeted(self): + return self._retweeted + + def SetRetweeted(self, retweeted): + self._retweeted = retweeted + + retweeted = property(GetRetweeted, SetRetweeted, + doc='') + + def GetSource(self): + return self._source + + def SetSource(self, source): + self._source = source + + source = property(GetSource, SetSource, + doc='') + + def GetText(self): + '''Get the text of this status message. + + Returns: + The text of this status message. + ''' + return self._text + + def SetText(self, text): + '''Set the text of this status message. + + Args: + text: + The text of this status message + ''' + self._text = text + + text = property(GetText, SetText, + doc='The text of this status message') + + def GetLocation(self): + '''Get the geolocation associated with this status message + + Returns: + The geolocation string of this status message. + ''' + return self._location + + def SetLocation(self, location): + '''Set the geolocation associated with this status message + + Args: + location: + The geolocation string of this status message + ''' + self._location = location + + location = property(GetLocation, SetLocation, + doc='The geolocation string of this status message') + + def GetRelativeCreatedAt(self): + '''Get a human redable string representing the posting time + + Returns: + A human readable string representing the posting time + ''' + fudge = 1.25 + delta = long(self.now) - long(self.created_at_in_seconds) + + if delta < (1 * fudge): + return 'about a second ago' + elif delta < (60 * (1/fudge)): + return 'about %d seconds ago' % (delta) + elif delta < (60 * fudge): + return 'about a minute ago' + elif delta < (60 * 60 * (1/fudge)): + return 'about %d minutes ago' % (delta / 60) + elif delta < (60 * 60 * fudge) or delta / (60 * 60) == 1: + return 'about an hour ago' + elif delta < (60 * 60 * 24 * (1/fudge)): + return 'about %d hours ago' % (delta / (60 * 60)) + elif delta < (60 * 60 * 24 * fudge) or delta / (60 * 60 * 24) == 1: + return 'about a day ago' + else: + return 'about %d days ago' % (delta / (60 * 60 * 24)) + + relative_created_at = property(GetRelativeCreatedAt, + doc='Get a human readable string representing ' + 'the posting time') + + def GetUser(self): + '''Get a twitter.User reprenting the entity posting this status message. + + Returns: + A twitter.User reprenting the entity posting this status message + ''' + return self._user + + def SetUser(self, user): + '''Set a twitter.User reprenting the entity posting this status message. + + Args: + user: + A twitter.User reprenting the entity posting this status message + ''' + self._user = user + + user = property(GetUser, SetUser, + doc='A twitter.User reprenting the entity posting this ' + 'status message') + + def GetNow(self): + '''Get the wallclock time for this status message. + + Used to calculate relative_created_at. Defaults to the time + the object was instantiated. + + Returns: + Whatever the status instance believes the current time to be, + in seconds since the epoch. + ''' + if self._now is None: + self._now = time.time() + return self._now + + def SetNow(self, now): + '''Set the wallclock time for this status message. + + Used to calculate relative_created_at. Defaults to the time + the object was instantiated. + + Args: + now: + The wallclock time for this instance. + ''' + self._now = now + + now = property(GetNow, SetNow, + doc='The wallclock time for this status instance.') + + def GetGeo(self): + return self._geo + + def SetGeo(self, geo): + self._geo = geo + + geo = property(GetGeo, SetGeo, + doc='') + + def GetPlace(self): + return self._place + + def SetPlace(self, place): + self._place = place + + place = property(GetPlace, SetPlace, + doc='') + + def GetCoordinates(self): + return self._coordinates + + def SetCoordinates(self, coordinates): + self._coordinates = coordinates + + coordinates = property(GetCoordinates, SetCoordinates, + doc='') + + def GetContributors(self): + return self._contributors + + def SetContributors(self, contributors): + self._contributors = contributors + + contributors = property(GetContributors, SetContributors, + doc='') + + def GetRetweeted_status(self): + return self._retweeted_status + + def SetRetweeted_status(self, retweeted_status): + self._retweeted_status = retweeted_status + + retweeted_status = property(GetRetweeted_status, SetRetweeted_status, + doc='') + + def GetRetweetCount(self): + return self._retweet_count + + def SetRetweetCount(self, retweet_count): + self._retweet_count = retweet_count + + retweet_count = property(GetRetweetCount, SetRetweetCount, + doc='') + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + try: + return other and \ + self.created_at == other.created_at and \ + self.id == other.id and \ + self.text == other.text and \ + self.location == other.location and \ + self.user == other.user and \ + self.in_reply_to_screen_name == other.in_reply_to_screen_name and \ + self.in_reply_to_user_id == other.in_reply_to_user_id and \ + self.in_reply_to_status_id == other.in_reply_to_status_id and \ + self.truncated == other.truncated and \ + self.retweeted == other.retweeted and \ + self.favorited == other.favorited and \ + self.source == other.source and \ + self.geo == other.geo and \ + self.place == other.place and \ + self.coordinates == other.coordinates and \ + self.contributors == other.contributors and \ + self.retweeted_status == other.retweeted_status and \ + self.retweet_count == other.retweet_count + except AttributeError: + return False + + def __str__(self): + '''A string representation of this twitter.Status instance. + + The return value is the same as the JSON string representation. + + Returns: + A string representation of this twitter.Status instance. + ''' + return self.AsJsonString() + + def AsJsonString(self): + '''A JSON string representation of this twitter.Status instance. + + Returns: + A JSON string representation of this twitter.Status instance + ''' + return simplejson.dumps(self.AsDict(), sort_keys=True) + + def AsDict(self): + '''A dict representation of this twitter.Status instance. + + The return value uses the same key names as the JSON representation. + + Return: + A dict representing this twitter.Status instance + ''' + data = {} + if self.created_at: + data['created_at'] = self.created_at + if self.favorited: + data['favorited'] = self.favorited + if self.id: + data['id'] = self.id + if self.text: + data['text'] = self.text + if self.location: + data['location'] = self.location + if self.user: + data['user'] = self.user.AsDict() + if self.in_reply_to_screen_name: + data['in_reply_to_screen_name'] = self.in_reply_to_screen_name + if self.in_reply_to_user_id: + data['in_reply_to_user_id'] = self.in_reply_to_user_id + if self.in_reply_to_status_id: + data['in_reply_to_status_id'] = self.in_reply_to_status_id + if self.truncated is not None: + data['truncated'] = self.truncated + if self.retweeted is not None: + data['retweeted'] = self.retweeted + if self.favorited is not None: + data['favorited'] = self.favorited + if self.source: + data['source'] = self.source + if self.geo: + data['geo'] = self.geo + if self.place: + data['place'] = self.place + if self.coordinates: + data['coordinates'] = self.coordinates + if self.contributors: + data['contributors'] = self.contributors + if self.hashtags: + data['hashtags'] = [h.text for h in self.hashtags] + if self.retweeted_status: + data['retweeted_status'] = self.retweeted_status.AsDict() + if self.retweet_count: + data['retweet_count'] = self.retweet_count + return data + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: A JSON dict, as converted from the JSON in the twitter API + Returns: + A twitter.Status instance + ''' + if 'user' in data: + user = User.NewFromJsonDict(data['user']) + else: + user = None + if 'retweeted_status' in data: + retweeted_status = Status.NewFromJsonDict(data['retweeted_status']) + else: + retweeted_status = None + urls = None + user_mentions = None + hashtags = None + if 'entities' in data: + if 'urls' in data['entities']: + urls = [Url.NewFromJsonDict(u) for u in data['entities']['urls']] + if 'user_mentions' in data['entities']: + user_mentions = [User.NewFromJsonDict(u) for u in data['entities']['user_mentions']] + if 'hashtags' in data['entities']: + hashtags = [Hashtag.NewFromJsonDict(h) for h in data['entities']['hashtags']] + + location_string = data.get('location', None) + + # get the /geo/coordinates/ data, parse to location string + if (data.get('geo', None) != None): + if (data['geo'].get('coordinates', None) != None): + if (len(data['geo']['coordinates']) == 2): + location_string = "%s, %s" % (data['geo']['coordinates'][0], data['geo']['coordinates'][1]) + + return Status(created_at=data.get('created_at', None), + favorited=data.get('favorited', None), + id=data.get('id', None), + text=data.get('text', None), + location=location_string,#data.get('location', None), + in_reply_to_screen_name=data.get('in_reply_to_screen_name', None), + in_reply_to_user_id=data.get('in_reply_to_user_id', None), + in_reply_to_status_id=data.get('in_reply_to_status_id', None), + truncated=data.get('truncated', None), + retweeted=data.get('retweeted', None), + source=data.get('source', None), + user=user, + urls=urls, + user_mentions=user_mentions, + hashtags=hashtags, + geo=data.get('geo', None), + place=data.get('place', None), + coordinates=data.get('coordinates', None), + contributors=data.get('contributors', None), + retweeted_status=retweeted_status, + retweet_count=data.get('retweet_count', None)) + + +class User(object): + '''A class representing the User structure used by the twitter API. + + The User structure exposes the following properties: + + user.id + user.name + user.screen_name + user.location + user.description + user.profile_image_url + user.profile_background_tile + user.profile_background_image_url + user.profile_sidebar_fill_color + user.profile_background_color + user.profile_link_color + user.profile_text_color + user.protected + user.utc_offset + user.time_zone + user.url + user.status + user.statuses_count + user.followers_count + user.friends_count + user.favourites_count + user.geo_enabled + user.verified + user.lang + user.notifications + user.contributors_enabled + user.created_at + user.listed_count + ''' + def __init__(self, + id=None, + name=None, + screen_name=None, + location=None, + description=None, + profile_image_url=None, + profile_background_tile=None, + profile_background_image_url=None, + profile_sidebar_fill_color=None, + profile_background_color=None, + profile_link_color=None, + profile_text_color=None, + protected=None, + utc_offset=None, + time_zone=None, + followers_count=None, + friends_count=None, + statuses_count=None, + favourites_count=None, + url=None, + status=None, + geo_enabled=None, + verified=None, + lang=None, + notifications=None, + contributors_enabled=None, + created_at=None, + listed_count=None): + self.id = id + self.name = name + self.screen_name = screen_name + self.location = location + self.description = description + self.profile_image_url = profile_image_url + self.profile_background_tile = profile_background_tile + self.profile_background_image_url = profile_background_image_url + self.profile_sidebar_fill_color = profile_sidebar_fill_color + self.profile_background_color = profile_background_color + self.profile_link_color = profile_link_color + self.profile_text_color = profile_text_color + self.protected = protected + self.utc_offset = utc_offset + self.time_zone = time_zone + self.followers_count = followers_count + self.friends_count = friends_count + self.statuses_count = statuses_count + self.favourites_count = favourites_count + self.url = url + self.status = status + self.geo_enabled = geo_enabled + self.verified = verified + self.lang = lang + self.notifications = notifications + self.contributors_enabled = contributors_enabled + self.created_at = created_at + self.listed_count = listed_count + + def GetId(self): + '''Get the unique id of this user. + + Returns: + The unique id of this user + ''' + return self._id + + def SetId(self, id): + '''Set the unique id of this user. + + Args: + id: The unique id of this user. + ''' + self._id = id + + id = property(GetId, SetId, + doc='The unique id of this user.') + + def GetName(self): + '''Get the real name of this user. + + Returns: + The real name of this user + ''' + return self._name + + def SetName(self, name): + '''Set the real name of this user. + + Args: + name: The real name of this user + ''' + self._name = name + + name = property(GetName, SetName, + doc='The real name of this user.') + + def GetScreenName(self): + '''Get the short twitter name of this user. + + Returns: + The short twitter name of this user + ''' + return self._screen_name + + def SetScreenName(self, screen_name): + '''Set the short twitter name of this user. + + Args: + screen_name: the short twitter name of this user + ''' + self._screen_name = screen_name + + screen_name = property(GetScreenName, SetScreenName, + doc='The short twitter name of this user.') + + def GetLocation(self): + '''Get the geographic location of this user. + + Returns: + The geographic location of this user + ''' + return self._location + + def SetLocation(self, location): + '''Set the geographic location of this user. + + Args: + location: The geographic location of this user + ''' + self._location = location + + location = property(GetLocation, SetLocation, + doc='The geographic location of this user.') + + def GetDescription(self): + '''Get the short text description of this user. + + Returns: + The short text description of this user + ''' + return self._description + + def SetDescription(self, description): + '''Set the short text description of this user. + + Args: + description: The short text description of this user + ''' + self._description = description + + description = property(GetDescription, SetDescription, + doc='The short text description of this user.') + + def GetUrl(self): + '''Get the homepage url of this user. + + Returns: + The homepage url of this user + ''' + return self._url + + def SetUrl(self, url): + '''Set the homepage url of this user. + + Args: + url: The homepage url of this user + ''' + self._url = url + + url = property(GetUrl, SetUrl, + doc='The homepage url of this user.') + + def GetProfileImageUrl(self): + '''Get the url of the thumbnail of this user. + + Returns: + The url of the thumbnail of this user + ''' + return self._profile_image_url + + def SetProfileImageUrl(self, profile_image_url): + '''Set the url of the thumbnail of this user. + + Args: + profile_image_url: The url of the thumbnail of this user + ''' + self._profile_image_url = profile_image_url + + profile_image_url= property(GetProfileImageUrl, SetProfileImageUrl, + doc='The url of the thumbnail of this user.') + + def GetProfileBackgroundTile(self): + '''Boolean for whether to tile the profile background image. + + Returns: + True if the background is to be tiled, False if not, None if unset. + ''' + return self._profile_background_tile + + def SetProfileBackgroundTile(self, profile_background_tile): + '''Set the boolean flag for whether to tile the profile background image. + + Args: + profile_background_tile: Boolean flag for whether to tile or not. + ''' + self._profile_background_tile = profile_background_tile + + profile_background_tile = property(GetProfileBackgroundTile, SetProfileBackgroundTile, + doc='Boolean for whether to tile the background image.') + + def GetProfileBackgroundImageUrl(self): + return self._profile_background_image_url + + def SetProfileBackgroundImageUrl(self, profile_background_image_url): + self._profile_background_image_url = profile_background_image_url + + profile_background_image_url = property(GetProfileBackgroundImageUrl, SetProfileBackgroundImageUrl, + doc='The url of the profile background of this user.') + + def GetProfileSidebarFillColor(self): + return self._profile_sidebar_fill_color + + def SetProfileSidebarFillColor(self, profile_sidebar_fill_color): + self._profile_sidebar_fill_color = profile_sidebar_fill_color + + profile_sidebar_fill_color = property(GetProfileSidebarFillColor, SetProfileSidebarFillColor) + + def GetProfileBackgroundColor(self): + return self._profile_background_color + + def SetProfileBackgroundColor(self, profile_background_color): + self._profile_background_color = profile_background_color + + profile_background_color = property(GetProfileBackgroundColor, SetProfileBackgroundColor) + + def GetProfileLinkColor(self): + return self._profile_link_color + + def SetProfileLinkColor(self, profile_link_color): + self._profile_link_color = profile_link_color + + profile_link_color = property(GetProfileLinkColor, SetProfileLinkColor) + + def GetProfileTextColor(self): + return self._profile_text_color + + def SetProfileTextColor(self, profile_text_color): + self._profile_text_color = profile_text_color + + profile_text_color = property(GetProfileTextColor, SetProfileTextColor) + + def GetProtected(self): + return self._protected + + def SetProtected(self, protected): + self._protected = protected + + protected = property(GetProtected, SetProtected) + + def GetUtcOffset(self): + return self._utc_offset + + def SetUtcOffset(self, utc_offset): + self._utc_offset = utc_offset + + utc_offset = property(GetUtcOffset, SetUtcOffset) + + def GetTimeZone(self): + '''Returns the current time zone string for the user. + + Returns: + The descriptive time zone string for the user. + ''' + return self._time_zone + + def SetTimeZone(self, time_zone): + '''Sets the user's time zone string. + + Args: + time_zone: + The descriptive time zone to assign for the user. + ''' + self._time_zone = time_zone + + time_zone = property(GetTimeZone, SetTimeZone) + + def GetStatus(self): + '''Get the latest twitter.Status of this user. + + Returns: + The latest twitter.Status of this user + ''' + return self._status + + def SetStatus(self, status): + '''Set the latest twitter.Status of this user. + + Args: + status: + The latest twitter.Status of this user + ''' + self._status = status + + status = property(GetStatus, SetStatus, + doc='The latest twitter.Status of this user.') + + def GetFriendsCount(self): + '''Get the friend count for this user. + + Returns: + The number of users this user has befriended. + ''' + return self._friends_count + + def SetFriendsCount(self, count): + '''Set the friend count for this user. + + Args: + count: + The number of users this user has befriended. + ''' + self._friends_count = count + + friends_count = property(GetFriendsCount, SetFriendsCount, + doc='The number of friends for this user.') + + def GetListedCount(self): + '''Get the listed count for this user. + + Returns: + The number of lists this user belongs to. + ''' + return self._listed_count + + def SetListedCount(self, count): + '''Set the listed count for this user. + + Args: + count: + The number of lists this user belongs to. + ''' + self._listed_count = count + + listed_count = property(GetListedCount, SetListedCount, + doc='The number of lists this user belongs to.') + + def GetFollowersCount(self): + '''Get the follower count for this user. + + Returns: + The number of users following this user. + ''' + return self._followers_count + + def SetFollowersCount(self, count): + '''Set the follower count for this user. + + Args: + count: + The number of users following this user. + ''' + self._followers_count = count + + followers_count = property(GetFollowersCount, SetFollowersCount, + doc='The number of users following this user.') + + def GetStatusesCount(self): + '''Get the number of status updates for this user. + + Returns: + The number of status updates for this user. + ''' + return self._statuses_count + + def SetStatusesCount(self, count): + '''Set the status update count for this user. + + Args: + count: + The number of updates for this user. + ''' + self._statuses_count = count + + statuses_count = property(GetStatusesCount, SetStatusesCount, + doc='The number of updates for this user.') + + def GetFavouritesCount(self): + '''Get the number of favourites for this user. + + Returns: + The number of favourites for this user. + ''' + return self._favourites_count + + def SetFavouritesCount(self, count): + '''Set the favourite count for this user. + + Args: + count: + The number of favourites for this user. + ''' + self._favourites_count = count + + favourites_count = property(GetFavouritesCount, SetFavouritesCount, + doc='The number of favourites for this user.') + + def GetGeoEnabled(self): + '''Get the setting of geo_enabled for this user. + + Returns: + True/False if Geo tagging is enabled + ''' + return self._geo_enabled + + def SetGeoEnabled(self, geo_enabled): + '''Set the latest twitter.geo_enabled of this user. + + Args: + geo_enabled: + True/False if Geo tagging is to be enabled + ''' + self._geo_enabled = geo_enabled + + geo_enabled = property(GetGeoEnabled, SetGeoEnabled, + doc='The value of twitter.geo_enabled for this user.') + + def GetVerified(self): + '''Get the setting of verified for this user. + + Returns: + True/False if user is a verified account + ''' + return self._verified + + def SetVerified(self, verified): + '''Set twitter.verified for this user. + + Args: + verified: + True/False if user is a verified account + ''' + self._verified = verified + + verified = property(GetVerified, SetVerified, + doc='The value of twitter.verified for this user.') + + def GetLang(self): + '''Get the setting of lang for this user. + + Returns: + language code of the user + ''' + return self._lang + + def SetLang(self, lang): + '''Set twitter.lang for this user. + + Args: + lang: + language code for the user + ''' + self._lang = lang + + lang = property(GetLang, SetLang, + doc='The value of twitter.lang for this user.') + + def GetNotifications(self): + '''Get the setting of notifications for this user. + + Returns: + True/False for the notifications setting of the user + ''' + return self._notifications + + def SetNotifications(self, notifications): + '''Set twitter.notifications for this user. + + Args: + notifications: + True/False notifications setting for the user + ''' + self._notifications = notifications + + notifications = property(GetNotifications, SetNotifications, + doc='The value of twitter.notifications for this user.') + + def GetContributorsEnabled(self): + '''Get the setting of contributors_enabled for this user. + + Returns: + True/False contributors_enabled of the user + ''' + return self._contributors_enabled + + def SetContributorsEnabled(self, contributors_enabled): + '''Set twitter.contributors_enabled for this user. + + Args: + contributors_enabled: + True/False contributors_enabled setting for the user + ''' + self._contributors_enabled = contributors_enabled + + contributors_enabled = property(GetContributorsEnabled, SetContributorsEnabled, + doc='The value of twitter.contributors_enabled for this user.') + + def GetCreatedAt(self): + '''Get the setting of created_at for this user. + + Returns: + created_at value of the user + ''' + return self._created_at + + def SetCreatedAt(self, created_at): + '''Set twitter.created_at for this user. + + Args: + created_at: + created_at value for the user + ''' + self._created_at = created_at + + created_at = property(GetCreatedAt, SetCreatedAt, + doc='The value of twitter.created_at for this user.') + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + try: + return other and \ + self.id == other.id and \ + self.name == other.name and \ + self.screen_name == other.screen_name and \ + self.location == other.location and \ + self.description == other.description and \ + self.profile_image_url == other.profile_image_url and \ + self.profile_background_tile == other.profile_background_tile and \ + self.profile_background_image_url == other.profile_background_image_url and \ + self.profile_sidebar_fill_color == other.profile_sidebar_fill_color and \ + self.profile_background_color == other.profile_background_color and \ + self.profile_link_color == other.profile_link_color and \ + self.profile_text_color == other.profile_text_color and \ + self.protected == other.protected and \ + self.utc_offset == other.utc_offset and \ + self.time_zone == other.time_zone and \ + self.url == other.url and \ + self.statuses_count == other.statuses_count and \ + self.followers_count == other.followers_count and \ + self.favourites_count == other.favourites_count and \ + self.friends_count == other.friends_count and \ + self.status == other.status and \ + self.geo_enabled == other.geo_enabled and \ + self.verified == other.verified and \ + self.lang == other.lang and \ + self.notifications == other.notifications and \ + self.contributors_enabled == other.contributors_enabled and \ + self.created_at == other.created_at and \ + self.listed_count == other.listed_count + + except AttributeError: + return False + + def __str__(self): + '''A string representation of this twitter.User instance. + + The return value is the same as the JSON string representation. + + Returns: + A string representation of this twitter.User instance. + ''' + return self.AsJsonString() + + def AsJsonString(self): + '''A JSON string representation of this twitter.User instance. + + Returns: + A JSON string representation of this twitter.User instance + ''' + return simplejson.dumps(self.AsDict(), sort_keys=True) + + def AsDict(self): + '''A dict representation of this twitter.User instance. + + The return value uses the same key names as the JSON representation. + + Return: + A dict representing this twitter.User instance + ''' + data = {} + if self.id: + data['id'] = self.id + if self.name: + data['name'] = self.name + if self.screen_name: + data['screen_name'] = self.screen_name + if self.location: + data['location'] = self.location + if self.description: + data['description'] = self.description + if self.profile_image_url: + data['profile_image_url'] = self.profile_image_url + if self.profile_background_tile is not None: + data['profile_background_tile'] = self.profile_background_tile + if self.profile_background_image_url: + data['profile_sidebar_fill_color'] = self.profile_background_image_url + if self.profile_background_color: + data['profile_background_color'] = self.profile_background_color + if self.profile_link_color: + data['profile_link_color'] = self.profile_link_color + if self.profile_text_color: + data['profile_text_color'] = self.profile_text_color + if self.protected is not None: + data['protected'] = self.protected + if self.utc_offset: + data['utc_offset'] = self.utc_offset + if self.time_zone: + data['time_zone'] = self.time_zone + if self.url: + data['url'] = self.url + if self.status: + data['status'] = self.status.AsDict() + if self.friends_count: + data['friends_count'] = self.friends_count + if self.followers_count: + data['followers_count'] = self.followers_count + if self.statuses_count: + data['statuses_count'] = self.statuses_count + if self.favourites_count: + data['favourites_count'] = self.favourites_count + if self.geo_enabled: + data['geo_enabled'] = self.geo_enabled + if self.verified: + data['verified'] = self.verified + if self.lang: + data['lang'] = self.lang + if self.notifications: + data['notifications'] = self.notifications + if self.contributors_enabled: + data['contributors_enabled'] = self.contributors_enabled + if self.created_at: + data['created_at'] = self.created_at + if self.listed_count: + data['listed_count'] = self.listed_count + + return data + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: + A JSON dict, as converted from the JSON in the twitter API + + Returns: + A twitter.User instance + ''' + if 'status' in data: + status = Status.NewFromJsonDict(data['status']) + else: + status = None + + return User(id=data.get('id', None), + name=data.get('name', None), + screen_name=data.get('screen_name', None), + location=data.get('location', None), + description=data.get('description', None), + statuses_count=data.get('statuses_count', None), + followers_count=data.get('followers_count', None), + favourites_count=data.get('favourites_count', None), + friends_count=data.get('friends_count', None), + profile_image_url=data.get('profile_image_url', None), + profile_background_tile = data.get('profile_background_tile', None), + profile_background_image_url = data.get('profile_background_image_url', None), + profile_sidebar_fill_color = data.get('profile_sidebar_fill_color', None), + profile_background_color = data.get('profile_background_color', None), + profile_link_color = data.get('profile_link_color', None), + profile_text_color = data.get('profile_text_color', None), + protected = data.get('protected', None), + utc_offset = data.get('utc_offset', None), + time_zone = data.get('time_zone', None), + url=data.get('url', None), + status=status, + geo_enabled=data.get('geo_enabled', None), + verified=data.get('verified', None), + lang=data.get('lang', None), + notifications=data.get('notifications', None), + contributors_enabled=data.get('contributors_enabled', None), + created_at=data.get('created_at', None), + listed_count=data.get('listed_count', None)) + +class List(object): + '''A class representing the List structure used by the twitter API. + + The List structure exposes the following properties: + + list.id + list.name + list.slug + list.description + list.full_name + list.mode + list.uri + list.member_count + list.subscriber_count + list.following + ''' + def __init__(self, + id=None, + name=None, + slug=None, + description=None, + full_name=None, + mode=None, + uri=None, + member_count=None, + subscriber_count=None, + following=None, + user=None): + self.id = id + self.name = name + self.slug = slug + self.description = description + self.full_name = full_name + self.mode = mode + self.uri = uri + self.member_count = member_count + self.subscriber_count = subscriber_count + self.following = following + self.user = user + + def GetId(self): + '''Get the unique id of this list. + + Returns: + The unique id of this list + ''' + return self._id + + def SetId(self, id): + '''Set the unique id of this list. + + Args: + id: + The unique id of this list. + ''' + self._id = id + + id = property(GetId, SetId, + doc='The unique id of this list.') + + def GetName(self): + '''Get the real name of this list. + + Returns: + The real name of this list + ''' + return self._name + + def SetName(self, name): + '''Set the real name of this list. + + Args: + name: + The real name of this list + ''' + self._name = name + + name = property(GetName, SetName, + doc='The real name of this list.') + + def GetSlug(self): + '''Get the slug of this list. + + Returns: + The slug of this list + ''' + return self._slug + + def SetSlug(self, slug): + '''Set the slug of this list. + + Args: + slug: + The slug of this list. + ''' + self._slug = slug + + slug = property(GetSlug, SetSlug, + doc='The slug of this list.') + + def GetDescription(self): + '''Get the description of this list. + + Returns: + The description of this list + ''' + return self._description + + def SetDescription(self, description): + '''Set the description of this list. + + Args: + description: + The description of this list. + ''' + self._description = description + + description = property(GetDescription, SetDescription, + doc='The description of this list.') + + def GetFull_name(self): + '''Get the full_name of this list. + + Returns: + The full_name of this list + ''' + return self._full_name + + def SetFull_name(self, full_name): + '''Set the full_name of this list. + + Args: + full_name: + The full_name of this list. + ''' + self._full_name = full_name + + full_name = property(GetFull_name, SetFull_name, + doc='The full_name of this list.') + + def GetMode(self): + '''Get the mode of this list. + + Returns: + The mode of this list + ''' + return self._mode + + def SetMode(self, mode): + '''Set the mode of this list. + + Args: + mode: + The mode of this list. + ''' + self._mode = mode + + mode = property(GetMode, SetMode, + doc='The mode of this list.') + + def GetUri(self): + '''Get the uri of this list. + + Returns: + The uri of this list + ''' + return self._uri + + def SetUri(self, uri): + '''Set the uri of this list. + + Args: + uri: + The uri of this list. + ''' + self._uri = uri + + uri = property(GetUri, SetUri, + doc='The uri of this list.') + + def GetMember_count(self): + '''Get the member_count of this list. + + Returns: + The member_count of this list + ''' + return self._member_count + + def SetMember_count(self, member_count): + '''Set the member_count of this list. + + Args: + member_count: + The member_count of this list. + ''' + self._member_count = member_count + + member_count = property(GetMember_count, SetMember_count, + doc='The member_count of this list.') + + def GetSubscriber_count(self): + '''Get the subscriber_count of this list. + + Returns: + The subscriber_count of this list + ''' + return self._subscriber_count + + def SetSubscriber_count(self, subscriber_count): + '''Set the subscriber_count of this list. + + Args: + subscriber_count: + The subscriber_count of this list. + ''' + self._subscriber_count = subscriber_count + + subscriber_count = property(GetSubscriber_count, SetSubscriber_count, + doc='The subscriber_count of this list.') + + def GetFollowing(self): + '''Get the following status of this list. + + Returns: + The following status of this list + ''' + return self._following + + def SetFollowing(self, following): + '''Set the following status of this list. + + Args: + following: + The following of this list. + ''' + self._following = following + + following = property(GetFollowing, SetFollowing, + doc='The following status of this list.') + + def GetUser(self): + '''Get the user of this list. + + Returns: + The owner of this list + ''' + return self._user + + def SetUser(self, user): + '''Set the user of this list. + + Args: + user: + The owner of this list. + ''' + self._user = user + + user = property(GetUser, SetUser, + doc='The owner of this list.') + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + try: + return other and \ + self.id == other.id and \ + self.name == other.name and \ + self.slug == other.slug and \ + self.description == other.description and \ + self.full_name == other.full_name and \ + self.mode == other.mode and \ + self.uri == other.uri and \ + self.member_count == other.member_count and \ + self.subscriber_count == other.subscriber_count and \ + self.following == other.following and \ + self.user == other.user + + except AttributeError: + return False + + def __str__(self): + '''A string representation of this twitter.List instance. + + The return value is the same as the JSON string representation. + + Returns: + A string representation of this twitter.List instance. + ''' + return self.AsJsonString() + + def AsJsonString(self): + '''A JSON string representation of this twitter.List instance. + + Returns: + A JSON string representation of this twitter.List instance + ''' + return simplejson.dumps(self.AsDict(), sort_keys=True) + + def AsDict(self): + '''A dict representation of this twitter.List instance. + + The return value uses the same key names as the JSON representation. + + Return: + A dict representing this twitter.List instance + ''' + data = {} + if self.id: + data['id'] = self.id + if self.name: + data['name'] = self.name + if self.slug: + data['slug'] = self.slug + if self.description: + data['description'] = self.description + if self.full_name: + data['full_name'] = self.full_name + if self.mode: + data['mode'] = self.mode + if self.uri: + data['uri'] = self.uri + if self.member_count is not None: + data['member_count'] = self.member_count + if self.subscriber_count is not None: + data['subscriber_count'] = self.subscriber_count + if self.following is not None: + data['following'] = self.following + if self.user is not None: + data['user'] = self.user + return data + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: + A JSON dict, as converted from the JSON in the twitter API + + Returns: + A twitter.List instance + ''' + if 'user' in data: + user = User.NewFromJsonDict(data['user']) + else: + user = None + return List(id=data.get('id', None), + name=data.get('name', None), + slug=data.get('slug', None), + description=data.get('description', None), + full_name=data.get('full_name', None), + mode=data.get('mode', None), + uri=data.get('uri', None), + member_count=data.get('member_count', None), + subscriber_count=data.get('subscriber_count', None), + following=data.get('following', None), + user=user) + +class DirectMessage(object): + '''A class representing the DirectMessage structure used by the twitter API. + + The DirectMessage structure exposes the following properties: + + direct_message.id + direct_message.created_at + direct_message.created_at_in_seconds # read only + direct_message.sender_id + direct_message.sender_screen_name + direct_message.recipient_id + direct_message.recipient_screen_name + direct_message.text + ''' + + def __init__(self, + id=None, + created_at=None, + sender_id=None, + sender_screen_name=None, + recipient_id=None, + recipient_screen_name=None, + text=None): + '''An object to hold a Twitter direct message. + + This class is normally instantiated by the twitter.Api class and + returned in a sequence. + + Note: Dates are posted in the form "Sat Jan 27 04:17:38 +0000 2007" + + Args: + id: + The unique id of this direct message. [Optional] + created_at: + The time this direct message was posted. [Optional] + sender_id: + The id of the twitter user that sent this message. [Optional] + sender_screen_name: + The name of the twitter user that sent this message. [Optional] + recipient_id: + The id of the twitter that received this message. [Optional] + recipient_screen_name: + The name of the twitter that received this message. [Optional] + text: + The text of this direct message. [Optional] + ''' + self.id = id + self.created_at = created_at + self.sender_id = sender_id + self.sender_screen_name = sender_screen_name + self.recipient_id = recipient_id + self.recipient_screen_name = recipient_screen_name + self.text = text + + def GetId(self): + '''Get the unique id of this direct message. + + Returns: + The unique id of this direct message + ''' + return self._id + + def SetId(self, id): + '''Set the unique id of this direct message. + + Args: + id: + The unique id of this direct message + ''' + self._id = id + + id = property(GetId, SetId, + doc='The unique id of this direct message.') + + def GetCreatedAt(self): + '''Get the time this direct message was posted. + + Returns: + The time this direct message was posted + ''' + return self._created_at + + def SetCreatedAt(self, created_at): + '''Set the time this direct message was posted. + + Args: + created_at: + The time this direct message was created + ''' + self._created_at = created_at + + created_at = property(GetCreatedAt, SetCreatedAt, + doc='The time this direct message was posted.') + + def GetCreatedAtInSeconds(self): + '''Get the time this direct message was posted, in seconds since the epoch. + + Returns: + The time this direct message was posted, in seconds since the epoch. + ''' + return calendar.timegm(rfc822.parsedate(self.created_at)) + + created_at_in_seconds = property(GetCreatedAtInSeconds, + doc="The time this direct message was " + "posted, in seconds since the epoch") + + def GetSenderId(self): + '''Get the unique sender id of this direct message. + + Returns: + The unique sender id of this direct message + ''' + return self._sender_id + + def SetSenderId(self, sender_id): + '''Set the unique sender id of this direct message. + + Args: + sender_id: + The unique sender id of this direct message + ''' + self._sender_id = sender_id + + sender_id = property(GetSenderId, SetSenderId, + doc='The unique sender id of this direct message.') + + def GetSenderScreenName(self): + '''Get the unique sender screen name of this direct message. + + Returns: + The unique sender screen name of this direct message + ''' + return self._sender_screen_name + + def SetSenderScreenName(self, sender_screen_name): + '''Set the unique sender screen name of this direct message. + + Args: + sender_screen_name: + The unique sender screen name of this direct message + ''' + self._sender_screen_name = sender_screen_name + + sender_screen_name = property(GetSenderScreenName, SetSenderScreenName, + doc='The unique sender screen name of this direct message.') + + def GetRecipientId(self): + '''Get the unique recipient id of this direct message. + + Returns: + The unique recipient id of this direct message + ''' + return self._recipient_id + + def SetRecipientId(self, recipient_id): + '''Set the unique recipient id of this direct message. + + Args: + recipient_id: + The unique recipient id of this direct message + ''' + self._recipient_id = recipient_id + + recipient_id = property(GetRecipientId, SetRecipientId, + doc='The unique recipient id of this direct message.') + + def GetRecipientScreenName(self): + '''Get the unique recipient screen name of this direct message. + + Returns: + The unique recipient screen name of this direct message + ''' + return self._recipient_screen_name + + def SetRecipientScreenName(self, recipient_screen_name): + '''Set the unique recipient screen name of this direct message. + + Args: + recipient_screen_name: + The unique recipient screen name of this direct message + ''' + self._recipient_screen_name = recipient_screen_name + + recipient_screen_name = property(GetRecipientScreenName, SetRecipientScreenName, + doc='The unique recipient screen name of this direct message.') + + def GetText(self): + '''Get the text of this direct message. + + Returns: + The text of this direct message. + ''' + return self._text + + def SetText(self, text): + '''Set the text of this direct message. + + Args: + text: + The text of this direct message + ''' + self._text = text + + text = property(GetText, SetText, + doc='The text of this direct message') + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + try: + return other and \ + self.id == other.id and \ + self.created_at == other.created_at and \ + self.sender_id == other.sender_id and \ + self.sender_screen_name == other.sender_screen_name and \ + self.recipient_id == other.recipient_id and \ + self.recipient_screen_name == other.recipient_screen_name and \ + self.text == other.text + except AttributeError: + return False + + def __str__(self): + '''A string representation of this twitter.DirectMessage instance. + + The return value is the same as the JSON string representation. + + Returns: + A string representation of this twitter.DirectMessage instance. + ''' + return self.AsJsonString() + + def AsJsonString(self): + '''A JSON string representation of this twitter.DirectMessage instance. + + Returns: + A JSON string representation of this twitter.DirectMessage instance + ''' + return simplejson.dumps(self.AsDict(), sort_keys=True) + + def AsDict(self): + '''A dict representation of this twitter.DirectMessage instance. + + The return value uses the same key names as the JSON representation. + + Return: + A dict representing this twitter.DirectMessage instance + ''' + data = {} + if self.id: + data['id'] = self.id + if self.created_at: + data['created_at'] = self.created_at + if self.sender_id: + data['sender_id'] = self.sender_id + if self.sender_screen_name: + data['sender_screen_name'] = self.sender_screen_name + if self.recipient_id: + data['recipient_id'] = self.recipient_id + if self.recipient_screen_name: + data['recipient_screen_name'] = self.recipient_screen_name + if self.text: + data['text'] = self.text + return data + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: + A JSON dict, as converted from the JSON in the twitter API + + Returns: + A twitter.DirectMessage instance + ''' + return DirectMessage(created_at=data.get('created_at', None), + recipient_id=data.get('recipient_id', None), + sender_id=data.get('sender_id', None), + text=data.get('text', None), + sender_screen_name=data.get('sender_screen_name', None), + id=data.get('id', None), + recipient_screen_name=data.get('recipient_screen_name', None)) + +class Hashtag(object): + ''' A class represeinting a twitter hashtag + ''' + def __init__(self, + text=None): + self.text = text + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: + A JSON dict, as converted from the JSON in the twitter API + + Returns: + A twitter.Hashtag instance + ''' + return Hashtag(text = data.get('text', None)) + +class Trend(object): + ''' A class representing a trending topic + ''' + def __init__(self, name=None, query=None, timestamp=None): + self.name = name + self.query = query + self.timestamp = timestamp + + def __str__(self): + return 'Name: %s\nQuery: %s\nTimestamp: %s\n' % (self.name, self.query, self.timestamp) + + @staticmethod + def NewFromJsonDict(data, timestamp = None): + '''Create a new instance based on a JSON dict + + Args: + data: + A JSON dict + timestamp: + Gets set as the timestamp property of the new object + + Returns: + A twitter.Trend object + ''' + return Trend(name=data.get('name', None), + query=data.get('query', None), + timestamp=timestamp) + +class Url(object): + '''A class representing an URL contained in a tweet''' + def __init__(self, + url=None, + expanded_url=None): + self.url = url + self.expanded_url = expanded_url + + @staticmethod + def NewFromJsonDict(data): + '''Create a new instance based on a JSON dict. + + Args: + data: + A JSON dict, as converted from the JSON in the twitter API + + Returns: + A twitter.Url instance + ''' + return Url(url=data.get('url', None), + expanded_url=data.get('expanded_url', None)) + +class Api(object): + '''A python interface into the Twitter API + + By default, the Api caches results for 1 minute. + + Example usage: + + To create an instance of the twitter.Api class, with no authentication: + + >>> import twitter + >>> api = twitter.Api() + + To fetch the most recently posted public twitter status messages: + + >>> statuses = api.GetPublicTimeline() + >>> print [s.user.name for s in statuses] + [u'DeWitt', u'Kesuke Miyagi', u'ev', u'Buzz Andersen', u'Biz Stone'] #... + + To fetch a single user's public status messages, where "user" is either + a Twitter "short name" or their user id. + + >>> statuses = api.GetUserTimeline(user) + >>> print [s.text for s in statuses] + + To use authentication, instantiate the twitter.Api class with a + consumer key and secret; and the oAuth key and secret: + + >>> api = twitter.Api(consumer_key='twitter consumer key', + consumer_secret='twitter consumer secret', + access_token_key='the_key_given', + access_token_secret='the_key_secret') + + To fetch your friends (after being authenticated): + + >>> users = api.GetFriends() + >>> print [u.name for u in users] + + To post a twitter status message (after being authenticated): + + >>> status = api.PostUpdate('I love python-twitter!') + >>> print status.text + I love python-twitter! + + There are many other methods, including: + + >>> api.PostUpdates(status) + >>> api.PostDirectMessage(user, text) + >>> api.GetUser(user) + >>> api.GetReplies() + >>> api.GetUserTimeline(user) + >>> api.GetStatus(id) + >>> api.DestroyStatus(id) + >>> api.GetFriendsTimeline(user) + >>> api.GetFriends(user) + >>> api.GetFollowers() + >>> api.GetFeatured() + >>> api.GetDirectMessages() + >>> api.PostDirectMessage(user, text) + >>> api.DestroyDirectMessage(id) + >>> api.DestroyFriendship(user) + >>> api.CreateFriendship(user) + >>> api.GetUserByEmail(email) + >>> api.VerifyCredentials() + ''' + + DEFAULT_CACHE_TIMEOUT = 60 # cache for 1 minute + _API_REALM = 'Twitter API' + + def __init__(self, + consumer_key=None, + consumer_secret=None, + access_token_key=None, + access_token_secret=None, + input_encoding=None, + request_headers=None, + cache=DEFAULT_CACHE, + shortner=None, + base_url=None, + use_gzip_compression=False, + debugHTTP=False): + '''Instantiate a new twitter.Api object. + + Args: + consumer_key: + Your Twitter user's consumer_key. + consumer_secret: + Your Twitter user's consumer_secret. + access_token_key: + The oAuth access token key value you retrieved + from running get_access_token.py. + access_token_secret: + The oAuth access token's secret, also retrieved + from the get_access_token.py run. + input_encoding: + The encoding used to encode input strings. [Optional] + request_header: + A dictionary of additional HTTP request headers. [Optional] + cache: + The cache instance to use. Defaults to DEFAULT_CACHE. + Use None to disable caching. [Optional] + shortner: + The shortner instance to use. Defaults to None. + See shorten_url.py for an example shortner. [Optional] + base_url: + The base URL to use to contact the Twitter API. + Defaults to https://twitter.com. [Optional] + use_gzip_compression: + Set to True to tell enable gzip compression for any call + made to Twitter. Defaults to False. [Optional] + debugHTTP: + Set to True to enable debug output from urllib2 when performing + any HTTP requests. Defaults to False. [Optional] + ''' + self.SetCache(cache) + self._urllib = urllib2 + self._cache_timeout = Api.DEFAULT_CACHE_TIMEOUT + self._input_encoding = input_encoding + self._use_gzip = use_gzip_compression + self._debugHTTP = debugHTTP + self._oauth_consumer = None + + self._InitializeRequestHeaders(request_headers) + self._InitializeUserAgent() + self._InitializeDefaultParameters() + + if base_url is None: + self.base_url = 'https://api.twitter.com/1' + else: + self.base_url = base_url + + if consumer_key is not None and (access_token_key is None or + access_token_secret is None): + print >> sys.stderr, 'Twitter now requires an oAuth Access Token for API calls.' + print >> sys.stderr, 'If your using this library from a command line utility, please' + print >> sys.stderr, 'run the the included get_access_token.py tool to generate one.' + + raise TwitterError('Twitter requires oAuth Access Token for all API access') + + self.SetCredentials(consumer_key, consumer_secret, access_token_key, access_token_secret) + + def SetCredentials(self, + consumer_key, + consumer_secret, + access_token_key=None, + access_token_secret=None): + '''Set the consumer_key and consumer_secret for this instance + + Args: + consumer_key: + The consumer_key of the twitter account. + consumer_secret: + The consumer_secret for the twitter account. + access_token_key: + The oAuth access token key value you retrieved + from running get_access_token.py. + access_token_secret: + The oAuth access token's secret, also retrieved + from the get_access_token.py run. + ''' + self._consumer_key = consumer_key + self._consumer_secret = consumer_secret + self._access_token_key = access_token_key + self._access_token_secret = access_token_secret + self._oauth_consumer = None + + if consumer_key is not None and consumer_secret is not None and \ + access_token_key is not None and access_token_secret is not None: + self._signature_method_plaintext = oauth.SignatureMethod_PLAINTEXT() + self._signature_method_hmac_sha1 = oauth.SignatureMethod_HMAC_SHA1() + + self._oauth_token = oauth.Token(key=access_token_key, secret=access_token_secret) + self._oauth_consumer = oauth.Consumer(key=consumer_key, secret=consumer_secret) + + def ClearCredentials(self): + '''Clear the any credentials for this instance + ''' + self._consumer_key = None + self._consumer_secret = None + self._access_token_key = None + self._access_token_secret = None + self._oauth_consumer = None + + def GetPublicTimeline(self, + since_id=None, + include_rts=None, + include_entities=None): + '''Fetch the sequence of public twitter.Status message for all users. + + Args: + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + include_rts: + If True, the timeline will contain native retweets (if they + exist) in addition to the standard stream of tweets. [Optional] + include_entities: + If True, each tweet will include a node called "entities,". + This node offers a variety of metadata about the tweet in a + discreet structure, including: user_mentions, urls, and + hashtags. [Optional] + + Returns: + An sequence of twitter.Status instances, one for each message + ''' + parameters = {} + + if since_id: + parameters['since_id'] = since_id + if include_rts: + parameters['include_rts'] = 1 + if include_entities: + parameters['include_entities'] = 1 + + url = '%s/statuses/public_timeline.json' % self.base_url + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def FilterPublicTimeline(self, + term, + since_id=None): + '''Filter the public twitter timeline by a given search term on + the local machine. + + Args: + term: + term to search by. + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + + Returns: + A sequence of twitter.Status instances, one for each message + containing the term + ''' + statuses = self.GetPublicTimeline(since_id) + results = [] + + for s in statuses: + if s.text.lower().find(term.lower()) != -1: + results.append(s) + + return results + + def GetSearch(self, + term=None, + geocode=None, + since_id=None, + per_page=15, + page=1, + lang="en", + show_user="true", + query_users=False): + '''Return twitter search results for a given term. + + Args: + term: + term to search by. Optional if you include geocode. + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + geocode: + geolocation information in the form (latitude, longitude, radius) + [Optional] + per_page: + number of results to return. Default is 15 [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + lang: + language for results. Default is English [Optional] + show_user: + prefixes screen name in status + query_users: + If set to False, then all users only have screen_name and + profile_image_url available. + If set to True, all information of users are available, + but it uses lots of request quota, one per status. + + Returns: + A sequence of twitter.Status instances, one for each message containing + the term + ''' + # Build request parameters + parameters = {} + + if since_id: + parameters['since_id'] = since_id + + if term is None and geocode is None: + return [] + + if term is not None: + parameters['q'] = term + + if geocode is not None: + parameters['geocode'] = ','.join(map(str, geocode)) + + parameters['show_user'] = show_user + parameters['lang'] = lang + parameters['rpp'] = per_page + parameters['page'] = page + + # Make and send requests + url = 'http://search.twitter.com/search.json' + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + + results = [] + + for x in data['results']: + temp = Status.NewFromJsonDict(x) + + if query_users: + # Build user object with new request + temp.user = self.GetUser(urllib.quote(x['from_user'])) + else: + temp.user = User(screen_name=x['from_user'], profile_image_url=x['profile_image_url']) + + results.append(temp) + + # Return built list of statuses + return results # [Status.NewFromJsonDict(x) for x in data['results']] + + def GetTrendsCurrent(self, exclude=None): + '''Get the current top trending topics + + Args: + exclude: + Appends the exclude parameter as a request parameter. + Currently only exclude=hashtags is supported. [Optional] + + Returns: + A list with 10 entries. Each entry contains the twitter. + ''' + parameters = {} + if exclude: + parameters['exclude'] = exclude + url = '%s/trends/current.json' % self.base_url + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + + trends = [] + + for t in data['trends']: + for item in data['trends'][t]: + trends.append(Trend.NewFromJsonDict(item, timestamp = t)) + return trends + + def GetTrendsDaily(self, exclude=None, startdate=None): + '''Get the current top trending topics for each hour in a given day + + Args: + startdate: + The start date for the report. + Should be in the format YYYY-MM-DD. [Optional] + exclude: + Appends the exclude parameter as a request parameter. + Currently only exclude=hashtags is supported. [Optional] + + Returns: + A list with 24 entries. Each entry contains the twitter. + Trend elements that were trending at the corresponding hour of the day. + ''' + parameters = {} + if exclude: + parameters['exclude'] = exclude + if not startdate: + startdate = time.strftime('%Y-%m-%d', time.gmtime()) + parameters['date'] = startdate + url = '%s/trends/daily.json' % self.base_url + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + + trends = [] + + for i in xrange(24): + trends.append(None) + for t in data['trends']: + idx = int(time.strftime('%H', time.strptime(t, '%Y-%m-%d %H:%M'))) + trends[idx] = [Trend.NewFromJsonDict(x, timestamp = t) + for x in data['trends'][t]] + return trends + + def GetTrendsWeekly(self, exclude=None, startdate=None): + '''Get the top 30 trending topics for each day in a given week. + + Args: + startdate: + The start date for the report. + Should be in the format YYYY-MM-DD. [Optional] + exclude: + Appends the exclude parameter as a request parameter. + Currently only exclude=hashtags is supported. [Optional] + Returns: + A list with each entry contains the twitter. + Trend elements of trending topics for the corrsponding day of the week + ''' + parameters = {} + if exclude: + parameters['exclude'] = exclude + if not startdate: + startdate = time.strftime('%Y-%m-%d', time.gmtime()) + parameters['date'] = startdate + url = '%s/trends/weekly.json' % self.base_url + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + + trends = [] + + for i in xrange(7): + trends.append(None) + # use the epochs of the dates as keys for a dictionary + times = dict([(calendar.timegm(time.strptime(t, '%Y-%m-%d')),t) + for t in data['trends']]) + cnt = 0 + # create the resulting structure ordered by the epochs of the dates + for e in sorted(times.keys()): + trends[cnt] = [Trend.NewFromJsonDict(x, timestamp = times[e]) + for x in data['trends'][times[e]]] + cnt +=1 + return trends + + def GetFriendsTimeline(self, + user=None, + count=None, + page=None, + since_id=None, + retweets=None, + include_entities=None): + '''Fetch the sequence of twitter.Status messages for a user's friends + + The twitter.Api instance must be authenticated if the user is private. + + Args: + user: + Specifies the ID or screen name of the user for whom to return + the friends_timeline. If not specified then the authenticated + user set in the twitter.Api instance will be used. [Optional] + count: + Specifies the number of statuses to retrieve. May not be + greater than 100. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + retweets: + If True, the timeline will contain native retweets. [Optional] + include_entities: + If True, each tweet will include a node called "entities,". + This node offers a variety of metadata about the tweet in a + discreet structure, including: user_mentions, urls, and + hashtags. [Optional] + + Returns: + A sequence of twitter.Status instances, one for each message + ''' + if not user and not self._oauth_consumer: + raise TwitterError("User must be specified if API is not authenticated.") + url = '%s/statuses/friends_timeline' % self.base_url + if user: + url = '%s/%s.json' % (url, user) + else: + url = '%s.json' % url + parameters = {} + if count is not None: + try: + if int(count) > 100: + raise TwitterError("'count' may not be greater than 100") + except ValueError: + raise TwitterError("'count' must be an integer") + parameters['count'] = count + if page is not None: + try: + parameters['page'] = int(page) + except ValueError: + raise TwitterError("'page' must be an integer") + if since_id: + parameters['since_id'] = since_id + if retweets: + parameters['include_rts'] = True + if include_entities: + parameters['include_entities'] = True + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def GetUserTimeline(self, + id=None, + user_id=None, + screen_name=None, + since_id=None, + max_id=None, + count=None, + page=None, + include_rts=None, + include_entities=None): + '''Fetch the sequence of public Status messages for a single user. + + The twitter.Api instance must be authenticated if the user is private. + + Args: + id: + Specifies the ID or screen name of the user for whom to return + the user_timeline. [Optional] + user_id: + Specfies the ID of the user for whom to return the + user_timeline. Helpful for disambiguating when a valid user ID + is also a valid screen name. [Optional] + screen_name: + Specfies the screen name of the user for whom to return the + user_timeline. Helpful for disambiguating when a valid screen + name is also a user ID. [Optional] + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + max_id: + Returns only statuses with an ID less than (that is, older + than) or equal to the specified ID. [Optional] + count: + Specifies the number of statuses to retrieve. May not be + greater than 200. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + include_rts: + If True, the timeline will contain native retweets (if they + exist) in addition to the standard stream of tweets. [Optional] + include_entities: + If True, each tweet will include a node called "entities,". + This node offers a variety of metadata about the tweet in a + discreet structure, including: user_mentions, urls, and + hashtags. [Optional] + + Returns: + A sequence of Status instances, one for each message up to count + ''' + parameters = {} + + if id: + url = '%s/statuses/user_timeline/%s.json' % (self.base_url, id) + elif user_id: + url = '%s/statuses/user_timeline.json?user_id=%d' % (self.base_url, user_id) + elif screen_name: + url = ('%s/statuses/user_timeline.json?screen_name=%s' % (self.base_url, + screen_name)) + elif not self._oauth_consumer: + raise TwitterError("User must be specified if API is not authenticated.") + else: + url = '%s/statuses/user_timeline.json' % self.base_url + + if since_id: + try: + parameters['since_id'] = long(since_id) + except: + raise TwitterError("since_id must be an integer") + + if max_id: + try: + parameters['max_id'] = long(max_id) + except: + raise TwitterError("max_id must be an integer") + + if count: + try: + parameters['count'] = int(count) + except: + raise TwitterError("count must be an integer") + + if page: + try: + parameters['page'] = int(page) + except: + raise TwitterError("page must be an integer") + + if include_rts: + parameters['include_rts'] = 1 + + if include_entities: + parameters['include_entities'] = 1 + + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def GetStatus(self, id): + '''Returns a single status message. + + The twitter.Api instance must be authenticated if the + status message is private. + + Args: + id: + The numeric ID of the status you are trying to retrieve. + + Returns: + A twitter.Status instance representing that status message + ''' + try: + if id: + long(id) + except: + raise TwitterError("id must be an long integer") + url = '%s/statuses/show/%s.json' % (self.base_url, id) + json = self._FetchUrl(url) + data = self._ParseAndCheckTwitter(json) + return Status.NewFromJsonDict(data) + + def DestroyStatus(self, id): + '''Destroys the status specified by the required ID parameter. + + The twitter.Api instance must be authenticated and the + authenticating user must be the author of the specified status. + + Args: + id: + The numerical ID of the status you're trying to destroy. + + Returns: + A twitter.Status instance representing the destroyed status message + ''' + try: + if id: + long(id) + except: + raise TwitterError("id must be an integer") + url = '%s/statuses/destroy/%s.json' % (self.base_url, id) + json = self._FetchUrl(url, post_data={'id': id}) + data = self._ParseAndCheckTwitter(json) + return Status.NewFromJsonDict(data) + + def PostUpdate(self, status, in_reply_to_status_id=None): + '''Post a twitter status message from the authenticated user. + + The twitter.Api instance must be authenticated. + + Args: + status: + The message text to be posted. + Must be less than or equal to 140 characters. + in_reply_to_status_id: + The ID of an existing status that the status to be posted is + in reply to. This implicitly sets the in_reply_to_user_id + attribute of the resulting status to the user ID of the + message being replied to. Invalid/missing status IDs will be + ignored. [Optional] + Returns: + A twitter.Status instance representing the message posted. + ''' + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + + url = '%s/statuses/update.json' % self.base_url + + if isinstance(status, unicode) or self._input_encoding is None: + u_status = status + else: + u_status = unicode(status, self._input_encoding) + + if len(u_status) > CHARACTER_LIMIT: + raise TwitterError("Text must be less than or equal to %d characters. " + "Consider using PostUpdates." % CHARACTER_LIMIT) + + data = {'status': status} + if in_reply_to_status_id: + data['in_reply_to_status_id'] = in_reply_to_status_id + json = self._FetchUrl(url, post_data=data) + data = self._ParseAndCheckTwitter(json) + return Status.NewFromJsonDict(data) + + def PostUpdates(self, status, continuation=None, **kwargs): + '''Post one or more twitter status messages from the authenticated user. + + Unlike api.PostUpdate, this method will post multiple status updates + if the message is longer than 140 characters. + + The twitter.Api instance must be authenticated. + + Args: + status: + The message text to be posted. + May be longer than 140 characters. + continuation: + The character string, if any, to be appended to all but the + last message. Note that Twitter strips trailing '...' strings + from messages. Consider using the unicode \u2026 character + (horizontal ellipsis) instead. [Defaults to None] + **kwargs: + See api.PostUpdate for a list of accepted parameters. + + Returns: + A of list twitter.Status instance representing the messages posted. + ''' + results = list() + if continuation is None: + continuation = '' + line_length = CHARACTER_LIMIT - len(continuation) + lines = textwrap.wrap(status, line_length) + for line in lines[0:-1]: + results.append(self.PostUpdate(line + continuation, **kwargs)) + results.append(self.PostUpdate(lines[-1], **kwargs)) + return results + + def GetUserRetweets(self, count=None, since_id=None, max_id=None, include_entities=False): + '''Fetch the sequence of retweets made by a single user. + + The twitter.Api instance must be authenticated. + + Args: + count: + The number of status messages to retrieve. [Optional] + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + max_id: + Returns results with an ID less than (that is, older than) or + equal to the specified ID. [Optional] + include_entities: + If True, each tweet will include a node called "entities,". + This node offers a variety of metadata about the tweet in a + discreet structure, including: user_mentions, urls, and + hashtags. [Optional] + + Returns: + A sequence of twitter.Status instances, one for each message up to count + ''' + url = '%s/statuses/retweeted_by_me.json' % self.base_url + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + parameters = {} + if count is not None: + try: + if int(count) > 100: + raise TwitterError("'count' may not be greater than 100") + except ValueError: + raise TwitterError("'count' must be an integer") + if count: + parameters['count'] = count + if since_id: + parameters['since_id'] = since_id + if include_entities: + parameters['include_entities'] = True + if max_id: + try: + parameters['max_id'] = long(max_id) + except: + raise TwitterError("max_id must be an integer") + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def GetReplies(self, since=None, since_id=None, page=None): + '''Get a sequence of status messages representing the 20 most + recent replies (status updates prefixed with @twitterID) to the + authenticating user. + + Args: + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + since: + + Returns: + A sequence of twitter.Status instances, one for each reply to the user. + ''' + url = '%s/statuses/replies.json' % self.base_url + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + parameters = {} + if since: + parameters['since'] = since + if since_id: + parameters['since_id'] = since_id + if page: + parameters['page'] = page + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def GetRetweets(self, statusid): + '''Returns up to 100 of the first retweets of the tweet identified + by statusid + + Args: + statusid: + The ID of the tweet for which retweets should be searched for + + Returns: + A list of twitter.Status instances, which are retweets of statusid + ''' + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instsance must be authenticated.") + url = '%s/statuses/retweets/%s.json?include_entities=true&include_rts=true' % (self.base_url, statusid) + parameters = {} + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(s) for s in data] + + def GetFriends(self, user=None, cursor=-1): + '''Fetch the sequence of twitter.User instances, one for each friend. + + The twitter.Api instance must be authenticated. + + Args: + user: + The twitter name or id of the user whose friends you are fetching. + If not specified, defaults to the authenticated user. [Optional] + + Returns: + A sequence of twitter.User instances, one for each friend + ''' + if not user and not self._oauth_consumer: + raise TwitterError("twitter.Api instance must be authenticated") + if user: + url = '%s/statuses/friends/%s.json' % (self.base_url, user) + else: + url = '%s/statuses/friends.json' % self.base_url + parameters = {} + parameters['cursor'] = cursor + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [User.NewFromJsonDict(x) for x in data['users']] + + def GetFriendIDs(self, user=None, cursor=-1): + '''Returns a list of twitter user id's for every person + the specified user is following. + + Args: + user: + The id or screen_name of the user to retrieve the id list for + [Optional] + + Returns: + A list of integers, one for each user id. + ''' + if not user and not self._oauth_consumer: + raise TwitterError("twitter.Api instance must be authenticated") + if user: + url = '%s/friends/ids/%s.json' % (self.base_url, user) + else: + url = '%s/friends/ids.json' % self.base_url + parameters = {} + parameters['cursor'] = cursor + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return data + + def GetFollowerIDs(self, userid=None, cursor=-1): + '''Fetch the sequence of twitter.User instances, one for each follower + + The twitter.Api instance must be authenticated. + + Returns: + A sequence of twitter.User instances, one for each follower + ''' + url = 'http://twitter.com/followers/ids.json' + parameters = {} + parameters['cursor'] = cursor + if userid: + parameters['user_id'] = userid + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return data + + def GetFollowers(self, user=None, page=None): + '''Fetch the sequence of twitter.User instances, one for each follower + + The twitter.Api instance must be authenticated. + + Args: + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + + Returns: + A sequence of twitter.User instances, one for each follower + ''' + if not self._oauth_consumer: + raise TwitterError("twitter.Api instance must be authenticated") + #url = '%s/statuses/followers.json' % self.base_url + if user: + url = '%s/statuses/followers/%s.json' % (self.base_url, user) + else: + url = '%s/statuses/followers.json' % self.base_url + parameters = {} + if page: + parameters['page'] = page + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [User.NewFromJsonDict(x) for x in data] + + def GetFeatured(self): + '''Fetch the sequence of twitter.User instances featured on twitter.com + + The twitter.Api instance must be authenticated. + + Returns: + A sequence of twitter.User instances + ''' + url = '%s/statuses/featured.json' % self.base_url + json = self._FetchUrl(url) + data = self._ParseAndCheckTwitter(json) + return [User.NewFromJsonDict(x) for x in data] + + def UsersLookup(self, user_id=None, screen_name=None, users=None): + '''Fetch extended information for the specified users. + + Users may be specified either as lists of either user_ids, + screen_names, or twitter.User objects. The list of users that + are queried is the union of all specified parameters. + + The twitter.Api instance must be authenticated. + + Args: + user_id: + A list of user_ids to retrieve extended information. + [Optional] + screen_name: + A list of screen_names to retrieve extended information. + [Optional] + users: + A list of twitter.User objects to retrieve extended information. + [Optional] + + Returns: + A list of twitter.User objects for the requested users + ''' + + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + if not user_id and not screen_name and not users: + raise TwitterError("Specify at least on of user_id, screen_name, or users.") + url = '%s/users/lookup.json' % self.base_url + parameters = {} + uids = list() + if user_id: + uids.extend(user_id) + if users: + uids.extend([u.id for u in users]) + if len(uids): + parameters['user_id'] = ','.join(["%s" % u for u in uids]) + if screen_name: + parameters['screen_name'] = ','.join(screen_name) + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [User.NewFromJsonDict(u) for u in data] + + def GetUser(self, user): + '''Returns a single user. + + The twitter.Api instance must be authenticated. + + Args: + user: The twitter name or id of the user to retrieve. + + Returns: + A twitter.User instance representing that user + ''' + url = '%s/users/show/%s.json' % (self.base_url, user) + json = self._FetchUrl(url) + data = self._ParseAndCheckTwitter(json) + return User.NewFromJsonDict(data) + + def GetDirectMessages(self, since=None, since_id=None, page=None): + '''Returns a list of the direct messages sent to the authenticating user. + + The twitter.Api instance must be authenticated. + + Args: + since: + Narrows the returned results to just those statuses created + after the specified HTTP-formatted date. [Optional] + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + + Returns: + A sequence of twitter.DirectMessage instances + ''' + url = '%s/direct_messages.json' % self.base_url + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + parameters = {} + if since: + parameters['since'] = since + if since_id: + parameters['since_id'] = since_id + if page: + parameters['page'] = page + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [DirectMessage.NewFromJsonDict(x) for x in data] + + def PostDirectMessage(self, user, text): + '''Post a twitter direct message from the authenticated user + + The twitter.Api instance must be authenticated. + + Args: + user: The ID or screen name of the recipient user. + text: The message text to be posted. Must be less than 140 characters. + + Returns: + A twitter.DirectMessage instance representing the message posted + ''' + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + url = '%s/direct_messages/new.json' % self.base_url + data = {'text': text, 'user': user} + json = self._FetchUrl(url, post_data=data) + data = self._ParseAndCheckTwitter(json) + return DirectMessage.NewFromJsonDict(data) + + def DestroyDirectMessage(self, id): + '''Destroys the direct message specified in the required ID parameter. + + The twitter.Api instance must be authenticated, and the + authenticating user must be the recipient of the specified direct + message. + + Args: + id: The id of the direct message to be destroyed + + Returns: + A twitter.DirectMessage instance representing the message destroyed + ''' + url = '%s/direct_messages/destroy/%s.json' % (self.base_url, id) + json = self._FetchUrl(url, post_data={'id': id}) + data = self._ParseAndCheckTwitter(json) + return DirectMessage.NewFromJsonDict(data) + + def CreateFriendship(self, user): + '''Befriends the user specified in the user parameter as the authenticating user. + + The twitter.Api instance must be authenticated. + + Args: + The ID or screen name of the user to befriend. + Returns: + A twitter.User instance representing the befriended user. + ''' + url = '%s/friendships/create/%s.json' % (self.base_url, user) + json = self._FetchUrl(url, post_data={'user': user}) + data = self._ParseAndCheckTwitter(json) + return User.NewFromJsonDict(data) + + def DestroyFriendship(self, user): + '''Discontinues friendship with the user specified in the user parameter. + + The twitter.Api instance must be authenticated. + + Args: + The ID or screen name of the user with whom to discontinue friendship. + Returns: + A twitter.User instance representing the discontinued friend. + ''' + url = '%s/friendships/destroy/%s.json' % (self.base_url, user) + json = self._FetchUrl(url, post_data={'user': user}) + data = self._ParseAndCheckTwitter(json) + return User.NewFromJsonDict(data) + + def CreateFavorite(self, status): + '''Favorites the status specified in the status parameter as the authenticating user. + Returns the favorite status when successful. + + The twitter.Api instance must be authenticated. + + Args: + The twitter.Status instance to mark as a favorite. + Returns: + A twitter.Status instance representing the newly-marked favorite. + ''' + url = '%s/favorites/create/%s.json' % (self.base_url, status.id) + json = self._FetchUrl(url, post_data={'id': status.id}) + data = self._ParseAndCheckTwitter(json) + return Status.NewFromJsonDict(data) + + def DestroyFavorite(self, status): + '''Un-favorites the status specified in the ID parameter as the authenticating user. + Returns the un-favorited status in the requested format when successful. + + The twitter.Api instance must be authenticated. + + Args: + The twitter.Status to unmark as a favorite. + Returns: + A twitter.Status instance representing the newly-unmarked favorite. + ''' + url = '%s/favorites/destroy/%s.json' % (self.base_url, status.id) + json = self._FetchUrl(url, post_data={'id': status.id}) + data = self._ParseAndCheckTwitter(json) + return Status.NewFromJsonDict(data) + + def GetFavorites(self, + user=None, + page=None): + '''Return a list of Status objects representing favorited tweets. + By default, returns the (up to) 20 most recent tweets for the + authenticated user. + + Args: + user: + The twitter name or id of the user whose favorites you are fetching. + If not specified, defaults to the authenticated user. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + ''' + parameters = {} + + if page: + parameters['page'] = page + + if user: + url = '%s/favorites/%s.json' % (self.base_url, user) + elif not user and not self._oauth_consumer: + raise TwitterError("User must be specified if API is not authenticated.") + else: + url = '%s/favorites.json' % self.base_url + + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def GetMentions(self, + since_id=None, + max_id=None, + page=None): + '''Returns the 20 most recent mentions (status containing @twitterID) + for the authenticating user. + + Args: + since_id: + Returns results with an ID greater than (that is, more recent + than) the specified ID. There are limits to the number of + Tweets which can be accessed through the API. If the limit of + Tweets has occured since the since_id, the since_id will be + forced to the oldest ID available. [Optional] + max_id: + Returns only statuses with an ID less than + (that is, older than) the specified ID. [Optional] + page: + Specifies the page of results to retrieve. + Note: there are pagination limits. [Optional] + + Returns: + A sequence of twitter.Status instances, one for each mention of the user. + ''' + + url = '%s/statuses/mentions.json' % self.base_url + + if not self._oauth_consumer: + raise TwitterError("The twitter.Api instance must be authenticated.") + + parameters = {} + + if since_id: + parameters['since_id'] = since_id + if max_id: + parameters['max_id'] = max_id + if page: + parameters['page'] = page + + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [Status.NewFromJsonDict(x) for x in data] + + def CreateList(self, user, name, mode=None, description=None): + '''Creates a new list with the give name + + The twitter.Api instance must be authenticated. + + Args: + user: + Twitter name to create the list for + name: + New name for the list + mode: + 'public' or 'private'. + Defaults to 'public'. [Optional] + description: + Description of the list. [Optional] + + Returns: + A twitter.List instance representing the new list + ''' + url = '%s/%s/lists.json' % (self.base_url, user) + parameters = {'name': name} + if mode is not None: + parameters['mode'] = mode + if description is not None: + parameters['description'] = description + json = self._FetchUrl(url, post_data=parameters) + data = self._ParseAndCheckTwitter(json) + return List.NewFromJsonDict(data) + + def DestroyList(self, user, id): + '''Destroys the list from the given user + + The twitter.Api instance must be authenticated. + + Args: + user: + The user to remove the list from. + id: + The slug or id of the list to remove. + Returns: + A twitter.List instance representing the removed list. + ''' + url = '%s/%s/lists/%s.json' % (self.base_url, user, id) + json = self._FetchUrl(url, post_data={'_method': 'DELETE'}) + data = self._ParseAndCheckTwitter(json) + return List.NewFromJsonDict(data) + + def CreateSubscription(self, owner, list): + '''Creates a subscription to a list by the authenticated user + + The twitter.Api instance must be authenticated. + + Args: + owner: + User name or id of the owner of the list being subscribed to. + list: + The slug or list id to subscribe the user to + + Returns: + A twitter.List instance representing the list subscribed to + ''' + url = '%s/%s/%s/subscribers.json' % (self.base_url, owner, list) + json = self._FetchUrl(url, post_data={'list_id': list}) + data = self._ParseAndCheckTwitter(json) + return List.NewFromJsonDict(data) + + def DestroySubscription(self, owner, list): + '''Destroys the subscription to a list for the authenticated user + + The twitter.Api instance must be authenticated. + + Args: + owner: + The user id or screen name of the user that owns the + list that is to be unsubscribed from + list: + The slug or list id of the list to unsubscribe from + + Returns: + A twitter.List instance representing the removed list. + ''' + url = '%s/%s/%s/subscribers.json' % (self.base_url, owner, list) + json = self._FetchUrl(url, post_data={'_method': 'DELETE', 'list_id': list}) + data = self._ParseAndCheckTwitter(json) + return List.NewFromJsonDict(data) + + def GetSubscriptions(self, user, cursor=-1): + '''Fetch the sequence of Lists that the given user is subscribed to + + The twitter.Api instance must be authenticated. + + Args: + user: + The twitter name or id of the user + cursor: + "page" value that Twitter will use to start building the + list sequence from. -1 to start at the beginning. + Twitter will return in the result the values for next_cursor + and previous_cursor. [Optional] + + Returns: + A sequence of twitter.List instances, one for each list + ''' + if not self._oauth_consumer: + raise TwitterError("twitter.Api instance must be authenticated") + + url = '%s/%s/lists/subscriptions.json' % (self.base_url, user) + parameters = {} + parameters['cursor'] = cursor + + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [List.NewFromJsonDict(x) for x in data['lists']] + + def GetLists(self, user, cursor=-1): + '''Fetch the sequence of lists for a user. + + The twitter.Api instance must be authenticated. + + Args: + user: + The twitter name or id of the user whose friends you are fetching. + If the passed in user is the same as the authenticated user + then you will also receive private list data. + cursor: + "page" value that Twitter will use to start building the + list sequence from. -1 to start at the beginning. + Twitter will return in the result the values for next_cursor + and previous_cursor. [Optional] + + Returns: + A sequence of twitter.List instances, one for each list + ''' + if not self._oauth_consumer: + raise TwitterError("twitter.Api instance must be authenticated") + + url = '%s/%s/lists.json' % (self.base_url, user) + parameters = {} + parameters['cursor'] = cursor + + json = self._FetchUrl(url, parameters=parameters) + data = self._ParseAndCheckTwitter(json) + return [List.NewFromJsonDict(x) for x in data['lists']] + + def GetUserByEmail(self, email): + '''Returns a single user by email address. + + Args: + email: + The email of the user to retrieve. + + Returns: + A twitter.User instance representing that user + ''' + url = '%s/users/show.json?email=%s' % (self.base_url, email) + json = self._FetchUrl(url) + data = self._ParseAndCheckTwitter(json) + return User.NewFromJsonDict(data) + + def VerifyCredentials(self): + '''Returns a twitter.User instance if the authenticating user is valid. + + Returns: + A twitter.User instance representing that user if the + credentials are valid, None otherwise. + ''' + if not self._oauth_consumer: + raise TwitterError("Api instance must first be given user credentials.") + url = '%s/account/verify_credentials.json' % self.base_url + try: + json = self._FetchUrl(url, no_cache=True) + except urllib2.HTTPError, http_error: + if http_error.code == httplib.UNAUTHORIZED: + return None + else: + raise http_error + data = self._ParseAndCheckTwitter(json) + return User.NewFromJsonDict(data) + + def SetCache(self, cache): + '''Override the default cache. Set to None to prevent caching. + + Args: + cache: + An instance that supports the same API as the twitter._FileCache + ''' + if cache == DEFAULT_CACHE: + self._cache = _FileCache() + else: + self._cache = cache + + def SetUrllib(self, urllib): + '''Override the default urllib implementation. + + Args: + urllib: + An instance that supports the same API as the urllib2 module + ''' + self._urllib = urllib + + def SetCacheTimeout(self, cache_timeout): + '''Override the default cache timeout. + + Args: + cache_timeout: + Time, in seconds, that responses should be reused. + ''' + self._cache_timeout = cache_timeout + + def SetUserAgent(self, user_agent): + '''Override the default user agent + + Args: + user_agent: + A string that should be send to the server as the User-agent + ''' + self._request_headers['User-Agent'] = user_agent + + def SetXTwitterHeaders(self, client, url, version): + '''Set the X-Twitter HTTP headers that will be sent to the server. + + Args: + client: + The client name as a string. Will be sent to the server as + the 'X-Twitter-Client' header. + url: + The URL of the meta.xml as a string. Will be sent to the server + as the 'X-Twitter-Client-URL' header. + version: + The client version as a string. Will be sent to the server + as the 'X-Twitter-Client-Version' header. + ''' + self._request_headers['X-Twitter-Client'] = client + self._request_headers['X-Twitter-Client-URL'] = url + self._request_headers['X-Twitter-Client-Version'] = version + + def SetSource(self, source): + '''Suggest the "from source" value to be displayed on the Twitter web site. + + The value of the 'source' parameter must be first recognized by + the Twitter server. New source values are authorized on a case by + case basis by the Twitter development team. + + Args: + source: + The source name as a string. Will be sent to the server as + the 'source' parameter. + ''' + self._default_params['source'] = source + + def GetRateLimitStatus(self): + '''Fetch the rate limit status for the currently authorized user. + + Returns: + A dictionary containing the time the limit will reset (reset_time), + the number of remaining hits allowed before the reset (remaining_hits), + the number of hits allowed in a 60-minute period (hourly_limit), and + the time of the reset in seconds since The Epoch (reset_time_in_seconds). + ''' + url = '%s/account/rate_limit_status.json' % self.base_url + json = self._FetchUrl(url, no_cache=True) + data = self._ParseAndCheckTwitter(json) + return data + + def MaximumHitFrequency(self): + '''Determines the minimum number of seconds that a program must wait + before hitting the server again without exceeding the rate_limit + imposed for the currently authenticated user. + + Returns: + The minimum second interval that a program must use so as to not + exceed the rate_limit imposed for the user. + ''' + rate_status = self.GetRateLimitStatus() + reset_time = rate_status.get('reset_time', None) + limit = rate_status.get('remaining_hits', None) + + if reset_time: + # put the reset time into a datetime object + reset = datetime.datetime(*rfc822.parsedate(reset_time)[:7]) + + # find the difference in time between now and the reset time + 1 hour + delta = reset + datetime.timedelta(hours=1) - datetime.datetime.utcnow() + + if not limit: + return int(delta.seconds) + + # determine the minimum number of seconds allowed as a regular interval + max_frequency = int(delta.seconds / limit) + 1 + + # return the number of seconds + return max_frequency + + return 60 + + def _BuildUrl(self, url, path_elements=None, extra_params=None): + # Break url into consituent parts + (scheme, netloc, path, params, query, fragment) = urlparse.urlparse(url) + + # Add any additional path elements to the path + if path_elements: + # Filter out the path elements that have a value of None + p = [i for i in path_elements if i] + if not path.endswith('/'): + path += '/' + path += '/'.join(p) + + # Add any additional query parameters to the query string + if extra_params and len(extra_params) > 0: + extra_query = self._EncodeParameters(extra_params) + # Add it to the existing query + if query: + query += '&' + extra_query + else: + query = extra_query + + # Return the rebuilt URL + return urlparse.urlunparse((scheme, netloc, path, params, query, fragment)) + + def _InitializeRequestHeaders(self, request_headers): + if request_headers: + self._request_headers = request_headers + else: + self._request_headers = {} + + def _InitializeUserAgent(self): + user_agent = 'Python-urllib/%s (python-twitter/%s)' % \ + (self._urllib.__version__, __version__) + self.SetUserAgent(user_agent) + + def _InitializeDefaultParameters(self): + self._default_params = {} + + def _DecompressGzippedResponse(self, response): + raw_data = response.read() + if response.headers.get('content-encoding', None) == 'gzip': + url_data = gzip.GzipFile(fileobj=StringIO.StringIO(raw_data)).read() + else: + url_data = raw_data + return url_data + + def _Encode(self, s): + if self._input_encoding: + return unicode(s, self._input_encoding).encode('utf-8') + else: + return unicode(s).encode('utf-8') + + def _EncodeParameters(self, parameters): + '''Return a string in key=value&key=value form + + Values of None are not included in the output string. + + Args: + parameters: + A dict of (key, value) tuples, where value is encoded as + specified by self._encoding + + Returns: + A URL-encoded string in "key=value&key=value" form + ''' + if parameters is None: + return None + else: + return urllib.urlencode(dict([(k, self._Encode(v)) for k, v in parameters.items() if v is not None])) + + def _EncodePostData(self, post_data): + '''Return a string in key=value&key=value form + + Values are assumed to be encoded in the format specified by self._encoding, + and are subsequently URL encoded. + + Args: + post_data: + A dict of (key, value) tuples, where value is encoded as + specified by self._encoding + + Returns: + A URL-encoded string in "key=value&key=value" form + ''' + if post_data is None: + return None + else: + return urllib.urlencode(dict([(k, self._Encode(v)) for k, v in post_data.items()])) + + def _ParseAndCheckTwitter(self, json): + """Try and parse the JSON returned from Twitter and return + an empty dictionary if there is any error. This is a purely + defensive check because during some Twitter network outages + it will return an HTML failwhale page.""" + try: + data = simplejson.loads(json) + self._CheckForTwitterError(data) + except ValueError: + if "Twitter / Over capacity" in json: + raise TwitterError("Capacity Error") + if "Twitter / Error" in json: + raise TwitterError("Technical Error") + raise TwitterError("json decoding") + + return data + + def _CheckForTwitterError(self, data): + """Raises a TwitterError if twitter returns an error message. + + Args: + data: + A python dict created from the Twitter json response + + Raises: + TwitterError wrapping the twitter error message if one exists. + """ + # Twitter errors are relatively unlikely, so it is faster + # to check first, rather than try and catch the exception + if 'error' in data: + raise TwitterError(data['error']) + + def _FetchUrl(self, + url, + post_data=None, + parameters=None, + no_cache=None, + use_gzip_compression=None): + '''Fetch a URL, optionally caching for a specified time. + + Args: + url: + The URL to retrieve + post_data: + A dict of (str, unicode) key/value pairs. + If set, POST will be used. + parameters: + A dict whose key/value pairs should encoded and added + to the query string. [Optional] + no_cache: + If true, overrides the cache on the current request + use_gzip_compression: + If True, tells the server to gzip-compress the response. + It does not apply to POST requests. + Defaults to None, which will get the value to use from + the instance variable self._use_gzip [Optional] + + Returns: + A string containing the body of the response. + ''' + # Build the extra parameters dict + extra_params = {} + if self._default_params: + extra_params.update(self._default_params) + if parameters: + extra_params.update(parameters) + + if post_data: + http_method = "POST" + else: + http_method = "GET" + + if self._debugHTTP: + _debug = 1 + else: + _debug = 0 + + http_handler = self._urllib.HTTPHandler(debuglevel=_debug) + https_handler = self._urllib.HTTPSHandler(debuglevel=_debug) + + opener = self._urllib.OpenerDirector() + opener.add_handler(http_handler) + opener.add_handler(https_handler) + + if use_gzip_compression is None: + use_gzip = self._use_gzip + else: + use_gzip = use_gzip_compression + + # Set up compression + if use_gzip and not post_data: + opener.addheaders.append(('Accept-Encoding', 'gzip')) + + if self._oauth_consumer is not None: + if post_data and http_method == "POST": + parameters = post_data.copy() + + req = oauth.Request.from_consumer_and_token(self._oauth_consumer, + token=self._oauth_token, + http_method=http_method, + http_url=url, parameters=parameters) + + req.sign_request(self._signature_method_hmac_sha1, self._oauth_consumer, self._oauth_token) + + headers = req.to_header() + + if http_method == "POST": + encoded_post_data = req.to_postdata() + else: + encoded_post_data = None + url = req.to_url() + else: + url = self._BuildUrl(url, extra_params=extra_params) + encoded_post_data = self._EncodePostData(post_data) + + # Open and return the URL immediately if we're not going to cache + if encoded_post_data or no_cache or not self._cache or not self._cache_timeout: + response = opener.open(url, encoded_post_data) + url_data = self._DecompressGzippedResponse(response) + opener.close() + else: + # Unique keys are a combination of the url and the oAuth Consumer Key + if self._consumer_key: + key = self._consumer_key + ':' + url + else: + key = url + + # See if it has been cached before + last_cached = self._cache.GetCachedTime(key) + + # If the cached version is outdated then fetch another and store it + if not last_cached or time.time() >= last_cached + self._cache_timeout: + try: + response = opener.open(url, encoded_post_data) + url_data = self._DecompressGzippedResponse(response) + self._cache.Set(key, url_data) + except urllib2.HTTPError, e: + print e + opener.close() + else: + url_data = self._cache.Get(key) + + # Always return the latest version + return url_data + +class _FileCacheError(Exception): + '''Base exception class for FileCache related errors''' + +class _FileCache(object): + + DEPTH = 3 + + def __init__(self,root_directory=None): + self._InitializeRootDirectory(root_directory) + + def Get(self,key): + path = self._GetPath(key) + if os.path.exists(path): + return open(path).read() + else: + return None + + def Set(self,key,data): + path = self._GetPath(key) + directory = os.path.dirname(path) + if not os.path.exists(directory): + os.makedirs(directory) + if not os.path.isdir(directory): + raise _FileCacheError('%s exists but is not a directory' % directory) + temp_fd, temp_path = tempfile.mkstemp() + temp_fp = os.fdopen(temp_fd, 'w') + temp_fp.write(data) + temp_fp.close() + if not path.startswith(self._root_directory): + raise _FileCacheError('%s does not appear to live under %s' % + (path, self._root_directory)) + if os.path.exists(path): + os.remove(path) + os.rename(temp_path, path) + + def Remove(self,key): + path = self._GetPath(key) + if not path.startswith(self._root_directory): + raise _FileCacheError('%s does not appear to live under %s' % + (path, self._root_directory )) + if os.path.exists(path): + os.remove(path) + + def GetCachedTime(self,key): + path = self._GetPath(key) + if os.path.exists(path): + return os.path.getmtime(path) + else: + return None + + def _GetUsername(self): + '''Attempt to find the username in a cross-platform fashion.''' + try: + return os.getenv('USER') or \ + os.getenv('LOGNAME') or \ + os.getenv('USERNAME') or \ + os.getlogin() or \ + 'nobody' + except (IOError, OSError), e: + return 'nobody' + + def _GetTmpCachePath(self): + username = self._GetUsername() + cache_directory = 'python.cache_' + username + return os.path.join(tempfile.gettempdir(), cache_directory) + + def _InitializeRootDirectory(self, root_directory): + if not root_directory: + root_directory = self._GetTmpCachePath() + root_directory = os.path.abspath(root_directory) + if not os.path.exists(root_directory): + os.mkdir(root_directory) + if not os.path.isdir(root_directory): + raise _FileCacheError('%s exists but is not a directory' % + root_directory) + self._root_directory = root_directory + + def _GetPath(self,key): + try: + hashed_key = md5(key).hexdigest() + except TypeError: + hashed_key = md5.new(key).hexdigest() + + return os.path.join(self._root_directory, + self._GetPrefix(hashed_key), + hashed_key) + + def _GetPrefix(self,hashed_key): + return os.path.sep.join(hashed_key[0:_FileCache.DEPTH]) diff --git a/src/main/python/twitter-graphs/twitterRelationGraphs.py b/src/main/python/twitter-graphs/twitterRelationGraphs.py new file mode 100644 index 0000000..3204c29 --- /dev/null +++ b/src/main/python/twitter-graphs/twitterRelationGraphs.py @@ -0,0 +1,503 @@ +#=============================================================================== +# !/usr/bin/env python +# +# twitterRelationGraphs.py +# +# This script calls the twitter API to obtain the followers and friends graph for +# a given set of users. It also obtains the tweets of these users. +# +# Dependency: +# 1. Python 2.7 +# 2. Python-Twitter-0.8.2 http://code.google.com/p/python-twitter/ +# - Refer to their documentation for installation of the dependencies +# - Currently using a modified version of library at: +# - Rebuild and reinstall with the mod ver +# +# Copyright (c) 2011 Andy Luong. +#=============================================================================== + +import twitter, sys +from operator import itemgetter +import codecs +import time +import argparse +import httplib + +#=============================================================================== +# Globals +#=============================================================================== +api = None +args = None + +#=============================================================================== +# Book Keeping +#=============================================================================== +processedUsers = set() +processedOthers = set() +unprocessableUsers = set() +failGetTweets = 0 + +#=============================================================================== +# Output Files +#=============================================================================== +outFollowersGraph = None +outFollowersTweets = None +outFriendsGraph = None +outFriendsTweets = None +outProcessedUsers = None +outUnProcessedUsers = None + +#=============================================================================== +# Parameters +#=============================================================================== +# Minimum number of followers the target user must have before we process their graphs +# Number must be at least 1 to be processed +minNumFollowers = 50 + +# Minimum number of tweets the follower/friend must have +minNumTweets = 30 + +# Maximum number of followers the target user has (celebs...) +maxNumFollowers = 1000 + +# Number of users to extract +numUsersToProcess = 100 + +# Number of tweets to extract for follower/friend (max = 200) +numTweetsPerTimeline = 200 + +# Twitter API Call delay between GetTimeline calls +apiCallDelaySeconds = 3 + +# max number continuous of failures before we apiCallDelaySeconds +maxNumFailures = 3 + +# Twitter API Call delay after maxNumFailures +failDelaySeconds = 7 + +#=============================================================================== +# Authenticate twitter API calls +# This is required in order for most sophisticated API calls +# Obtain Keys after registering an application at: https://dev.twitter.com/apps +# Return: True if successful, False otherwise +#=============================================================================== +def authenticate(ck, cs, atk, ats): + global api + + api = twitter.Api( + consumer_key = ck, + consumer_secret = cs, + access_token_key = atk, + access_token_secret = ats) + + if not api.VerifyCredentials(): + print "There is an error with your authentication keys." + return False + + else: + print "You have been successfully authenticated." + return True + +#=============================================================================== +# Checks how many remaining Twitter API are remaining +# If the 'remaining hits' is less than the minimum required, sleep script +#=============================================================================== +def checkApiLimit(): + if api == None: + return + attempts = 0 + while attempts < 3: + try: + attempts += 1 + minRequiredHits = numUsersToProcess - 2 + print "Remaining Hits: " + str(api.GetRateLimitStatus()[u'remaining_hits']) + + limit = api.GetRateLimitStatus()[u'remaining_hits'] + + if not outFollowersTweets == None or not outFriendsTweets == None: + limit -= minRequiredHits + + if limit <= 0: + limit_sleep = round(api.GetRateLimitStatus()[u'reset_time_in_seconds'] - time.time() + 5) + print "API limit reached, sleep for %d seconds..." % limit_sleep + time.sleep(limit_sleep) + + return + + except (twitter.TwitterError, httplib.BadStatusLine): + print "Error checking for Rate Limit, Sleeping for 60 seconds..." + time.sleep(60) + +#=============================================================================== +# Get the last 'numTweetsPerTimeline' tweets for a given user +# If the Twitter API call fails, try up 3 times +#=============================================================================== +def getTweets(user): + global failGetTweets + + print ("\tRetrieving Timeline for User: " + + str(user.GetId()) + "/" + user.GetName() + + "\tNumTweets: " + str(user.GetStatusesCount())) + + if user.GetId() in processedOthers: + print "\t\tAlready Extracted tweets, so skipping..." + return "SKIP" + + attempts = 0 + maxAttempts = 3 + + while attempts < maxAttempts: + try: + attempts += 1 + tweets = api.GetUserTimeline(id = user.GetId(), count = numTweetsPerTimeline) + print "\t\tSuccessfully Extracted %d tweets" % len(tweets) + failGetTweets = 0 + + return tweets + + except (KeyboardInterrupt, SystemExit): + raise + + except: + print ("\t\tAttempt %d: Failed retrieval for User: " + str(user.GetId())) % attempts + failGetTweets = failGetTweets + 1 + + time.sleep(2); + + if failGetTweets >= maxNumFailures: + print "\nDue to high API call failure, sleeping script for %d seconds\n" % failDelaySeconds + time.sleep(failDelaySeconds) + failGetTweets = 0 + + return [] + + +#=============================================================================== +# Remove users that have less than 'minNumTweets' of tweets and ... +# minNumFollowers <= # followers <= maxNumFollowers +# Sort the remaining follower by number of tweets (descending) +# Return: Subset of Followers +#=============================================================================== +def filterUsers(users): + users = [(user.GetStatusesCount(), user) for user in users if (user.GetStatusesCount() >= minNumTweets and + user.GetFollowersCount() >= minNumFollowers and + user.GetFollowersCount() <= maxNumFollowers)] + users = sorted(users, key = itemgetter(0), reverse = True) + users = [user[1] for user in users] + + return users + +#=============================================================================== +# For a given user, obtain their follower graph +# Notation: other = follower/friend +# Output the follower/friend graph (ID_user, SN_user, ID_other, SN_other) +# Output the follower/friend tweets, for each tweet (ID_other, SN_other, date, geo, text ) - Optional +#=============================================================================== +def outputGraphsAndTweets(user, fGraph, outGraph, outTweet): + for other in fGraph: + otherID = ("ID_" + other[0] + "\t" + + "SN_\'" + other[1] + "\'") + + #Write the graph + outGraph.write("ID_" + str(user.GetId()) + "\t" + + "SN_\'" + user.GetName() + "\'\t" + + otherID + + "\n") + + #Write the Tweets + if outTweet: + for tweet in other[2]: + + geo = tweet.GetLocation() + + if geo == None or geo == '': + geo = "NoCoords" + geo = "<" + geo + ">" + + tweet_s = (otherID + "\t" + + tweet.GetCreatedAt() + "\t" + + geo + "\t" + + tweet.GetText().lower()) + + outTweet.write(tweet_s + "\n") + +#=============================================================================== +# If the user has minNumFollowers <= # followers <= maxNumFollowers, find their followers and tweets +# Notation: other = follower/friend +# Return: list of followers/friends where each follower/friend(id, screen_name, list(tweets) ) +#=============================================================================== +def relationshipGraph(user, gType, bTweets): + global failGetTweets + + if gType.lower() == 'followers': + print "Processing Followers Graph for User: " + str(user.GetId()) + "/" + user.GetName() + elif gType.lower() == 'friends': + print "Processing Friends Graph for User: " + str(user.GetId()) + "/" + user.GetName() + + fg = [] + + if (user.GetFollowersCount() >= minNumFollowers and + user.GetFollowersCount() <= maxNumFollowers): + others = [] + attempts = 0 + maxAttempts = 2 + + while attempts < maxAttempts: + try: + attempts += 1 + + if gType.lower() == 'followers': + others = api.GetFollowers(user.GetId()) + break + elif gType.lower() == 'friends': + others = api.GetFriends(user.GetId()) + break + else: + print "Unsupported Relationship Graph Processing: " + gType + return [] + + except (KeyboardInterrupt, SystemExit): + raise + + except (twitter.TwitterError, httplib.BadStatusLine): + print ("\t\tAttempt %d: Failed retrieval for User: " + str(user.GetId())) % attempts + failGetTweets = failGetTweets + 1 + + if failGetTweets >= maxNumFailures: + print "\nDue to high API call failure, sleeping script for %d seconds\n" % failDelaySeconds + time.sleep(failDelaySeconds) + failGetTweets = 0 + + if others == []: + return [] + + others = filterUsers(others) + count = 0 + + for i in range(0, len(others)): + other = others[i] + + #Default Value such that we always output follower or friend graph + tweets = "SKIP" + + #Grab Tweets from Others + if bTweets: + #Delay between calls + time.sleep(apiCallDelaySeconds) + tweets = getTweets(other) + processedOthers.add(other.GetId()) + + if tweets == "SKIP": + fg += [(str(other.GetId()), other.GetName(), [])] + count += 1 + elif not tweets == []: + fg += [(str(other.GetId()), other.GetName(), tweets)] + count += 1 + + if count >= numUsersToProcess: + break + else: + print "\tSkipping User because of too few or many followers (%d)." % user.GetFollowersCount() + + return fg + +#=============================================================================== +# For each user, process their followers/friends graph, as specified +#=============================================================================== +def processUser(userID): + print "Processing User: " + userID + + try: + user = api.GetUser(userID) + + if args.f: + followers = relationshipGraph(user, 'followers', not outFollowersTweets == None) + outputGraphsAndTweets(user, followers, outFollowersGraph, outFollowersTweets) + + if args.g: + friends = relationshipGraph(user, 'friends', not outFriendsTweets == None) + outputGraphsAndTweets(user, friends, outFriendsGraph, outFriendsTweets) + + processedUsers.add(str(user.GetId())) + outProcessedUsers.write(userID + "\n") + + except (KeyboardInterrupt, SystemExit): + raise + + except (twitter.TwitterError, httplib.BadStatusLine): + print "\tUnprocessable User: ", userID + unprocessableUsers.add(userID) + outUnProcessedUsers.write(userID + "\n") + + +#=============================================================================== +# Closes output files +# Writes a log of processed users to 'processedUsers.log' +# Writes a log of unprocessed users to 'unprocessedUsers.log' +#=============================================================================== +def cleanup(): + if outFollowersGraph: + outFollowersGraph.close() + + if outFriendsGraph: + outFriendsGraph.close() + + if outFollowersTweets: + outFollowersTweets.close() + + if outFriendsTweets: + outFollowersTweets.close() + + if outUnProcessedUsers: + outUnProcessedUsers.close() + + if outProcessedUsers: + outProcessedUsers.close() + +#=============================================================================== +# Reads in the file with a list of users and attempts to process each user +#=============================================================================== +def processUsers(users_file): + print "Processing Users..." + fin = open(users_file, 'r') + + + for user in fin: + user = user.strip() + + if not user in processedUsers and not len(user) == 0: + checkApiLimit() + processUser(user) + + fin.close() + +def main(): + parser = argparse.ArgumentParser(description = 'Build Twitter Relation Graphs.') + setup = parser.add_argument_group('Script Setup') + #graphs = parser.add_mutually_exclusive_group(required = True) + graphs = parser.add_argument_group('Relationships to Compute') + params = parser.add_argument_group('Script Parameters (Optional)') + + setup.add_argument('-k', + action = "store", + help = 'Authentication Keys Input File', + metavar = 'FILE', + required = True) + setup.add_argument('-u', + action = "store", + help = 'Users Input File', + metavar = 'FILE', + required = True) + graphs.add_argument('-f', + action = "store", + metavar = 'FILE', + help = 'Followers Graph File, Followers Tweets File (opt)', + nargs = '+') + graphs.add_argument('-g', + action = "store", + metavar = 'FILE', + help = 'Friends Graph File, Friends Tweets File (opt)', + nargs = '+') + #Parameters + params.add_argument('-minNumFollowers', + action = "store", + help = 'Minimum number of followers the target user must have before we process their graphs. Default %(default)s', + metavar = 'INT', + type = int, + default = 50) + params.add_argument('-minNumTweets', + action = "store", + help = 'Minimum number of tweets the follower/friend must have. Default %(default)s', + metavar = 'INT', + type = int, + default = 30,) + params.add_argument('-maxNumFollowers', + action = "store", + help = 'Maximum number of followers the target user has. Default %(default)s', + metavar = 'INT', + type = int, + default = 1000) + params.add_argument('-numUsersToProcess', + action = "store", + help = 'Number of followers/friends to extract for a user. Default %(default)s', + metavar = 'INT', + type = int, + default = 100) + params.add_argument('-numTweetsPerTimeline', + action = "store", + help = 'Number of tweets to extract for follower/friend (max = 200). Default %(default)s', + metavar = 'INT', + type = int, + default = 200) + params.add_argument('-apiCallDelaySeconds', + action = "store", + help = 'Twitter API Call delay between GetTimeline calls. Default %(default)s', + metavar = 'INT', + type = int, + default = 3) + params.add_argument('-maxNumFailures', + action = "store", + help = 'Max number continuous of failures before we apiCallDelaySeconds. Default %(default)s', + metavar = 'INT', + type = int, + default = 3) + params.add_argument('-failDelaySeconds', + action = "store", + help = 'Twitter API Call delay after maxNumFailures. Default %(default)s', + metavar = 'INT', + type = int, + default = 7) + global args + args = parser.parse_args() + #print args + #sys.exit() + + #Parse Key File + fin = open(args.k, 'r') + keys = fin.readlines() + fin.close() + if not authenticate(keys[0].strip(), + keys[1].strip(), + keys[2].strip(), + keys[3].strip()): + sys.exit() + + #Create output files + if args.f: + global outFollowersGraph, outFollowersTweets + outFollowersGraph = codecs.open(args.f[0], encoding = 'utf-8', mode = 'w') + if len(args.f) >= 2: + outFollowersTweets = codecs.open(args.f[1], encoding = 'utf-8', mode = 'w') + + + if args.g: + global outFriendsGraph, outFriendsTweets + outFriendsGraph = codecs.open(args.g[0], encoding = 'utf-8', mode = 'w') + if len(args.g) >= 2: + outFriendsTweets = codecs.open(args.g[1], encoding = 'utf-8', mode = 'w') + + global outUnProcessedUsers, outProcessedUsers + outUnProcessedUsers = codecs.open(args.u + '.unprocessedUsers.log', encoding = 'utf-8', mode = 'w') + outProcessedUsers = codecs.open(args.u + '.processedUsers.log', encoding = 'utf-8', mode = 'w') + + + #Set Parameters + global minNumTweets, minNumFollowers, maxNumFollowers, numUsersToProcess, numTweetsPerTimeline + global apiCallDelaySeconds, maxNumFailures, failDelaySeconds + minNumTweets = args.minNumTweets + minNumFollowers = max(1, args.minNumFollowers) + maxNumFollowers = args.maxNumFollowers + numUsersToProcess = args.numUsersToProcess + numTweetsPerTimeline = min(200, args.numTweetsPerTimeline) + apiCallDelaySeconds = args.apiCallDelaySeconds + maxNumFailures = args.maxNumFailures + failDelaySeconds = args.failDelaySeconds + + #Process Users + processUsers(args.u) + + #Cleanup + cleanup() + +main() + diff --git a/src/main/python/twitter_geotext_process.py b/src/main/python/twitter_geotext_process.py new file mode 100755 index 0000000..a55f8b2 --- /dev/null +++ b/src/main/python/twitter_geotext_process.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python + +####### +####### twitter_geotext_process.py +####### +####### Copyright (c) 2010 Ben Wing. +####### + +import sys, re +import math +from optparse import OptionParser +from nlputil import * + +############################################################################ +# Documentation # +############################################################################ + +# This program reads in data from the Geotext corpus provided by +# Eisenstein et al., and converts it into the format used for the +# Wikigrounder experiments. + +############################################################################ +# Code # +############################################################################ + +# Debug level; if non-zero, output lots of extra information about how +# things are progressing. If > 1, even more info. +debug = 0 + +# If true, print out warnings about strangely formatted input +show_warnings = True + +####################################################################### +### Utility functions ### +####################################################################### + +def uniprint(text): + '''Print Unicode text in UTF-8, so it can be output without errors''' + if type(text) is unicode: + print text.encode("utf-8") + else: + print text + +def warning(text): + '''Output a warning, formatting into UTF-8 as necessary''' + if show_warnings: + uniprint("Warning: %s" % text) + +####################################################################### +# Process files # +####################################################################### + +vocab_id_to_token = {} +user_id_to_token = {} +user_token_to_id = {} + +# Process vocabulary file +def read_vocab(filename): + id = 0 + for line in open(filename): + id += 1 + word = line.strip().split('\t')[0] + #errprint("%s: %s" % (id, word)) + vocab_id_to_token["%s" % id] = word + +# Process user file +def read_user_info(filename): + id = 666000000 + for line in open(filename): + id += 1 + userid = line.strip().split('\t')[0] + #errprint("%s: %s" % (id, userid)) + user_id_to_token[id] = userid + if userid in user_token_to_id: + errprint("User %s seen twice! Current ID=%s, former=%s" % ( + userid, id, user_token_to_id[userid])) + user_token_to_id[userid] = id + +# Process a file of data in "LDA" format +def process_lda_file(split, filename, userid_filename, artdat_file, + counts_file): + + userids = [] + for line in open(userid_filename): + userid, lat, long = line.strip().split('\t') + usernum = user_token_to_id[userid] + userids.append(usernum) + print >>artdat_file, ("%s\t%s\t%s\t\tMain\tno\tno\tno\t%s,%s\t1" % + (usernum, userid, split, lat, long)) + + userind = 0 + for line in open(filename): + line = line.strip() + args = line.split() + numtypes = args[0] + lat = args[1] + long = args[2] + usernum = userids[userind] + userind += 1 + print >>counts_file, "Article title: %s" % user_id_to_token[usernum] + print >>counts_file, "Article ID: %s" % usernum + for arg in args[3:]: + wordid, count = arg.split(':') + print >>counts_file, "%s = %s" % (vocab_id_to_token[wordid], count) + +####################################################################### +# Main code # +####################################################################### + +def main(): + op = OptionParser(usage="%prog [options] input_dir") + op.add_option("-i", "--input-dir", metavar="DIR", + help="Input dir with Geotext preprocessed files.") + op.add_option("-o", "--output-dir", metavar="DIR", + help="""Dir to output processed files.""") + op.add_option("-p", "--prefix", default="geotext-twitter-", + help="""Prefix to use for outputted files.""") + op.add_option("-d", "--debug", metavar="LEVEL", + help="Output debug info at given level") + + opts, args = op.parse_args() + + global debug + if opts.debug: + debug = int(opts.debug) + + if not opts.input_dir: + op.error("Must specify input dir using -i or --input-dir") + if not opts.output_dir: + op.error("Must specify output dir using -i or --output-dir") + + prefix = "%s/%s" % (opts.output_dir, opts.prefix) + artdat_file = open("%scombined-document-data.txt" % prefix, "w") + print >>artdat_file, "id\ttitle\tsplit\tredir\tnamespace\tis_list_of\tis_disambig\tis_list\tcoord\tincoming_links" + + counts_file = open("%scounts-only-coord-documents.txt" % prefix, "w") + + train_file = "%s/%s" % (opts.input_dir, "train.dat") + dev_file = "%s/%s" % (opts.input_dir, "dev.dat") + test_file = "%s/%s" % (opts.input_dir, "test.dat") + vocab_file = "%s/%s" % (opts.input_dir, "vocab_wc_dc") + userid_train_file = "%s/%s" % (opts.input_dir, "user_info.train") + userid_dev_file = "%s/%s" % (opts.input_dir, "user_info.dev") + userid_test_file = "%s/%s" % (opts.input_dir, "user_info.test") + userid_file = "%s/%s" % (opts.input_dir, "user_info") + + read_vocab(vocab_file) + read_user_info(userid_file) + process_lda_file('test', test_file, userid_test_file, artdat_file, counts_file) + process_lda_file('training', train_file, userid_train_file, artdat_file, counts_file) + process_lda_file('dev', dev_file, userid_dev_file, artdat_file, counts_file) + artdat_file.close() + counts_file.close() + +main() + diff --git a/src/main/python/twitter_to_lda.py b/src/main/python/twitter_to_lda.py new file mode 100755 index 0000000..c1bb273 --- /dev/null +++ b/src/main/python/twitter_to_lda.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +####### +####### twitter_to_lda.py +####### +####### Copyright (c) 2011 Ben Wing. +####### + +import sys, re +import math +from optparse import OptionParser +from nlputil import * + +############################################################################ +# Quick Start # +############################################################################ + +# This program reads in data from the user_pos_word file and related files, +# generated by preproc/extract.py from the original file "full_text.txt" +# in the Geotext corpus of Eisenstein et al. 2010, located at +# +# http://www.ark.cs.cmu.edu/GeoText/ +# +# It then converts the data to "LDA" format, splitting it into train/dev/test +# files in the process. These files are then used by +# twitter_geotext_process.py. +# +# To run: +# +# (1) The files user_info.train, user_info.dev, and user_info.test are +# needed, to indicate what goes in which splits. These files need to +# be copied from the processed_data subdir of the Geotext corpus, e.g. +# +# cp processed_data/user_info.* processed-20-docthresh +# +# (2) Normally, cd to the directory containing the files output by extract.py, +# and execute e.g. +# +# PATH_TO_SCRIPT/twitter_to_lda.py -i . -o . +# +# This will generate the LDA-format files train.dat, dev.dat, and test.dat. + +############################################################################ +# Documentation # +############################################################################ + +# This program is intended to be used after extract.py (part of the Geotext +# corpus) is rerun, possibly changing settings, esp. doc_count_thresh, which +# we reduce below its original value of 40. +# +# The format of user_pos_word is a series of lines like this: +# +# 1 6 6 +# 1 7 7 +# 1 8 8 +# 1 9 1 +# 1 10 9 +# 1 11 10 +# 1 12 11 +# 1 13 12 +# 1 14 13 +# +# where each line contains "document-id, position, word-id" where a +# "document" is a series of tweets from a single user. +# +# Other files needed on input are "user_info", "user_info.train", +# "user_info.dev" and "user_info.test". +# +# user_info has lines like this: +# +# USER_79321756 47.528139 -122.197916 +# USER_6197f95d 40.2015 -74.806535 +# USER_ce270acf 40.668643 -73.981635 +# +# which identify the user and latitude/longitude of each "document". +# Note that there is no repetition in the user tags, i.e. each document +# is a given a separate "user tag" even if a particular user produced more +# than one document. The latitude/longitude pairs do repeat, sometimes. +# +# The other three user_info.* files are in the same format but list only +# the documents going into each of the three train/dev/test sets (identiable +# by the user tag). +# +# A sample line in LDA format is as follows: +# +# 259 40.817009 -73.947467 234:2 636:1 402:3 67:4 603:1 670:2 369:1 235:1 34:7 436:4 637:1 336:1 604:1 269:3 135:4 671:1 1:85 437:1 638:1 169:1 404:1 605:1 672:1 2:4 203:2 572:1 36:7 438:2 639:2 505:1 70:1 606:1 137:2 673:1 3:3 104:2 640:1 372:1 37:14 138:1 71:3 607:2 674:1 4:1 641:1 373:1 407:1 675:1 72:2 608:1 139:2 5:2 642:1 39:4 140:1 676:1 609:2 542:3 408:1 643:1 509:1 174:2 40:1 677:1 275:1 141:1 342:1 7:1 610:2 443:1 644:3 309:2 678:1 611:1 645:5 377:1 42:1 679:1 210:1 612:1 110:1 646:1 680:1 144:2 77:2 613:2 245:3 111:1 647:1 681:1 145:1 480:1 11:3 212:3 614:1 279:1 179:1 112:1 648:1 682:1 280:1 146:1 213:1 615:3 649:1 415:1 683:1 13:61 616:1 650:1 684:1 14:1 550:2 617:1 383:1 182:2 115:6 651:1 685:1 350:1 15:6 82:3 618:1 183:1 49:1 585:2 652:2 317:1 518:2 686:1 16:17 619:1 351:1 586:1 251:1 653:3 687:1 84:2 620:1 51:6 587:1 654:2 688:1 420:1 85:4 621:3 286:3 152:1 588:7 454:1 119:9 655:1 689:1 622:1 287:1 153:1 53:1 589:1 455:1 656:1 690:1 87:8 623:1 54:1 590:1 255:2 121:1 657:1 322:4 389:1 691:1 88:1 624:1 21:12 591:1 256:1 658:1 323:1 189:1 390:1 55:1 692:1 625:2 290:3 223:1 659:1 391:4 56:5 592:1 693:1 90:7 626:1 358:3 660:1 593:1 694:1 91:5 627:1 359:3 24:9 225:3 661:1 58:21 594:1 695:1 25:1 628:1 494:1 193:1 662:1 327:2 595:1 696:1 26:4 428:1 629:1 127:21 663:1 60:1 596:1 27:2 563:2 630:1 128:4 664:1 530:1 597:4 229:1 430:1 631:3 665:1 598:1 263:3 163:2 632:3 398:4 599:1 666:1 633:1 600:1 667:1 567:2 433:1 634:1 500:1 199:2 601:1 266:3 467:3 668:3 434:3 32:9 635:1 300:2 501:1 166:1 602:1 669:1 334:1 +# +# This gives first the number of word types seen in the document, followed +# by latitude and longitude, then a pair ID:COUNT for each word type, listing +# the type's ID and the count of how many times that word is seen in the +# document. + +############################################################################ +# Code # +############################################################################ + +# Debug level; if non-zero, output lots of extra information about how +# things are progressing. If > 1, even more info. +debug = 0 + +# If true, print out warnings about strangely formatted input +show_warnings = True + +####################################################################### +### Utility functions ### +####################################################################### + +def uniprint(text): + '''Print Unicode text in UTF-8, so it can be output without errors''' + if type(text) is unicode: + print text.encode("utf-8") + else: + print text + +def warning(text): + '''Output a warning, formatting into UTF-8 as necessary''' + if show_warnings: + uniprint("Warning: %s" % text) + +####################################################################### +# Process files # +####################################################################### + +vocab_id_to_token = {} +# Mapping from user tag (e.g. USER_ce270acf) to document ID (numbered +# starting at 1). +user_id_to_document = {} +# Mapping from document ID to latitude/longitude. +document_to_latitude = {} +document_to_longitude = {} +# Mapping from document ID to word-count dictionary. +document_to_word_count = {} +# Mapping from one of 'train', 'dev', or 'test' to a list of user tags. +user_id_by_split = {} + +# Process combined user_info file +def read_user_info(filename): + id = 0 + for line in open(filename): + id += 1 + userid, lat, long = line.strip().split('\t') + #errprint("%s: %s" % (id, userid)) + if userid in user_id_to_document: + errprint("User %s seen twice! Current ID=%s, former=%s" % ( + userid, id, user_id_to_document[userid])) + user_id_to_document[userid] = id + document_to_latitude[id] = lat + document_to_longitude[id] = long + +# Process split user_info file +def read_user_info_split(split, filename): + user_id_by_split[split] = [] + for line in open(filename): + userid = line.strip().split('\t')[0] + user_id_by_split[split] += [userid] + +# Read user_pos_word +def read_user_pos_word(filename): + for line in open(filename): + doc, pos, word = line.strip().split('\t') + doc = int(doc) + if doc not in document_to_word_count: + document_to_word_count[doc] = intdict() + document_to_word_count[doc][word] += 1 + +# Output file in LDA format for given split +def output_lda_file(split, filename): + outfile = open(filename, "w") + for userid in user_id_by_split[split]: + docid = user_id_to_document[userid] + words = document_to_word_count[docid] + outfile.write("%s " % len(words)) + outfile.write("%s %s" % (document_to_latitude[docid], + document_to_longitude[docid])) + for word, count in words.iteritems(): + outfile.write(" %s:%s" % (word, count)) + outfile.write('\n') + outfile.close() + +####################################################################### +# Main code # +####################################################################### + +def main(): + op = OptionParser(usage="%prog [options] input_dir") + op.add_option("-i", "--input-dir", metavar="DIR", + help="Input dir with Geotext preprocessed files.") + op.add_option("-o", "--output-dir", metavar="DIR", + help="""Dir to output processed files.""") + op.add_option("-d", "--debug", metavar="LEVEL", + help="Output debug info at given level") + + opts, args = op.parse_args() + + global debug + if opts.debug: + debug = int(opts.debug) + + if not opts.input_dir: + op.error("Must specify input dir using -i or --input-dir") + if not opts.output_dir: + op.error("Must specify output dir using -i or --output-dir") + + splits = ['train', 'dev', 'test'] + read_user_info("%s/%s" % (opts.input_dir, "user_info")) + for split in splits: + read_user_info_split(split, "%s/%s.%s" % + (opts.input_dir, "user_info", split)) + read_user_pos_word("%s/%s" % (opts.input_dir, "user_pos_word")) + for split in splits: + output_lda_file(split, "%s/%s.dat" % (opts.output_dir, split)) + +main() + diff --git a/src/main/python/unescape_entities.py b/src/main/python/unescape_entities.py new file mode 100644 index 0000000..82f2f36 --- /dev/null +++ b/src/main/python/unescape_entities.py @@ -0,0 +1,32 @@ +import re, htmlentitydefs + +# NOTE: Courtesy of Frederik Lundh. +# +# http://effbot.org/zone/re-sub.htm#unescape-html + +## +# Removes HTML or XML character references and entities from a text string. +# +# @param text The HTML (or XML) source text. +# @return The plain text, as a Unicode string, if necessary. + +def unescape(text): + def fixup(m): + text = m.group(0) + if text[:2] == "&#": + # character reference + try: + if text[:3] == "&#x": + return unichr(int(text[3:-1], 16)) + else: + return unichr(int(text[2:-1])) + except ValueError: + pass + else: + # named entity + try: + text = unichr(htmlentitydefs.name2codepoint[text[1:-1]]) + except KeyError: + pass + return text # leave as is + return re.sub("&#?\w+;", fixup, text) diff --git a/src/main/resources/data/deu/stopwords.txt b/src/main/resources/data/deu/stopwords.txt new file mode 100644 index 0000000..edef220 --- /dev/null +++ b/src/main/resources/data/deu/stopwords.txt @@ -0,0 +1,231 @@ +aber +alle +allem +allen +aller +alles +als +also +am +an +ander +andere +anderem +anderen +anderer +anderes +anderm +andern +anderr +anders +auch +auf +aus +bei +bin +bis +bist +da +damit +dann +der +den +des +dem +die +das +daß +derselbe +derselben +denselben +desselben +demselben +dieselbe +dieselben +dasselbe +dazu +dein +deine +deinem +deinen +deiner +deines +denn +derer +dessen +dich +dir +du +dies +diese +diesem +diesen +dieser +dieses +doch +dort +durch +ein +eine +einem +einen +einer +eines +einig +einige +einigem +einigen +einiger +einiges +einmal +er +ihn +ihm +es +etwas +euer +eure +eurem +euren +eurer +eures +für +gegen +gewesen +hab +habe +haben +hat +hatte +hatten +hier +hin +hinter +ich +mich +mir +ihr +ihre +ihrem +ihren +ihrer +ihres +euch +im +in +indem +ins +ist +jede +jedem +jeden +jeder +jedes +jene +jenem +jenen +jener +jenes +jetzt +kann +kein +keine +keinem +keinen +keiner +keines +können +könnte +machen +man +manche +manchem +manchen +mancher +manches +mein +meine +meinem +meinen +meiner +meines +mit +muss +musste +nach +nicht +nichts +noch +nun +nur +ob +oder +ohne +sehr +sein +seine +seinem +seinen +seiner +seines +selbst +sich +sie +ihnen +sind +so +solche +solchem +solchen +solcher +solches +soll +sollte +sondern +sonst +über +um +und +uns +unse +unsem +unsen +unser +unses +unter +viel +vom +von +vor +während +war +waren +warst +was +weg +weil +weiter +welche +welchem +welchen +welcher +welches +wenn +werde +werden +wie +wieder +will +wir +wird +wirst +wo +wollen +wollte +würde +würden +zu +zum +zur +zwar +zwischen diff --git a/src/main/resources/data/eng/stopwords.txt b/src/main/resources/data/eng/stopwords.txt new file mode 100644 index 0000000..c6ca14f --- /dev/null +++ b/src/main/resources/data/eng/stopwords.txt @@ -0,0 +1,572 @@ +'s +a +a's +able +about +above +according +accordingly +across +actually +after +afterwards +again +against +ain't +all +allow +allows +almost +alone +along +already +also +although +always +am +among +amongst +an +and +another +any +anybody +anyhow +anyone +anything +anyway +anyways +anywhere +apart +appear +appreciate +appropriate +are +aren't +around +as +aside +ask +asking +associated +at +available +away +awfully +b +be +became +because +become +becomes +becoming +been +before +beforehand +behind +being +believe +below +beside +besides +best +better +between +beyond +both +brief +but +by +c +c'mon +c's +came +can +can't +cannot +cant +cause +causes +certain +certainly +changes +clearly +co +com +come +comes +concerning +consequently +consider +considering +contain +containing +contains +corresponding +could +couldn't +course +currently +d +definitely +described +despite +did +didn't +different +do +does +doesn't +doing +don't +done +down +downwards +during +e +each +edu +eg +eight +either +else +elsewhere +enough +entirely +especially +et +etc +even +ever +every +everybody +everyone +everything +everywhere +ex +exactly +example +except +f +far +few +fifth +first +five +followed +following +follows +for +former +formerly +forth +four +from +further +furthermore +g +get +gets +getting +given +gives +go +goes +going +gone +got +gotten +greetings +h +had +hadn't +happens +hardly +has +hasn't +have +haven't +having +he +he's +hello +help +hence +her +here +here's +hereafter +hereby +herein +hereupon +hers +herself +hi +him +himself +his +hither +hopefully +how +howbeit +however +i +i'd +i'll +i'm +i've +ie +if +ignored +immediate +in +inasmuch +inc +indeed +indicate +indicated +indicates +inner +insofar +instead +into +inward +is +isn't +it +it'd +it'll +it's +its +itself +j +just +k +keep +keeps +kept +know +knows +known +l +last +lately +later +latter +latterly +least +less +lest +let +let's +like +liked +likely +little +look +looking +looks +ltd +m +mainly +many +may +maybe +me +mean +meanwhile +merely +might +more +moreover +most +mostly +much +must +my +myself +n +name +namely +nd +near +nearly +necessary +need +needs +neither +never +nevertheless +new +next +nine +no +nobody +non +none +noone +nor +normally +not +nothing +novel +now +nowhere +o +obviously +of +off +often +oh +ok +okay +old +on +once +one +ones +only +onto +or +other +others +otherwise +ought +our +ours +ourselves +out +outside +over +overall +own +p +particular +particularly +per +perhaps +placed +please +plus +possible +presumably +probably +provides +q +que +quite +qv +r +rather +rd +re +really +reasonably +regarding +regardless +regards +relatively +respectively +right +s +said +same +saw +say +saying +says +second +secondly +see +seeing +seem +seemed +seeming +seems +seen +self +selves +sensible +sent +serious +seriously +seven +several +shall +she +should +shouldn't +since +six +so +some +somebody +somehow +someone +something +sometime +sometimes +somewhat +somewhere +soon +sorry +specified +specify +specifying +still +sub +such +sup +sure +t +t's +take +taken +tell +tends +th +than +thank +thanks +thanx +that +that's +thats +the +their +theirs +them +themselves +then +thence +there +there's +thereafter +thereby +therefore +therein +theres +thereupon +these +they +they'd +they'll +they're +they've +think +third +this +thorough +thoroughly +those +though +three +through +throughout +thru +thus +to +together +too +took +toward +towards +tried +tries +truly +try +trying +twice +two +u +un +under +unfortunately +unless +unlikely +until +unto +up +upon +us +use +used +useful +uses +using +usually +uucp +v +value +various +very +via +viz +vs +w +want +wants +was +wasn't +way +we +we'd +we'll +we're +we've +welcome +well +went +were +weren't +what +what's +whatever +when +whence +whenever +where +where's +whereafter +whereas +whereby +wherein +whereupon +wherever +whether +which +while +whither +who +who's +whoever +whole +whom +whose +why +will +willing +wish +with +within +without +won't +wonder +would +would +wouldn't +x +y +yes +yet +you +you'd +you'll +you're +you've +your +yours +yourself +yourselves +z +zero diff --git a/src/main/resources/data/eng/stopwords.txt.old b/src/main/resources/data/eng/stopwords.txt.old new file mode 100644 index 0000000..5312ae4 --- /dev/null +++ b/src/main/resources/data/eng/stopwords.txt.old @@ -0,0 +1,514 @@ +a +about +above +according +across +after +afterwards +again +against +ain't +albeit +all +almost +alone +along +already +also +although +always +am +among +amongst +an +and +another +any +anybody +anyhow +anyone +anything +anyway +anywhere +apart +are +aren't +around +as +at +av +be +became +because +become +becomes +becoming +been +before +beforehand +behind +being +below +beside +besides +between +beyond +both +but +by +can +can +cannot +canst +can't +certain +cf +choose +click +com +contrariwise +cos +could +couldn +couldn't +cu +day +did +didn't +do +does +doesn't +doesn't +doing +don +don't +dost +doth +double +down +dual +during +each +eight +eighteen +eighty +either +eleven +else +elsewhere +enough +et +etc +even +ever +every +everybody +everyone +everything +everywhere +except +excepted +excepting +exception +exclude +excluding +exclusive +far +farther +farthest +few +ff +fifteen +fifty +first +five +for +formerly +forth +forward +four +fourteen +fourty +from +front +further +furthermore +furthest +get +go +had +hadn't +halves +hardly +has +hasn't +hast +hath +have +haven't +he +he'd +he'll +hence +henceforth +her +here +hereabouts +hereafter +hereby +herein +here's +hereto +hereupon +hers +herself +he's +him +himself +hindmost +his +hither +hitherto +how +however +how's +howsoever +http +hundred +i +i'd +ie +if +i'll +i'm +in +inasmuch +inc +include +included +including +indeed +indoors +inside +insomuch +instead +into +inward +inwards +is +isn +isn't +it +it'd +it'll +its +it's +itself +i've +just +kg +kind +km +last +latter +latterly +less +lest +let +let's +like +little +ltd +many +may +maybe +me +meantime +meanwhile +might +more +moreover +most +mostly +mr +mrs +ms +much +must +my +myself +namely +need +neither +never +nevertheless +next +nine +ninetenn +ninety +no +nobody +none +nonetheless +noone +nope +nor +not +nothing +notwithstanding +now +nowadays +nowhere +o'clock +of +off +often +ok +on +once +one +one +only +onto +or +other +others +otherwise +ought +our +ours +ourselves +out +outside +over +own +per +perhaps +plenty +provide +quite +rather +really +round +said +sake +same +sang +save +saw +see +seeing +seem +seemed +seeming +seems +seen +seldom +selves +sent +seven +seventeen +seventy +several +shalt +she +she'd +she'll +she's +should +shouldn +shouldn't +shown +sideways +since +six +sixteen +sixty +slept +slew +slung +slunk +smote +so +some +somebody +somehow +someone +something +sometime +sometimes +somewhat +somewhere +spake +spat +spoke +spoken +sprang +sprung +stave +staves +still +such +supposing +ten +than +that +that'll +that's +the +thee +their +them +themselves +then +thence +thenceforth +there +thereabout +thereabouts +thereafter +thereby +there'd +therefore +therein +there'll +thereof +thereon +there's +thereto +thereupon +there've +these +they +they'd +they'll +they're +they've +thirteen +this +those +thou +though +thousand +three +thrice +thrity +through +throughout +thru +thus +thy +thyself +till +to +together +too +toward +towards +twelve +twenty +two +ugh +unable +under +underneath +unless +unlike +until +up +upon +upward +upwards +us +use +used +using +very +via +vs +want +was +wasn't +we +we'd +week +well +we'll +were +we're +weren't +we've +what +whatever +what's +whatsoever +when +whence +whenever +whensoever +where +whereabouts +whereafter +whereas +whereat +whereby +wherefore +wherefrom +wherein +whereinto +whereof +whereon +wheresoever +whereto +whereunto +whereupon +wherever +wherewith +whether +whew +which +whichever +whichsoever +while +whilst +whither +who +whoa +who'd +whoever +whole +whom +whomever +whomsoever +who's +whose +whosoever +who've +why +will +wilt +with +within +without +won +won't +worse +worst +would +wouldn't +would've +wow +www +ye +year +yet +yippee +you +you'd +you'll +your +you're +yours +yourself +yourselves +you've \ No newline at end of file diff --git a/src/main/resources/data/geo/country-codes.txt b/src/main/resources/data/geo/country-codes.txt new file mode 100644 index 0000000..94a6903 --- /dev/null +++ b/src/main/resources/data/geo/country-codes.txt @@ -0,0 +1,273 @@ +Afghanistan AF AF AFG 004 AFG .af +Albania AL AL ALB 008 ALB .al +Algeria AG DZ DZA 012 DZA .dz +American Samoa AQ AS ASM 016 ASM .as +Andorra AN AD AND 020 AND .ad +Angola AO AO AGO 024 AGO .ao +Anguilla AV AI AIA 660 AIA .ai +Antarctica AY AQ ATA 010 ATA .aq ISO defines as the territory south of 60 degrees south latitude +Antigua and Barbuda AC AG ATG 028 ATG .ag +Argentina AR AR ARG 032 ARG .ar +Armenia AM AM ARM 051 ARM .am +Aruba AA AW ABW 533 ABW .aw +Ashmore and Cartier Islands AT AU AUS 036 AUS - ISO includes with Australia +Australia AS AU AUS 036 AUS .au ISO includes Ashmore and Cartier Islands, Coral Sea Islands +Austria AU AT AUT 040 AUT .at +Azerbaijan AJ AZ AZE 031 AZE .az +Bahamas, The BF BS BHS 044 BHS .bs +Bahrain BA BH BHR 048 BHR .bh +Baker Island FQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Bangladesh BG BD BGD 050 BGD .bd +Barbados BB BB BRB 052 BRB .bb +Bassas da India BS - - - - - administered as part of French Southern and Antarctic Lands; no ISO codes assigned +Belarus BO BY BLR 112 BLR .by +Belgium BE BE BEL 056 BEL .be +Belize BH BZ BLZ 084 BLZ .bz +Benin BN BJ BEN 204 BEN .bj +Bermuda BD BM BMU 060 BMU .bm +Bhutan BT BT BTN 064 BTN .bt +Bolivia BL BO BOL 068 BOL .bo +Bosnia and Herzegovina BK BA BIH 070 BIH .ba +Botswana BC BW BWA 072 BWA .bw +Bouvet Island BV BV BVT 074 BVT .bv +Brazil BR BR BRA 076 BRA .br +British Indian Ocean Territory IO IO IOT 086 IOT .io +British Virgin Islands VI VG VGB 092 VGB .vg +Brunei BX BN BRN 096 BRN .bn +Bulgaria BU BG BGR 100 BGR .bg +Burkina Faso UV BF BFA 854 BFA .bf +Burma BM MM MMR 104 MMR .mm ISO uses the name Myanmar +Burundi BY BI BDI 108 BDI .bi +Cambodia CB KH KHM 116 KHM .kh +Cameroon CM CM CMR 120 CMR .cm +Canada CA CA CAN 124 CAN .ca +Cape Verde CV CV CPV 132 CPV .cv +Cayman Islands CJ KY CYM 136 CYM .ky +Central African Republic CT CF CAF 140 CAF .cf +Chad CD TD TCD 148 TCD .td +Chile CI CL CHL 152 CHL .cl +China CH CN CHN 156 CHN .cn see also Taiwan +Christmas Island KT CX CXR 162 CXR .cx +Clipperton Island IP PF PYF 258 FYP - ISO includes with French Polynesia +Cocos (Keeling) Islands CK CC CCK 166 AUS .cc +Colombia CO CO COL 170 COL .co +Comoros CN KM COM 174 COM .km +Congo, Democratic Republic of the CG CD COD 180 COD .cd formerly Zaire +Congo, Republic of the CF CG COG 178 COG .cg +Cook Islands CW CK COK 184 COK .ck +Coral Sea Islands CR AU AUS 036 AUS - ISO includes with Australia +Costa Rica CS CR CRI 188 CRI .cr +Cote d'Ivoire IV CI CIV 384 CIV .ci +Croatia HR HR HRV 191 HRV .hr +Cuba CU CU CUB 192 CUB .cu +Cyprus CY CY CYP 196 CYP .cy +Czech Republic EZ CZ CZE 203 CZE .cz +Denmark DA DK DNK 208 DNK .dk +Djibouti DJ DJ DJI 262 DJI .dj +Dominica DO DM DMA 212 DMA .dm +Dominican Republic DR DO DOM 214 DOM .do +Ecuador EC EC ECU 218 ECU .ec +Egypt EG EG EGY 818 EGY .eg +El Salvador ES SV SLV 222 SLV .sv +Equatorial Guinea EK GQ GNQ 226 GNQ .gq +Eritrea ER ER ERI 232 ERI .er +Estonia EN EE EST 233 EST .ee +Ethiopia ET ET ETH 231 ETH .et +Europa Island EU - - - - - administered as part of French Southern and Antarctic Lands; no ISO codes assigned +Falkland Islands (Islas Malvinas) FK FK FLK 238 FLK .fk +Faroe Islands FO FO FRO 234 FRO .fo +Fiji FJ FJ FJI 242 FJI .fj +Finland FI FI FIN 246 FIN .fi +France FR FR FRA 250 FRA .fr +France, Metropolitan - FX FXX 249 - .fx ISO limits to the European part of France, excluding French Guiana, French Polynesia, French Southern and Antarctic Lands, Guadeloupe, Martinique, Mayotte, New Caledonia, Reunion, Saint Pierre and Miquelon, Wallis and Futuna +French Guiana FG GF GUF 254 GUF .gf +French Polynesia FP PF PYF 258 PYF .pf ISO includes Clipperton Island +French Southern and Antarctic Lands FS TF ATF 260 ATF .tf FIPS 10-4 does not include the French-claimed portion of Antarctica (Terre Adelie) +Gabon GB GA GAB 266 GAB .ga +Gambia, The GA GM GMB 270 GMB .gm +Gaza Strip GZ PS PSE 275 PSE .ps ISO identifies as Occupied Palestinian Territory +Georgia GG GE GEO 268 GEO .ge +Germany GM DE DEU 276 DEU .de +Ghana GH GH GHA 288 GHA .gh +Gibraltar GI GI GIB 292 GIB .gi +Glorioso Islands GO - - - - - administered as part of French Southern and Antarctic Lands; no ISO codes assigned +Greece GR GR GRC 300 GRC .gr For its internal communications, the European Union recommends the use of the code EL in lieu of the ISO 3166-2 code of GR +Greenland GL GL GRL 304 GRL .gl +Grenada GJ GD GRD 308 GRD .gd +Guadeloupe GP GP GLP 312 GLP .gp +Guam GQ GU GUM 316 GUM .gu +Guatemala GT GT GTM 320 GTM .gt +Guernsey GK GG GGY 831 UK .gg +Guinea GV GN GIN 324 GIN .gn +Guinea-Bissau PU GW GNB 624 GNB .gw +Guyana GY GY GUY 328 GUY .gy +Haiti HA HT HTI 332 HTI .ht +Heard Island and McDonald Islands HM HM HMD 334 HMD .hm +Holy See (Vatican City) VT VA VAT 336 VAT .va +Honduras HO HN HND 340 HND .hn +Hong Kong HK HK HKG 344 HKG .hk +Howland Island HQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Hungary HU HU HUN 348 HUN .hu +Iceland IC IS ISL 352 ISL .is +India IN IN IND 356 IND .in +Indonesia ID ID IDN 360 IDN .id +Iran IR IR IRN 364 IRN .ir +Iraq IZ IQ IRQ 368 IRQ .iq +Ireland EI IE IRL 372 IRL .ie +Isle of Man IM IM IMN 833 UK .im +Israel IS IL ISR 376 ISR .il +Italy IT IT ITA 380 ITA .it +Jamaica JM JM JAM 388 JAM .jm +Jan Mayen JN SJ SJM 744 SJM - ISO includes with Svalbard +Japan JA JP JPN 392 JPN .jp +Jarvis Island DQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Jersey JE JE JEY 832 UK .je +Johnston Atoll JQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Jordan JO JO JOR 400 JOR .jo +Juan de Nova Island JU - - - - - administered as part of French Southern and Antarctic Lands; no ISO codes assigned +Kazakhstan KZ KZ KAZ 398 KAZ .kz +Kenya KE KE KEN 404 KEN .ke +Kingman Reef KQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Kiribati KR KI KIR 296 KIR .ki +Korea, North KN KP PRK 408 PRK .kp +Korea, South KS KR KOR 410 KOR .kr +Kosovo KV - - - - - ISO codes have not been designated +Kuwait KU KW KWT 414 KWT .kw +Kyrgyzstan KG KG KGZ 417 KGZ .kg +Laos LA LA LAO 418 LAO .la +Latvia LG LV LVA 428 LVA .lv +Lebanon LE LB LBN 422 LBN .lb +Lesotho LT LS LSO 426 LSO .ls +Liberia LI LR LBR 430 LBR .lr +Libya LY LY LBY 434 LBY .ly +Liechtenstein LS LI LIE 438 LIE .li +Lithuania LH LT LTU 440 LTU .lt +Luxembourg LU LU LUX 442 LUX .lu +Macau MC MO MAC 446 MAC .mo +Macedonia MK MK MKD 807 FYR .mk +Madagascar MA MG MDG 450 MDG .mg +Malawi MI MW MWI 454 MWI .mw +Malaysia MY MY MYS 458 MYS .my +Maldives MV MV MDV 462 MDV .mv +Mali ML ML MLI 466 MLI .ml +Malta MT MT MLT 470 MLT .mt +Marshall Islands RM MH MHL 584 MHL .mh +Martinique MB MQ MTQ 474 MTQ .mq +Mauritania MR MR MRT 478 MRT .mr +Mauritius MP MU MUS 480 MUS .mu +Mayotte MF YT MYT 175 FRA .yt +Mexico MX MX MEX 484 MEX .mx +Micronesia, Federated States of FM FM FSM 583 FSM .fm +Midway Islands MQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Moldova MD MD MDA 498 MDA .md +Monaco MN MC MCO 492 MCO .mc +Mongolia MG MN MNG 496 MNG .mn +Montenegro MJ ME MNE 499 MNE .me +Montserrat MH MS MSR 500 MSR .ms +Morocco MO MA MAR 504 MAR .ma +Mozambique MZ MZ MOZ 508 MOZ .mz +Myanmar - - - - - - see Burma +Namibia WA NA NAM 516 NAM .na +Nauru NR NR NRU 520 NRU .nr +Navassa Island BQ UM UMI 581 US - ISO includes with the US Minor Outlying Islands +Nepal NP NP NPL 524 NPL .np +Netherlands NL NL NLD 528 NLD .nl +Netherlands Antilles NT AN ANT 530 ANT .an +New Caledonia NC NC NCL 540 NCL .nc +New Zealand NZ NZ NZL 554 NZL .nz +Nicaragua NU NI NIC 558 NIC .ni +Niger NG NE NER 562 NER .ne +Nigeria NI NG NGA 566 NGA .ng +Niue NE NU NIU 570 NIU .nu +Norfolk Island NF NF NFK 574 NFK .nf +Northern Mariana Islands CQ MP MNP 580 MNP .mp +Norway NO NO NOR 578 NOR .no +Oman MU OM OMN 512 OMN .om +Pakistan PK PK PAK 586 PAK .pk +Palau PS PW PLW 585 PLW .pw +Palmyra Atoll LQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Panama PM PA PAN 591 PAN .pa +Papua New Guinea PP PG PNG 598 PNG .pg +Paracel Islands PF - - - - - +Paraguay PA PY PRY 600 PRY .py +Peru PE PE PER 604 PER .pe +Philippines RP PH PHL 608 PHL .ph +Pitcairn Islands PC PN PCN 612 PCN .pn +Poland PL PL POL 616 POL .pl +Portugal PO PT PRT 620 PRT .pt +Puerto Rico RQ PR PRI 630 PRI .pr +Qatar QA QA QAT 634 QAT .qa +Reunion RE RE REU 638 REU .re +Romania RO RO ROU 642 ROU .ro +Russia RS RU RUS 643 RUS .ru +Rwanda RW RW RWA 646 RWA .rw +Saint Barthelemy TB BL BLM 652 - .bl ccTLD .fr and .gp may also be used +Saint Helena SH SH SHN 654 SHN .sh +Saint Kitts and Nevis SC KN KNA 659 KNA .kn +Saint Lucia ST LC LCA 662 LCA .lc +Saint Martin RN MF MAF 663 - .mf ccTLD .fr and .gp may also be used +Saint Pierre and Miquelon SB PM SPM 666 SPM .pm +Saint Vincent and the Grenadines VC VC VCT 670 VCT .vc +Samoa WS WS WSM 882 WSM .ws +San Marino SM SM SMR 674 SMR .sm +Sao Tome and Principe TP ST STP 678 STP .st +Saudi Arabia SA SA SAU 682 SAU .sa +Senegal SG SN SEN 686 SEN .sn +Serbia RI RS SRB 688 - .rs +Seychelles SE SC SYC 690 SYC .sc +Sierra Leone SL SL SLE 694 SLE .sl +Singapore SN SG SGP 702 SGP .sg +Slovakia LO SK SVK 703 SVK .sk +Slovenia SI SI SVN 705 SVN .si +Solomon Islands BP SB SLB 090 SLB .sb +Somalia SO SO SOM 706 SOM .so +South Africa SF ZA ZAF 710 ZAF .za +South Georgia and the Islands SX GS SGS 239 SGS .gs +Spain SP ES ESP 724 ESP .es +Spratly Islands PG - - - - - +Sri Lanka CE LK LKA 144 LKA .lk +Sudan SU SD SDN 736 SDN .sd +Suriname NS SR SUR 740 SUR .sr +Svalbard SV SJ SJM 744 SJM .sj ISO includes Jan Mayen +Swaziland WZ SZ SWZ 748 SWZ .sz +Sweden SW SE SWE 752 SWE .se +Switzerland SZ CH CHE 756 CHE .ch +Syria SY SY SYR 760 SYR .sy +Taiwan TW TW TWN 158 TWN .tw +Tajikistan TI TJ TJK 762 TJK .tj +Tanzania TZ TZ TZA 834 TZA .tz +Thailand TH TH THA 764 THA .th +Timor-Leste TT TL TLS 626 TLS .tl +Togo TO TG TGO 768 TGO .tg +Tokelau TL TK TKL 772 TKL .tk +Tonga TN TO TON 776 TON .to +Trinidad and Tobago TD TT TTO 780 TTO .tt +Tromelin Island TE - - - - - administered as part of French Southern and Antarctic Lands; no ISO codes assigned +Tunisia TS TN TUN 788 TUN .tn +Turkey TU TR TUR 792 TUR .tr +Turkmenistan TX TM TKM 795 TKM .tm +Turks and Caicos Islands TK TC TCA 796 TCA .tc +Tuvalu TV TV TUV 798 TUV .tv +Uganda UG UG UGA 800 UGA .ug +Ukraine UP UA UKR 804 UKR .ua +United Arab Emirates AE AE ARE 784 ARE .ae +United Kingdom UK GB GBR 826 GBR .uk For its internal communications, the European Union recommends the use of the code UK in lieu of the ISO 3166-2 code of GB +United States US US USA 840 USA .us +United States Minor Outlying Islands - UM UMI 581 - .um ISO includes Baker Island, Howland Island, Jarvis Island, Johnston Atoll, Kingman Reef, Midway Islands, Navassa Island, Palmyra Atoll, Wake Island +Uruguay UY UY URY 858 URY .uy +Uzbekistan UZ UZ UZB 860 UZB .uz +Vanuatu NH VU VUT 548 VUT .vu +Venezuela VE VE VEN 862 VEN .ve +Vietnam VM VN VNM 704 VNM .vn +Virgin Islands VQ VI VIR 850 VIR .vi +Virgin Islands (UK) - - - - - .vg see British Virgin Islands +Virgin Islands (US) - - - - - .vi see Virgin Islands +Wake Island WQ UM UMI 581 UMI - ISO includes with the US Minor Outlying Islands +Wallis and Futuna WF WF WLF 876 WLF .wf +West Bank WE PS PSE 275 PSE .ps ISO identifies as Occupied Palestinian Territory +Western Sahara WI EH ESH 732 ESH .eh +Western Samoa - - - - - .ws see Samoa +Yemen YM YE YEM 887 YEM .ye +Zaire - - - - - - see Democratic Republic of the Congo +Zambia ZA ZM ZMB 894 ZMB .zm +Zimbabwe ZI ZW ZWE 716 ZWE .zw diff --git a/src/main/resources/data/por/stopwords.txt b/src/main/resources/data/por/stopwords.txt new file mode 100644 index 0000000..6b24778 --- /dev/null +++ b/src/main/resources/data/por/stopwords.txt @@ -0,0 +1,203 @@ +de +a +o +que +e +do +da +em +um +para +com +não +uma +os +no +se +na +por +mais +as +dos +como +mas +ao +ele +das +à +seu +sua +ou +quando +muito +nos +já +eu +também +só +pelo +pela +até +isso +ela +entre +depois +sem +mesmo +aos +seus +quem +nas +me +esse +eles +você +essa +num +nem +suas +meu +às +minha +numa +pelos +elas +qual +nós +lhe +deles +essas +esses +pelas +este +dele +tu +te +vocês +vos +lhes +meus +minhas +teu +tua +teus +tuas +nosso +nossa +nossos +nossas +dela +delas +esta +estes +estas +aquele +aquela +aqueles +aquelas +isto +aquilo +estou +está +estamos +estão +estive +esteve +estivemos +estiveram +estava +estávamos +estavam +estivera +estivéramos +esteja +estejamos +estejam +estivesse +estivéssemos +estivessem +estiver +estivermos +estiverem +hei +há +havemos +hão +houve +houvemos +houveram +houvera +houvéramos +haja +hajamos +hajam +houvesse +houvéssemos +houvessem +houver +houvermos +houverem +houverei +houverá +houveremos +houverão +houveria +houveríamos +houveriam +sou +somos +são +era +éramos +eram +fui +foi +fomos +foram +fora +fôramos +seja +sejamos +sejam +fosse +fôssemos +fossem +for +formos +forem +serei +será +seremos +serão +seria +seríamos +seriam +tenho +tem +temos +tém +tinha +tínhamos +tinham +tive +teve +tivemos +tiveram +tivera +tivéramos +tenha +tenhamos +tenham +tivesse +tivéssemos +tivessem +tiver +tivermos +tiverem +terei +terá +teremos +terão +teria +teríamos +teriam diff --git a/src/main/scala/opennlp/fieldspring/geolocate/CombinedModelCell.scala b/src/main/scala/opennlp/fieldspring/geolocate/CombinedModelCell.scala new file mode 100644 index 0000000..f608c7f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/CombinedModelCell.scala @@ -0,0 +1,77 @@ +/////////////////////////////////////////////////////////////////////////////// +// CombinedModelCellGrid.scala +// +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import opennlp.fieldspring.util.distances.spheredist +import opennlp.fieldspring.util.distances.SphereCoord +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.printutil.{errprint, warning} + +class CombinedModelCellGrid(table: SphereDocumentTable, + models: Seq[SphereCellGrid]) + extends SphereCellGrid(table) { + + override var total_num_cells: Int = models.map(_.total_num_cells).sum + override val num_training_passes: Int = models.map(_.num_training_passes).max + + var current_training_pass: Int = 0 + + override def begin_training_pass(pass: Int) = { + current_training_pass = pass + for (model <- models) { + if (pass <= model.num_training_passes) { + model.begin_training_pass(pass) + } + } + } + + def find_best_cell_for_document(doc: SphereDocument, + create_non_recorded: Boolean) = { + val candidates = + models.map(_.find_best_cell_for_document(doc, create_non_recorded)) + .filter(_ != null) + candidates.minBy((cell: SphereCell) => + spheredist(cell.get_center_coord, doc.coord)) + } + + def add_document_to_cell(document: SphereDocument) { + for (model <- models) { + if (current_training_pass <= model.num_training_passes) { + model.add_document_to_cell(document) + } + } + } + + def initialize_cells() { + } + + override def finish() { + for (model <- models) { + model.finish() + } + num_non_empty_cells = models.map(_.num_non_empty_cells).sum + } + + def iter_nonempty_cells(nonempty_word_dist: Boolean = false): Iterable[SphereCell] = { + models.map(_.iter_nonempty_cells(nonempty_word_dist)) + .reduce(_ ++ _) + } +} + + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/GenerateKML.scala b/src/main/scala/opennlp/fieldspring/geolocate/GenerateKML.scala new file mode 100644 index 0000000..f9b42da --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/GenerateKML.scala @@ -0,0 +1,191 @@ +/////////////////////////////////////////////////////////////////////////////// +// GenerateKML.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import java.io.{FileSystem=>_,_} + +import org.apache.hadoop.io._ + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.gridlocate.DistDocument +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import opennlp.fieldspring.worddist._ +import WordDist.memoizer._ + +class KMLParameters { + // Minimum and maximum colors + // FIXME: Allow these to be specified by command-line options + val kml_mincolor = Array(255.0, 255.0, 0.0) // yellow + val kml_maxcolor = Array(255.0, 0.0, 0.0) // red + + var kml_max_height: Double = _ + + var kml_transform: String = _ +} + +class GenerateKMLParameters( + parser: ArgParser = null +) extends GeolocateParameters(parser) { + //// Options used only in KML generation (--mode=generate-kml) + var kml_words = + ap.option[String]("k", "kml-words", "kw", + help = """Words to generate KML distributions for, when +--mode=generate-kml. Each word should be separated by a comma. A separate +file is generated for each word, using the value of '--kml-prefix' and adding +'.kml' to it.""") + // Same as above but a sequence + var split_kml_words:Seq[String] = _ + var kml_prefix = + ap.option[String]("kml-prefix", "kp", + default = "kml-dist.", + help = """Prefix to use for KML files outputted in --mode=generate-kml. +The actual filename is created by appending the word, and then the suffix +'.kml'. Default '%default'.""") + var kml_transform = + ap.option[String]("kml-transform", "kt", "kx", + default = "none", + choices = Seq("none", "log", "logsquared"), + help = """Type of transformation to apply to the probabilities +when generating KML (--mode=generate-kml), possibly to try and make the +low values more visible. Possibilities are 'none' (no transformation), +'log' (take the log), and 'logsquared' (negative of squared log). Default +'%default'.""") + var kml_max_height = + ap.option[Double]("kml-max-height", "kmh", + default = 2000000.0, + help = """Height of highest bar, in meters. Default %default.""") +} + + +/* A constructor that filters the distributions to contain only the words we + care about, to save memory and time. */ +class FilterUnigramWordDistConstructor( + factory: WordDistFactory, + filter_words: Seq[String], + ignore_case: Boolean, + stopwords: Set[String], + whitelist: Set[String], + minimum_word_count: Int = 1 + ) extends DefaultUnigramWordDistConstructor( + factory, ignore_case, stopwords, whitelist, minimum_word_count + ) { + + override def finish_before_global(dist: WordDist) { + super.finish_before_global(dist) + + val model = dist.asInstanceOf[UnigramWordDist].model + val oov = memoize_string("-OOV-") + + // Filter the words we don't care about, to save memory and time. + for ((word, count) <- model.iter_items + if !(filter_words contains unmemoize_string(word))) { + model.remove_item(word) + model.add_item(oov, count) + } + } +} + +class WordCellTupleWritable extends + WritableComparable[WordCellTupleWritable] { + var word: String = _ + var index: RegularCellIndex = _ + + def set(word: String, index: RegularCellIndex) { + this.word = word + this.index = index + } + + def write(out: DataOutput) { + out.writeUTF(word) + out.writeInt(index.latind) + out.writeInt(index.longind) + } + + def readFields(in: DataInput) { + word = in.readUTF() + val latind = in.readInt() + val longind = in.readInt() + index = RegularCellIndex(latind, longind) + } + + // It hardly matters how we compare the cell indices. + def compareTo(other: WordCellTupleWritable) = + word.compareTo(other.word) +} + +class GenerateKMLDriver extends + GeolocateDriver with StandaloneExperimentDriverStats { + type TParam = GenerateKMLParameters + type TRunRes = Unit + + override def handle_parameters() { + super.handle_parameters() + need(params.kml_words, "kml-words") + params.split_kml_words = params.kml_words.split(',') + } + + override protected def initialize_word_dist_constructor( + factory: WordDistFactory) = { + if (word_dist_type != "unigram") + param_error("Only unigram word distributions supported with GenerateKML") + val the_stopwords = get_stopwords() + val the_whitelist = get_whitelist() + new FilterUnigramWordDistConstructor( + factory, + params.split_kml_words, + ignore_case = !params.preserve_case_words, + stopwords = the_stopwords, + whitelist = the_whitelist, + minimum_word_count = params.minimum_word_count) + } + + /** + * Do the actual KML generation. Some tracking info written to stderr. + * KML files created and written on disk. + */ + + def run_after_setup() { + val cdist_factory = new SphereCellDistFactory(params.lru_cache_size) + for (word <- params.split_kml_words) { + val celldist = cdist_factory.get_cell_dist(cell_grid, memoize_string(word)) + if (!celldist.normalized) { + warning("""Non-normalized distribution, apparently word %s not seen anywhere. +Not generating an empty KML file.""", word) + } else { + val kmlparams = new KMLParameters() + kmlparams.kml_max_height = params.kml_max_height + kmlparams.kml_transform = params.kml_transform + celldist.generate_kml_file("%s%s.kml" format (params.kml_prefix, word), + kmlparams) + } + } + } +} + +object GenerateKMLApp extends GeolocateApp("generate-kml") { + type TDriver = GenerateKMLDriver + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/Geolocate.scala b/src/main/scala/opennlp/fieldspring/geolocate/Geolocate.scala new file mode 100644 index 0000000..aad6ad5 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/Geolocate.scala @@ -0,0 +1,664 @@ +/////////////////////////////////////////////////////////////////////////////// +// Geolocate.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// Copyright (C) 2011 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import util.matching.Regex +import util.Random +import math._ +import collection.mutable + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil.{FileHandler, LocalFileHandler} +import opennlp.fieldspring.util.osutil.output_resource_usage +import opennlp.fieldspring.util.printutil.errprint + +import opennlp.fieldspring.gridlocate._ +import GridLocateDriver.Debug._ + +import opennlp.fieldspring.worddist.{WordDist,WordDistFactory} +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +/* + +This module is the main driver module for the Geolocate subproject. +The Geolocate subproject does document-level geolocation and is part +of Fieldspring. An underlying GridLocate framework is provided +for doing work of various sorts with documents that are amalgamated +into grids (e.g. over the Earth or over dates or times). This means +that code for Geolocate is split between `fieldspring.geolocate` and +`fieldspring.gridlocate`. See also GridLocate.scala. + +The Geolocate code works as follows: + +-- The main entry class is GeolocateDocumentApp. This is hooked into + GeolocateDocumentDriver. The driver classes implement the logic for + running the program -- in fact, this logic in the superclass + GeolocateDocumentTypeDriver so that a separate Hadoop driver can be + provided. The separate driver class is provided so that we can run + the geolocate app and other Fieldspring apps programmatically as well + as from the command line, and the complication of multiple driver + classes is (at least partly) due to supporting various apps, e.g. + GenerateKML (a separate app for generating KML files graphically + illustrating the corpus). The mechanism that implements the driver + class is in fieldspring.util.experiment. The actual entry point is + in ExperimentApp.main(), although the entire implementation is in + ExperimentApp.implement_main(). +-- The class GeolocateDocumentParameters holds descriptions of all of the + various command-line parameters, as well as the values of those + parameters when read from the command line (or alternatively, filled in + by another program using the programmatic interface). This inherits + from GeolocateParameters, which supplies parameters common to other + Fieldspring apps. Argument parsing is handled using + fieldspring.util.argparser, a custom argument-parsing package built on + top of Argot. +-- The driver class has three main methods. `handle_parameters` verifies + that valid combinations of parameters were specified. `setup_for_run` + creates some internal structures necessary for running, and + `run_after_setup` does the actual running. The reason for the separation + of the two is that only the former is used by the Hadoop driver. + (FIXME: Perhaps there's a better way of handling this.) + +In order to support all the various command-line parameters, the logic for +doing geolocation is split up into various classes: + +-- Classes exist in `gridlocate` for an individual document (DistDocument), + the table of all documents (DistDocumentTable), the grid containing cells + into which the documents are placed (CellGrid), and the individual cells + in the grid (GeoCell). There also needs to be a class specifying a + coordinate identifying a document (e.g. time or latitude/longitude pair). + Specific versions of all of these are created for Geolocate, identified + by the word "Sphere" (SphereDocument, SphereCell, SphereCoord, etc.), + which is intended to indicate the fact that the grid refers to locations + on the surface of a sphere. +-- The cell grid class SphereGrid has subclasses for the different types of + grids (MultiRegularCellGrid, KDTreeCellGrid). +-- Different types of strategy objects (subclasses of + GeolocateDocumentStrategy, in turn a subclass of GridLocateDocumentStrategy) + implement the different inference methods specified using `--strategy`, + e.g. KLDivergenceStrategy or NaiveBayesDocumentStrategy. The driver method + `setup_for_run` creates the necessary strategy objects. +-- Evaluation is performed using different CellGridEvaluator objects, e.g. + RankedSphereCellGridEvaluator and MeanShiftSphereCellGridEvaluator. +*/ + +///////////////////////////////////////////////////////////////////////////// +// Evaluation strategies // +///////////////////////////////////////////////////////////////////////////// + +abstract class GeolocateDocumentStrategy( + sphere_grid: SphereCellGrid +) extends GridLocateDocumentStrategy[SphereCell, SphereCellGrid](sphere_grid) { } + +class CellDistMostCommonToponymGeolocateDocumentStrategy( + sphere_grid: SphereCellGrid +) extends GeolocateDocumentStrategy(sphere_grid) { + val cdist_factory = + new SphereCellDistFactory(sphere_grid.table.driver.params.lru_cache_size) + + def return_ranked_cells(_word_dist: WordDist, include: Iterable[SphereCell]) = { + val word_dist = UnigramStrategy.check_unigram_dist(_word_dist) + val wikipedia_table = sphere_grid.table.wikipedia_subtable + + // Look for a toponym, then a proper noun, then any word. + // FIXME: How can 'word' be null? + // FIXME: Use invalid_word + // FIXME: Should predicate be passed an index and have to do its own + // unmemoizing? + var maxword = word_dist.find_most_common_word( + word => word(0).isUpper && wikipedia_table.word_is_toponym(word)) + if (maxword == None) { + maxword = word_dist.find_most_common_word( + word => word(0).isUpper) + } + if (maxword == None) + maxword = word_dist.find_most_common_word(word => true) + cdist_factory.get_cell_dist(sphere_grid, maxword.get). + get_ranked_cells(include) + } +} + +class LinkMostCommonToponymGeolocateDocumentStrategy( + sphere_grid: SphereCellGrid +) extends GeolocateDocumentStrategy(sphere_grid) { + def return_ranked_cells(_word_dist: WordDist, include: Iterable[SphereCell]) = { + val word_dist = UnigramStrategy.check_unigram_dist(_word_dist) + val wikipedia_table = sphere_grid.table.wikipedia_subtable + + var maxword = word_dist.find_most_common_word( + word => word(0).isUpper && wikipedia_table.word_is_toponym(word)) + if (maxword == None) { + maxword = word_dist.find_most_common_word( + word => wikipedia_table.word_is_toponym(word)) + } + if (debug("commontop")) + errprint(" maxword = %s", maxword) + val cands = + if (maxword != None) + wikipedia_table.construct_candidates( + unmemoize_string(maxword.get)) + else Seq[SphereDocument]() + if (debug("commontop")) + errprint(" candidates = %s", cands) + // Sort candidate list by number of incoming links + val candlinks = + (for (cand <- cands) yield (cand, + cand.asInstanceOf[WikipediaDocument].adjusted_incoming_links.toDouble)). + // sort by second element of tuple, in reverse order + sortWith(_._2 > _._2) + if (debug("commontop")) + errprint(" sorted candidates = %s", candlinks) + + def find_good_cells_for_coord(cands: Iterable[(SphereDocument, Double)]) = { + for { + (cand, links) <- candlinks + val cell = { + val retval = sphere_grid.find_best_cell_for_document(cand, false) + if (retval == null) + errprint("Strange, found no cell for candidate %s", cand) + retval + } + if (cell != null) + } yield (cell, links) + } + + // Convert to cells + val candcells = find_good_cells_for_coord(candlinks) + + if (debug("commontop")) + errprint(" cell candidates = %s", candcells) + + // Append random cells and remove duplicates + merge_numbered_sequences_uniquely(candcells, + new RandomGridLocateDocumentStrategy[SphereCell, SphereCellGrid](sphere_grid). + return_ranked_cells(word_dist, include)) + } +} + +class SphereAverageCellProbabilityStrategy( + sphere_grid: SphereCellGrid +) extends AverageCellProbabilityStrategy[ + SphereCell, SphereCellGrid +](sphere_grid) { + type TCellDistFactory = SphereCellDistFactory + def create_cell_dist_factory(lru_cache_size: Int) = + new SphereCellDistFactory(lru_cache_size) +} + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +/** + * Class retrieving command-line arguments or storing programmatic + * configuration parameters. + * + * @param parser If specified, should be a parser for retrieving the + * value of command-line arguments from the command line. Provided + * that the parser has been created and initialized by creating a + * previous instance of this same class with the same parser (a + * "shadow field" class), the variables below will be initialized with + * the values given by the user on the command line. Otherwise, they + * will be initialized with the default values for the parameters. + * Because they are vars, they can be freely set to other values. + * + */ +class GeolocateParameters(parser: ArgParser = null) extends + GridLocateParameters(parser) { + //// Options indicating how to generate the cells we compare against + var degrees_per_cell = + ap.option[Double]("degrees-per-cell", "dpc", metavar = "DEGREES", + default = 1.0, + help = """Size (in degrees, a floating-point number) of the tiling +cells that cover the Earth. Default %default. """) + var miles_per_cell = + ap.option[Double]("miles-per-cell", "mpc", metavar = "MILES", + help = """Size (in miles, a floating-point number) of the tiling +cells that cover the Earth. If given, it overrides the value of +--degrees-per-cell. No default, as the default of --degrees-per-cell +is used.""") + var km_per_cell = + ap.option[Double]("km-per-cell", "kpc", metavar = "KM", + help = """Size (in kilometers, a floating-point number) of the tiling +cells that cover the Earth. If given, it overrides the value of +--degrees-per-cell. No default, as the default of --degrees-per-cell +is used.""") + var width_of_multi_cell = + ap.option[Int]("width-of-multi-cell", metavar = "CELLS", default = 1, + help = """Width of the cell used to compute a statistical +distribution for geolocation purposes, in terms of number of tiling cells. +NOTE: It's unlikely you want to change this. It may be removed entirely in +later versions. In normal circumstances, the value is 1, i.e. use a single +tiling cell to compute each multi cell. If the value is more than +1, the multi cells overlap.""") + + //// Options for using KD trees, and related parameters + var kd_tree = + ap.flag("kd-tree", "kd", "kdtree", + help = """Specifies we should use a KD tree rather than uniform +grid cell.""") + + var kd_bucket_size = + ap.option[Int]("kd-bucket-size", "kdbs", "bucket-size", default = 200, + metavar = "INT", + help = """Bucket size before splitting a leaf into two children. +Default %default.""") + + var center_method = + ap.option[String]("center-method", "cm", metavar = "CENTER_METHOD", + default = "centroid", + choices = Seq("centroid", "center"), + help = """Chooses whether to use center or centroid for cell +center calculation. Options are either 'centroid' or 'center'. +Default '%default'.""") + + var kd_split_method = + ap.option[String]("kd-split-method", "kdsm", metavar = "SPLIT_METHOD", + default = "halfway", + choices = Seq("halfway", "median", "maxmargin"), + help = """Chooses which leaf-splitting method to use. Valid options are +'halfway', which splits into two leaves of equal degrees, 'median', which +splits leaves to have an equal number of documents, and 'maxmargin', +which splits at the maximum margin between two points. All splits are always +on the longest dimension. Default '%default'.""") + + var kd_use_backoff = + ap.flag("kd-backoff", "kd-use-backoff", + help = """Specifies if we should back off to larger cell distributions.""") + + var kd_interpolate_weight = + ap.option[Double]("kd-interpolate-weight", "kdiw", default = 0.0, + help = """Specifies the weight given to parent distributions. +Default value '%default' means no interpolation is used.""") + + //// Combining the kd-tree model with the cell-grid model + val combined_kd_grid = + ap.flag("combined-kd-grid", help = """Combine both the KD tree and +uniform grid cell models?""") + +} + +trait GeolocateDriver extends GridLocateDriver { + type TDoc = SphereDocument + type TCell = SphereCell + type TGrid = SphereCellGrid + type TDocTable = SphereDocumentTable + override type TParam <: GeolocateParameters + var degrees_per_cell = 0.0 + + override def handle_parameters() { + super.handle_parameters() + if (params.miles_per_cell < 0) + param_error("Miles per cell must be positive if specified") + if (params.km_per_cell < 0) + param_error("Kilometers per cell must be positive if specified") + if (params.degrees_per_cell < 0) + param_error("Degrees per cell must be positive if specified") + if (params.miles_per_cell > 0 && params.km_per_cell > 0) + param_error("Only one of --miles-per-cell and --km-per-cell can be given") + degrees_per_cell = + if (params.miles_per_cell > 0) + params.miles_per_cell / miles_per_degree + else if (params.km_per_cell > 0) + params.km_per_cell / km_per_degree + else + params.degrees_per_cell + if (params.width_of_multi_cell <= 0) + param_error("Width of multi cell must be positive") + } + + protected def initialize_document_table(word_dist_factory: WordDistFactory) = { + new SphereDocumentTable(this, word_dist_factory) + } + + protected def initialize_cell_grid(table: SphereDocumentTable) = { + if (params.combined_kd_grid) { + val kdcg = + KdTreeCellGrid(table, params.kd_bucket_size, params.kd_split_method, + params.kd_use_backoff, params.kd_interpolate_weight) + val mrcg = + new MultiRegularCellGrid(degrees_per_cell, + params.width_of_multi_cell, table) + new CombinedModelCellGrid(table, Seq(mrcg, kdcg)) + } else if (params.kd_tree) { + KdTreeCellGrid(table, params.kd_bucket_size, params.kd_split_method, + params.kd_use_backoff, params.kd_interpolate_weight) + } else { + new MultiRegularCellGrid(degrees_per_cell, + params.width_of_multi_cell, table) + } + } +} + +class GeolocateDocumentParameters( + parser: ArgParser = null +) extends GeolocateParameters(parser) { + var eval_format = + ap.option[String]("f", "eval-format", + default = "internal", + choices = Seq("internal", "raw-text", "pcl-travel"), + help = """Format of evaluation file(s). The evaluation files themselves +are specified using --eval-file. The following formats are +recognized: + +'internal' is the normal format. It means to consider documents to be +documents to evaluate, and to use the development or test set specified +in the document file as the set of documents to evaluate. There is +no eval file for this format. + +'raw-text' assumes that the eval file is simply raw text. (NOT YET +IMPLEMENTED.) + +'pcl-travel' is another alternative. It assumes that each evaluation file +is in PCL-Travel XML format, and uses each chapter in the evaluation +file as a document to evaluate.""") + + var strategy = + ap.multiOption[String]("s", "strategy", + default = Seq("partial-kl-divergence"), + aliasedChoices = Seq( + Seq("baseline"), + Seq("none"), + Seq("full-kl-divergence", "full-kldiv", "full-kl"), + Seq("partial-kl-divergence", "partial-kldiv", "partial-kl", "part-kl"), + Seq("symmetric-full-kl-divergence", "symmetric-full-kldiv", + "symmetric-full-kl", "sym-full-kl"), + Seq("symmetric-partial-kl-divergence", + "symmetric-partial-kldiv", "symmetric-partial-kl", "sym-part-kl"), + Seq("cosine-similarity", "cossim"), + Seq("partial-cosine-similarity", "partial-cossim", "part-cossim"), + Seq("smoothed-cosine-similarity", "smoothed-cossim"), + Seq("smoothed-partial-cosine-similarity", "smoothed-partial-cossim", + "smoothed-part-cossim"), + Seq("average-cell-probability", "avg-cell-prob", "acp"), + Seq("naive-bayes-with-baseline", "nb-base"), + Seq("naive-bayes-no-baseline", "nb-nobase")), + help = """Strategy/strategies to use for geolocation. +'baseline' means just use the baseline strategy (see --baseline-strategy). + +'none' means don't do any geolocation. Useful for testing the parts that +read in data and generate internal structures. + +'full-kl-divergence' (or 'full-kldiv') searches for the cell where the KL +divergence between the document and cell is smallest. + +'partial-kl-divergence' (or 'partial-kldiv') is similar but uses an +abbreviated KL divergence measure that only considers the words seen in the +document; empirically, this appears to work just as well as the full KL +divergence. + +'average-cell-probability' (or 'celldist') involves computing, for each word, +a probability distribution over cells using the word distribution of each cell, +and then combining the distributions over all words in a document, weighted by +the count the word in the document. + +'naive-bayes-with-baseline' and 'naive-bayes-no-baseline' use the Naive +Bayes algorithm to match a test document against a training document (e.g. +by assuming that the words of the test document are independent of each +other, if we are using a unigram word distribution). The "baseline" is +currently + +Default is 'partial-kl-divergence'. + +NOTE: Multiple --strategy options can be given, and each strategy will +be tried, one after the other.""") + + var coord_strategy = + ap.option[String]("coord-strategy", "cs", + default = "top-ranked", + choices = Seq("top-ranked", "mean-shift"), + help = """Strategy/strategies to use to choose the best coordinate for +a document. + +'top-ranked' means to choose the single best-ranked cell according to the +scoring strategy specified using '--strategy', and use its central point. + +'mean-shift' means to take the K best cells (according to '--k-best'), +and then compute a single point using the mean-shift algorithm. This +algorithm works by steadily shifting each point towards the others by +computing an average of the points surrounding a given point, weighted +by a function that drops off rapidly as the distance from the point +increases (specifically, the weighting is the same as for a Gaussian density, +with a parameter H, specified using '--mean-shift-window', that corresponds to +the standard deviation in the Gaussian distribution function). The idea is +that the points will eventually converge on the largest cluster within the +original points. The algorithm repeatedly moves the points closer to each +other until either the total standard deviation of the points (i.e. +approximately the average distance of the points from their mean) is less than +the value specified by '--mean-shift-max-stddev', or the number of iterations +exceeds '--mean-shift-max-iterations'. + +Default '%default'.""") + + var k_best = + ap.option[Int]("k-best", "kb", + default = 10, + help = """Value of K for use in the mean-shift algorithm +(see '--coord-strategy'). For this value of K, we choose the K best cells +and then apply the mean-shift algorithm to the central points of those cells. + +Default '%default'.""") + + var mean_shift_window = + ap.option[Double]("mean-shift-window", "msw", + default = 1.0, + help = """Window to use in the mean-shift algorithm +(see '--coord-strategy'). + +Default '%default'.""") + + var mean_shift_max_stddev = + ap.option[Double]("mean-shift-max-stddev", "msms", + default = 1e-10, + help = """Maximum allowed standard deviation (i.e. approximately the +average distance of the points from their mean) among the points selected by +the mean-shift algorithm (see '--coord-strategy'). + +Default '%default'.""") + + var mean_shift_max_iterations = + ap.option[Int]("mean-shift-max-iterations", "msmi", + default = 100, + help = """Maximum number of iterations in the mean-shift algorithm +(see '--coord-strategy'). + +Default '%default'.""") + + var baseline_strategy = + ap.multiOption[String]("baseline-strategy", "bs", + default = Seq("internal-link"), + aliasedChoices = Seq( + Seq("internal-link", "link"), + Seq("random"), + Seq("num-documents", "numdocs", "num-docs"), + Seq("link-most-common-toponym"), + Seq("cell-distribution-most-common-toponym", + "celldist-most-common-toponym")), + help = """Strategy to use to compute the baseline. + +'internal-link' (or 'link') means use number of internal links pointing to the +document or cell. + +'random' means choose randomly. + +'num-documents' (or 'num-docs' or 'numdocs'; only in cell-type matching) means +use number of documents in cell. + +'link-most-common-toponym' means to look for the toponym that occurs the +most number of times in the document, and then use the internal-link +baseline to match it to a location. + +'celldist-most-common-toponym' is similar, but uses the cell distribution +of the most common toponym. + +Default '%default'. + +NOTE: Multiple --baseline-strategy options can be given, and each strategy will +be tried, one after the other. Currently, however, the *-most-common-toponym +strategies cannot be mixed with other baseline strategies, or with non-baseline +strategies, since they require that --preserve-case-words be set internally.""") +} + +// FUCK ME. Have to make this abstract and GeolocateDocumentDriver a subclass +// so that the TParam can be overridden in HadoopGeolocateDocumentDriver. +trait GeolocateDocumentTypeDriver extends GeolocateDriver with + GridLocateDocumentDriver { + override type TParam <: GeolocateDocumentParameters + type TRunRes = + Seq[(String, GridLocateDocumentStrategy[SphereCell, SphereCellGrid], + CorpusEvaluator[_,_])] + + override def handle_parameters() { + super.handle_parameters() + + // The *-most-common-toponym strategies require case preserving + // (as if set by --preseve-case-words), while most other strategies want + // the opposite. So check to make sure we don't have a clash. + if (params.strategy contains "baseline") { + var need_case = false + var need_no_case = false + for (bstrat <- params.baseline_strategy) { + if (bstrat.endsWith("most-common-toponym")) + need_case = true + else + need_no_case = true + } + if (need_case) { + if (params.strategy.length > 1 || need_no_case) { + // That's because we have to set --preserve-case-words, which we + // generally don't want set for other strategies and which affects + // the way we construct the training-document distributions. + param_error("Can't currently mix *-most-common-toponym baseline strategy with other strategies") + } + params.preserve_case_words = true + } + } + + if (params.eval_format == "raw-text") { + // FIXME!!!! + param_error("Raw-text reading not implemented yet") + } + + if (params.eval_format == "internal") { + if (params.eval_file.length > 0) + param_error("--eval-file should not be given when --eval-format=internal") + } else + need_seq(params.eval_file, "eval-file", "evaluation file(s)") + } + + override def create_strategy(stratname: String) = { + stratname match { + case "link-most-common-toponym" => + new LinkMostCommonToponymGeolocateDocumentStrategy(cell_grid) + case "celldist-most-common-toponym" => + new CellDistMostCommonToponymGeolocateDocumentStrategy(cell_grid) + case "average-cell-probability" => + new SphereAverageCellProbabilityStrategy(cell_grid) + case other => super.create_strategy(other) + } + } + + def create_strategies() = { + val strats_unflat = + for (stratname <- params.strategy) yield { + if (stratname == "baseline") { + for (basestratname <- params.baseline_strategy) yield { + val strategy = create_strategy(basestratname) + ("baseline " + basestratname, strategy) + } + } else { + val strategy = create_strategy(stratname) + Seq((stratname, strategy)) + } + } + strats_unflat.flatten filter { case (name, strat) => strat != null } + } + + /** + * Create the document evaluator object used to evaluate a given + * document. + * + * @param strategy Strategy object that implements the mechanism for + * scoring different pseudodocuments against a document. The + * strategy computes a ranked list of all pseudodocuments, with + * corresponding scores. The document evaluator then uses this to + * finish evaluating the document (e.g. picking the top-ranked one, + * applying the mean-shift algorithm, etc.). + * @param stratname Name of the strategy. + */ + def create_document_evaluator( + strategy: GridLocateDocumentStrategy[SphereCell, SphereCellGrid], + stratname: String) = { + // Generate reader object + if (params.eval_format == "pcl-travel") + new PCLTravelGeolocateDocumentEvaluator(strategy, stratname, this) + else if (params.coord_strategy == "top-ranked") + new RankedSphereCellGridEvaluator(strategy, stratname, this) + else + new MeanShiftSphereCellGridEvaluator(strategy, stratname, this, + params.k_best, params.mean_shift_window, + params.mean_shift_max_stddev, + params.mean_shift_max_iterations) + } + + /** + * Do the actual document geolocation. Results to stderr (see above), and + * also returned. + * + * The current return type is as follows: + * + * Seq[(java.lang.String, GridLocateDocumentStrategy[SphereCell, SphereCellGrid], scala.collection.mutable.Map[evalobj.Document,opennlp.fieldspring.geolocate.EvaluationResult])] where val evalobj: opennlp.fieldspring.geolocate.CorpusEvaluator + * + * This means you get a sequence of tuples of + * (strategyname, strategy, results) + * where: + * strategyname = name of strategy as given on command line + * strategy = strategy object + * results = map listing results for each document (an abstract type + * defined in CorpusEvaluator; the result type EvaluationResult + * is practically an abstract type, too -- the most useful dynamic + * type in practice is DocumentEvaluationResult) + */ + def run_after_setup() = { + process_strategies(strategies)((stratname, strategy) => + create_document_evaluator(strategy, stratname)) + } +} + +class GeolocateDocumentDriver extends + GeolocateDocumentTypeDriver with StandaloneExperimentDriverStats { + override type TParam = GeolocateDocumentParameters +} + +abstract class GeolocateApp(appname: String) extends + GridLocateApp(appname) { + override type TDriver <: GeolocateDriver +} + +object GeolocateDocumentApp extends GeolocateApp("geolocate-document") { + type TDriver = GeolocateDocumentDriver + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/Hadoop.scala b/src/main/scala/opennlp/fieldspring/geolocate/Hadoop.scala new file mode 100644 index 0000000..2512e71 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/Hadoop.scala @@ -0,0 +1,469 @@ +/////////////////////////////////////////////////////////////////////////////// +// Hadoop.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import collection.JavaConversions._ + +import org.apache.hadoop.io._ +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.fs.Path + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.experiment.ExperimentMeteredTask +import opennlp.fieldspring.util.hadoop._ +import opennlp.fieldspring.util.ioutil.FileHandler +import opennlp.fieldspring.util.mathutil.{mean, median} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.gridlocate.{CellGridEvaluator,FieldspringInfo,DistDocumentFileProcessor} + +/* Basic idea for hooking up Geolocate with Hadoop. Hadoop works in terms + of key-value pairs, as follows: + + (1) A preprocessor generates key-value pairs, which are passed to hadoop. + Note that typically at this stage what's passed to Hadoop is not + in the form of a key-value pair but just some sort of item, e.g. a + line of text. This typically becomes the value of the key-value pair, + while something that most programs ignore becomes the key (e.g. the + item count of the item that was seen). Note that these keys and values, + as for all data passed around by Hadoop, is typed, and the type is + under the programmer's control. Hence, although the value is commonly + text, it may not be. + + (2) Hadoop creates a number of mappers and partitions the input items in + some way, passing some fraction of the input items to each mapper. + + (3) Each mapper iterates over its input items, and for each in turn, + generates a possibly-empty set of key-value output items. Note that + the types of the output keys and values may be totally different from + those of the input keys and values. The output key has much more + significance than the input key. + + (5) A "shuffle" step happens internally, where all output items are + grouped according to their keys, and the keys are further sorted. + (Or equivalently, the entire set of output items is sorted on their + keys, which will naturally place identical keys next to each other + as well as ensuring that non-identical keys are sorted, and then + sets of items with identical keys are transformed into single items + with the same key and a value consisting of a list of all of items + grouped together.) What actually happens is that the items are + sorted and grouped at the end of the map stage on the map node, + *before* being sent over the network; the overall "shuffle" then + simply involves merging. + + (4.5) To reduce the amount of data sent over the network, a combiner can + be defined, which runs on the map node after sorting but before + sending the data over. This is optional, and if it exists, does + a preliminary reduce. Depending on the task in question, this may + be exactly the same as the reducer. For example, if the reducer + simply adds up all of the items passed to it, the same function + can be used as a combiner, since a set of number can be added up + all at once or in parts. (An example of where this can't be done + is when the reducer needs to find the median of a set of items. + Computing the median involves selecting one of the items of a + set rather than mashing them all together, and which item is to + be selected cannot be known until the entire set is seen. Given + a subset, the median could be any value in the subset; hence the + entire subset must be sent along, and cannot in general be + "combined" in any way.) + + (6) A set of reducers are created, and the resulting grouped items are + partitioned based on their keys, with each reducer receiving one + sorted partition, i.e. a list of all the items whose keys were + assigned to that reducer, in sorted order, where the value of each + item (remember, items are key-value pairs) is a list of items + (all values associated with the same key in the items output by + the mapper). Each key is seen only once (assuming no crashes/restarts), + and only on a single reducer. The reducer then outputs its own + output pairs, typically by "reducing" the value list of each key + into a single item. + + (7) A post-processor might take these final output items and do something + with them. + + Note about types: + + In general: + + MAP STAGE: + + Mapper input is of type A -> B + Mapper output is of type C -> D + Hence map() is of type (A -> B) -> Iterable[(C -> D)] + + Often types B and C are identical or related. + + COMBINE STAGE: + + The combiner is strictly an optimization, and the program must work + correctly regardless of whether the combiner is run or not -- or, for + that matter, if run multiple times. This means that the input and + output types of the combiner must be the same, and in most cases + the combiner must be idempotent (i.e. if its input is a previous + output, it should output is input unchanged; in other words, it + does nothing if run multiple times on the same input). + + Combiner input is of type C -> Iterable[D] + Combiner output is of type C -> Iterable[D] + Hence combine() is of type (C -> Iterable[D]) -> (C -> Iterable[D]) + + (The output of the cominber is grouped, just like its input from the + map output.) + + REDUCE STAGE: + + Reducer input is of type C -> Iterable[D] + Reducer output is of type E -> F + Hence reduce() is of type (C -> Iterable[D]) -> (E -> F) + + In our case, we assume that the mappers() do the real work and the + reducers just collect the stats and combine them. We can break a + big job in two ways: Either by partitioning the set of test documents + and having each mapper do a full evaluation on a limited number of + test documents, or by partitioning the grid and have each mapper + compare all test documents against a portion of the grid. A third + possibility is to combine both, where a mapper does a portion of + the test documents against a portion of the grid. + + OUR IMPLEMENTATION: + + Input values to map() are tuples (strategy, document). Output items + are have key = (cellgrid-details, strategy), value = result for + particular document (includes various items, including document, + predicted cell, true rank, various distances). No combiner, since + we have to compute a median, meaning we need all values. Reducer + computes mean/median for all values for a given cellgrid/strategy. + NOTE: For identifying a particular cell, we use indices, since we + can't pass pointers. For KD trees and such, we conceivably might have + to pass in to the reducer some complex details identifying the + cell grid parameters. If so, this probably would get passed first + to all reducers using the trick of creating a custom partitioner + that ensures the reducer gets this info first. +*/ + +/************************************************************************/ +/* General Hadoop code for Geolocate app */ +/************************************************************************/ + +abstract class HadoopGeolocateApp( + progname: String +) extends GeolocateApp(progname) with HadoopTextDBApp { + override type TDriver <: HadoopGeolocateDriver + + def corpus_suffix = + driver.params.eval_set + "-" + driver.document_file_suffix + def corpus_dirs = params.input_corpus + + override def initialize_hadoop_input(job: Job) { + super.initialize_hadoop_input(job) + FileOutputFormat.setOutputPath(job, new Path(params.outfile)) + } +} + +trait HadoopGeolocateParameters extends GeolocateParameters { + var fieldspring_dir = + ap.option[String]("fieldspring-dir", + help = """Directory to use in place of FIELDSPRING_DIR environment +variable (e.g. in Hadoop).""") + + var outfile = + ap.positional[String]("outfile", + help = """File to store evaluation results in.""") + +} + +/** + * Base mix-in for a Geolocate application using Hadoop. + * + * @see HadoopGeolocateDriver + */ + +trait HadoopGeolocateDriver extends + GeolocateDriver with HadoopExperimentDriver { + override type TParam <: HadoopGeolocateParameters + + override def handle_parameters() { + super.handle_parameters() + need(params.fieldspring_dir, "fieldspring-dir") + FieldspringInfo.set_fieldspring_dir(params.fieldspring_dir) + } +} + +/************************************************************************/ +/* Hadoop implementation of geolocate-document */ +/************************************************************************/ + +class DocumentEvaluationMapper extends + Mapper[Object, Text, Text, DoubleWritable] with + HadoopExperimentMapReducer { + def progname = HadoopGeolocateDocumentApp.progname + type TContext = Mapper[Object, Text, Text, DoubleWritable]#Context + type TDriver = HadoopGeolocateDocumentDriver + // more type erasure crap + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver + + var evaluators: Iterable[CellGridEvaluator[SphereCoord,SphereDocument,_,_,_]] = _ + val task = new ExperimentMeteredTask(driver, "document", "evaluating") + + class HadoopDocumentFileProcessor( + context: TContext + ) extends DistDocumentFileProcessor( + driver.params.eval_set + "-" + driver.document_file_suffix, driver + ) { + override def get_shortfile = + filename_to_counter_name(driver.get_file_handler, + driver.get_configuration.get("mapred.input.dir")) + + /* #### FIXME!!! Need to redo things so that different splits are + separated into different files. */ + def handle_document(fieldvals: Seq[String]) = { + val table = driver.document_table + val doc = table.create_and_init_document(schema, fieldvals, false) + val retval = if (doc != null) { + doc.dist.finish_after_global() + var skipped = 0 + var not_skipped = 0 + for (e <- evaluators) { + val num_processed = task.num_processed + val doctag = "#%d" format (1 + num_processed) + if (e.would_skip_document(doc, doctag)) { + skipped += 1 + errprint("Skipped document %s because evaluator would skip it", + doc) + } else { + not_skipped += 1 + // Don't put side-effecting code inside of an assert! + val result = + e.evaluate_document(doc, doctag) + assert(result != null) + context.write(new Text(e.stratname), + new DoubleWritable(result.asInstanceOf[SphereDocumentEvaluationResult].pred_truedist)) + task.item_processed() + } + context.progress + } + if (skipped > 0 && not_skipped > 0) + warning("""Something strange: %s evaluator(s) skipped document, but %s evaluator(s) +didn't skip. Usually all or none should skip.""", skipped, not_skipped) + (not_skipped > 0) + } else false + context.progress + (retval, true) + } + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = + throw new IllegalStateException( + "process_lines should never be called here") + } + + var processor: HadoopDocumentFileProcessor = _ + override def init(context: TContext) { + super.init(context) + if (driver.params.eval_format != "internal") + driver.params.parser.error( + "For Hadoop, '--eval-format' must be 'internal'") + else { + evaluators = + for ((stratname, strategy) <- driver.strategies) + yield driver.create_document_evaluator(strategy, stratname). + asInstanceOf[CellGridEvaluator[ + SphereCoord,SphereDocument,_,_,_]] + if (driver.params.input_corpus.length != 1) { + driver.params.parser.error( + "FIXME: For Hadoop, currently need exactly one corpus") + } else { + processor = new HadoopDocumentFileProcessor(context) + processor.read_schema_from_textdb(driver.get_file_handler, + driver.params.input_corpus(0)) + context.progress + } + } + } + + override def setup(context: TContext) { init(context) } + + override def map(key: Object, value: Text, context: TContext) { + processor.parse_row(value.toString) + context.progress + } +} + +class DocumentResultReducer extends + Reducer[Text, DoubleWritable, Text, DoubleWritable] { + + type TContext = Reducer[Text, DoubleWritable, Text, DoubleWritable]#Context + + var driver: HadoopGeolocateDocumentDriver = _ + + override def setup(context: TContext) { + driver = new HadoopGeolocateDocumentDriver + driver.set_task_context(context) + } + + override def reduce(key: Text, values: java.lang.Iterable[DoubleWritable], + context: TContext) { + val errordists = (for (v <- values) yield v.get).toSeq + val mean_dist = mean(errordists) + val median_dist = median(errordists) + context.write(new Text(key.toString + " mean"), new DoubleWritable(mean_dist)) + context.write(new Text(key.toString + " median"), new DoubleWritable(median_dist)) + } +} + +class HadoopGeolocateDocumentParameters( + parser: ArgParser = null +) extends GeolocateDocumentParameters(parser) with HadoopGeolocateParameters { +} + +/** + * Class for running the geolocate-document app using Hadoop. + */ + +class HadoopGeolocateDocumentDriver extends + GeolocateDocumentTypeDriver with HadoopGeolocateDriver { + override type TParam = HadoopGeolocateDocumentParameters +} + +object HadoopGeolocateDocumentApp extends + HadoopGeolocateApp("Fieldspring geolocate-document") { + type TDriver = HadoopGeolocateDocumentDriver + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() + + def initialize_hadoop_classes(job: Job) { + job.setJarByClass(classOf[DocumentEvaluationMapper]) + job.setMapperClass(classOf[DocumentEvaluationMapper]) + job.setReducerClass(classOf[DocumentResultReducer]) + job.setOutputKeyClass(classOf[Text]) + job.setOutputValueClass(classOf[DoubleWritable]) + } +} + +// Old code. Probably won't ever be needed. If we feel the need to move +// to more complex types when serializing, we should switch to Avro rather +// than reinventing the wheel. + +// /** +// * Hadoop has a standard Writable class but it isn't so good for us, since +// * it assumes its read method +// */ +// trait HadoopGeolocateWritable[T] { +// def write(out: DataOutput): Unit +// def read(in: DataInput): T +// } +// +// /** +// * Class for writing out in a format suitable for Hadoop. Implements +// Hadoop's Writable interface. Because the +// */ +// +// abstract class RecordWritable() extends WritableComparable[RecordWritable] { +// } + +/* + +abstract class ObjectConverter { + type Type + type TWritable <: Writable + def makeWritable(): TWritable + def toWritable(obj: Type, w: TWritable) + def fromWritable(w: TWritable): obj +} + +object IntConverter { + type Type = Int + type TWritable = IntWritable + + def makeWritable() = new IntWritable + def toWritable(obj: Int, w: IntWritable) { w.set(obj) } + def fromWritable(w: TWritable) = w.get +} + +abstract class RecordWritable( + fieldtypes: Seq[Class] +) extends WritableComparable[RecordWritable] { + type Type + + var obj: Type = _ + var obj_set: Boolean = false + + def set(xobj: Type) { + obj = xobj + obj_set = true + } + + def get() = { + assert(obj_set) + obj + } + + def write(out: DataOutput) {} + def readFields(in: DataInput) {} + + val writables = new Array[Writable](fieldtypes.length) +} + + +object SphereDocumentConverter extends RecordWriterConverter { + type Type = SphereDocument + + def serialize(doc: SphereDocument) = doc.title + def deserialize(title: String) = FIXME + + def init() { + RecordWriterConverter.register_converter(SphereDocument, this) + } +} + + +class DocumentEvaluationResultWritable extends RecordWritable { + type Type = DocumentEvaluationResult + def to_properties(obj: Type) = + Seq(obj.document, obj.pred_cell, obj.true_rank, + obj.true_cell, obj.num_docs_in_true_cell, + obj.true_center, obj.true_truedist, obj.true_degdist, + obj.pred_center, obj.pred_truedist, obj.pred_degdist) + def from_properties(props: Seq[Any]) = { + val Seq(document, pred_cell, true_rank, + true_cell, num_docs_in_true_cell, + true_center, true_truedist, true_degdist, + pred_center, pred_truedist, pred_degdist) = props + new HadoopDocumentEvaluationResult( + document.asInstanceOf[SphereDocument], + pred_cell.asInstanceOf[GeoCell], + true_rank.asInstanceOf[Int], + true_cell.asInstanceOf[GeoCell], + num_docs_in_true_cell.asInstanceOf[Int], + true_center.asInstanceOf[SphereCoord], + true_truedist.asInstanceOf[Double], + true_degdist.asInstanceOf[Double], + pred_center.asInstanceOf[SphereCoord], + pred_truedist.asInstanceOf[Double], + pred_degdist.asInstanceOf[Double] + ) + } +} + +*/ diff --git a/src/main/scala/opennlp/fieldspring/geolocate/KDTreeCell.scala b/src/main/scala/opennlp/fieldspring/geolocate/KDTreeCell.scala new file mode 100644 index 0000000..6885721 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/KDTreeCell.scala @@ -0,0 +1,224 @@ +/////////////////////////////////////////////////////////////////////////////// +// KDTreeCellGrid.scala +// +// Copyright (C) 2011, 2012 Stephen Roller, The University of Texas at Austin +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// Copyright (C) 2011 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Map + +import ags.utils.KdTree + +import opennlp.fieldspring.util.distances.SphereCoord +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.worddist.UnigramWordDist + +class KdTreeCell( + cellgrid: KdTreeCellGrid, + val kdleaf : KdTree +) extends RectangularCell(cellgrid) { + + def get_northeast_coord () : SphereCoord = { + new SphereCoord(kdleaf.minLimit(0), kdleaf.minLimit(1)) + } + + def get_southwest_coord () : SphereCoord = { + new SphereCoord(kdleaf.maxLimit(0), kdleaf.maxLimit(1)) + } + + def describe_indices () : String = { + "Placeholder" + } + + def describe_location () : String = { + get_boundary.toString + " (Center: " + get_center_coord + ")" + } +} + +object KdTreeCellGrid { + def apply(table: SphereDocumentTable, bucketSize: Int, splitMethod: String, + useBackoff: Boolean, interpolateWeight: Double = 0.0) : KdTreeCellGrid = { + new KdTreeCellGrid(table, bucketSize, splitMethod match { + case "halfway" => KdTree.SplitMethod.HALFWAY + case "median" => KdTree.SplitMethod.MEDIAN + case "maxmargin" => KdTree.SplitMethod.MAX_MARGIN + }, useBackoff, interpolateWeight) + } +} + +class KdTreeCellGrid(table: SphereDocumentTable, + bucketSize: Int, + splitMethod: KdTree.SplitMethod, + useBackoff: Boolean, + interpolateWeight: Double) + extends SphereCellGrid(table) { + /** + * Total number of cells in the grid. + */ + var total_num_cells: Int = 0 + var kdtree: KdTree = new KdTree(2, bucketSize, splitMethod) + + val nodes_to_cell: Map[KdTree, KdTreeCell] = Map() + val leaves_to_cell: Map[KdTree, KdTreeCell] = Map() + + override val num_training_passes: Int = 2 + var current_training_pass: Int = 0 + + override def begin_training_pass(pass: Int) = { + current_training_pass = pass + + if (pass == 1) { + // do nothing + } else if (pass == 2) { + // we've seen all the coordinates. we need to build up + // the entire kd-tree structure now, the centroids, and + // clean out the data. + + val task = new ExperimentMeteredTask(table.driver, "K-d tree structure", + "generating") + + // build the full kd-tree structure. + kdtree.balance + + for (node <- kdtree.getNodes) { + val c = new KdTreeCell(this, node) + nodes_to_cell.update(node, c) + task.item_processed() + } + task.finish() + + // no longer need to keep all our locations in memory. destroy + // them. to free up memory. + kdtree.annihilateData + } else { + // definitely should not get here + assert(false); + } + } + + def find_best_cell_for_document(doc: SphereDocument, + create_non_recorded: Boolean) = { + // FIXME: implementation note: the KD tree should tile the entire earth's surface, + // but there's a possibility of something going awry here if we've never + // seen a evaluation point before. + leaves_to_cell(kdtree.getLeaf(Array(doc.coord.lat, doc.coord.long))) + } + + /** + * Add the given document to the cell grid. + */ + def add_document_to_cell(document: SphereDocument) { + if (current_training_pass == 1) { + kdtree.addPoint(Array(document.coord.lat, document.coord.long)) + } else if (current_training_pass == 2) { + val leaf = kdtree.getLeaf(Array(document.coord.lat, document.coord.long)) + var n = leaf + while (n != null) { + nodes_to_cell(n).add_document(document) + n = n.parent; + } + } else { + assert(false) + } + } + + /** + * Generate all non-empty cells. This will be called once (and only once), + * after all documents have been added to the cell grid by calling + * `add_document_to_cell`. The generation happens internally; but after + * this, `iter_nonempty_cells` should work properly. + */ + def initialize_cells() { + total_num_cells = kdtree.getLeaves.size + num_non_empty_cells = total_num_cells + + // need to finish generating all the word distributions + for (c <- nodes_to_cell.valuesIterator) { + c.finish() + } + + if (interpolateWeight > 0) { + // this is really gross. we need to interpolate all the nodes + // by modifying their worddist.count map. We are breaking + // so many levels of abstraction by doing this AND preventing + // us from using interpolation with bigrams :( + // + + // We'll do it top-down so dependencies are met. + + val iwtopdown = true + val nodes = + if (iwtopdown) kdtree.getNodes.toList + else kdtree.getNodes.reverse + + val task = new ExperimentMeteredTask(table.driver, "K-d tree cell", + "interpolating") + for (node <- nodes if node.parent != null) { + val cell = nodes_to_cell(node) + val wd = cell.combined_dist.word_dist + val model = wd.asInstanceOf[UnigramWordDist].model + + for ((k,v) <- model.iter_items) { + model.set_item(k, (1 - interpolateWeight) * v) + } + + val pcell = nodes_to_cell(node.parent) + val pwd = pcell.combined_dist.word_dist + val pmodel = pwd.asInstanceOf[UnigramWordDist].model + + for ((k,v) <- pmodel.iter_items) { + val oldv = if (model contains k) model.get_item(k) else 0.0 + val newv = oldv + interpolateWeight * v + if (newv > interpolateWeight) + model.set_item(k, newv) + } + + task.item_processed() + } + task.finish() + } + + // here we need to drop nonleaf nodes unless backoff is enabled. + val nodes = if (useBackoff) kdtree.getNodes else kdtree.getLeaves + for (node <- nodes) { + leaves_to_cell.update(node, nodes_to_cell(node)) + } + } + + /** + * Iterate over all non-empty cells. + * + * @param nonempty_word_dist If given, returned cells must also have a + * non-empty word distribution; otherwise, they just need to have at least + * one document in them. (Not all documents have word distributions, esp. + * when --max-time-per-stage has been set to a non-zero value so that we + * only load some subset of the word distributions for all documents. But + * even when not set, some documents may be listed in the document-data file + * but have no corresponding word counts given in the counts file.) + */ + def iter_nonempty_cells(nonempty_word_dist: Boolean = false): Iterable[SphereCell] = { + val nodes = if (useBackoff) kdtree.getNodes else kdtree.getLeaves + for (leaf <- nodes + if (leaf.size() > 0 || !nonempty_word_dist)) + yield leaves_to_cell(leaf) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/MultiRegularCell.scala b/src/main/scala/opennlp/fieldspring/geolocate/MultiRegularCell.scala new file mode 100644 index 0000000..b06fc7c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/MultiRegularCell.scala @@ -0,0 +1,554 @@ +/////////////////////////////////////////////////////////////////////////////// +// MultiRegularCell.scala +// +// Copyright (C) 2010, 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import math._ +import collection.mutable + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.printutil.{errout, errprint} +import opennlp.fieldspring.util.experiment._ + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ + +///////////////////////////////////////////////////////////////////////////// +// A regularly spaced grid // +///////////////////////////////////////////////////////////////////////////// + +/* + We divide the earth's surface into "tiling cells", all of which are the + same square size, running on latitude/longitude lines, and which have a + constant number of degrees on a size, set using the value of the command- + line option --degrees-per-cell. (Alternatively, the value of --miles-per-cell + or --km-per-cell are converted into degrees using 'miles_per_degree' or + 'km_per_degree', respectively, which specify the size of a degree at + the equator and is derived from the value for the Earth's radius.) + + In addition, we form a square of tiling cells in order to create a + "multi cell", which is used to compute a distribution over words. The + number of tiling cells on a side of a multi cell is determined by + --width-of-multi-cell. Note that if this is greater than 1, different + multi cells will overlap. + + To specify a cell, we use cell indices, which are derived from + coordinates by dividing by degrees_per_cell. Hence, if for example + degrees_per_cell is 2.0, then cell indices are in the range [-45,+45] + for latitude and [-90,+90) for longitude. Correspondingly, to convert + cell index to a SphereCoord, we multiply latitude and longitude by + degrees_per_cell. + + In general, an arbitrary coordinate will have fractional cell indices; + however, the cell indices of the corners of a cell (tiling or multi) + will be integers. Cells are canonically indexed and referred to by + the index of the southwest corner. In other words, for a given index, + the latitude or longitude of the southwest corner of the corresponding + cell (tiling or multi) is index*degrees_per_cell. For a tiling cell, + the cell includes all coordinates whose latitude or longitude is in + the half-open interval [index*degrees_per_cell, + (index+1)*degrees_per_cell). For a multi cell, the cell includes all + coordinates whose latitude or longitude is in the half-open interval + [index*degrees_per_cell, (index+width_of_multi_cell)*degrees_per_cell). + + Near the edges, tiling cells may be truncated. Multi cells will + wrap around longitudinally, and will still have the same number of + tiling cells, but may be smaller. + */ + +/** + * The index of a regular cell, using "cell index" integers, as described + * above. + */ +case class RegularCellIndex(latind: Int, longind: Int) { + def toFractional() = FractionalRegularCellIndex(latind, longind) +} + +object RegularCellIndex { + /* SCALABUG: Why do I need to specify RegularCellIndex as the return type + here? And why is this not required in the almost identical construction + in SphereCoord? I get this error (not here, but where the object is + created): + + [error] /Users/benwing/devel/fieldspring/src/main/scala/opennlp/fieldspring/geolocate/MultiRegularCell.scala:273: overloaded method apply needs result type + [error] RegularCellIndex(latind, longind) + [error] ^ + */ + def apply(cell_grid: MultiRegularCellGrid, latind: Int, longind: Int): + RegularCellIndex = { + require(valid(cell_grid, latind, longind)) + new RegularCellIndex(latind, longind) + } + + def valid(cell_grid: MultiRegularCellGrid, latind: Int, longind: Int) = ( + latind >= cell_grid.minimum_latind && + latind <= cell_grid.maximum_latind && + longind >= cell_grid.minimum_longind && + longind <= cell_grid.maximum_longind + ) + + def coerce_indices(cell_grid: MultiRegularCellGrid, latind: Int, + longind: Int) = { + var newlatind = latind + var newlongind = longind + if (newlatind > cell_grid.maximum_latind) + newlatind = cell_grid.maximum_latind + while (newlongind > cell_grid.maximum_longind) + newlongind -= (cell_grid.maximum_longind - cell_grid.minimum_longind + 1) + if (newlatind < cell_grid.minimum_latind) + newlatind = cell_grid.minimum_latind + while (newlongind < cell_grid.minimum_longind) + newlongind += (cell_grid.maximum_longind - cell_grid.minimum_longind + 1) + (newlatind, newlongind) + } + + def coerce(cell_grid: MultiRegularCellGrid, latind: Int, longind: Int) = { + val (newlatind, newlongind) = coerce_indices(cell_grid, latind, longind) + apply(cell_grid, newlatind, newlongind) + } +} + +/** + * Similar to `RegularCellIndex`, but for the case where the indices are + * fractional, representing a location other than at the corners of a + * cell. + */ +case class FractionalRegularCellIndex(latind: Double, longind: Double) { +} + +/** + * A cell where the cell grid is a MultiRegularCellGrid. (See that class.) + * + * @param cell_grid The CellGrid object for the grid this cell is in, + * an instance of MultiRegularCellGrid. + * @param index Index of (the southwest corner of) this cell in the grid + */ + +class MultiRegularCell( + cell_grid: MultiRegularCellGrid, + val index: RegularCellIndex +) extends RectangularCell(cell_grid) { + + def get_southwest_coord() = + cell_grid.multi_cell_index_to_near_corner_coord(index) + + def get_northeast_coord() = + cell_grid.multi_cell_index_to_far_corner_coord(index) + + def describe_location() = { + "%s-%s" format (get_southwest_coord(), get_northeast_coord()) + } + + def describe_indices() = "%s,%s" format (index.latind, index.longind) + + /** + * For a given multi cell, iterate over the tiling cells in the multi cell. + * The return values are the indices of the southwest corner of each + * tiling cell. + */ + def iterate_tiling_cells() = { + // Be careful around the edges -- we need to truncate the latitude and + // wrap the longitude. The call to `coerce()` will automatically + // wrap the longitude, but we need to truncate the latitude ourselves, + // or else we'll end up repeating cells. + val max_offset = cell_grid.width_of_multi_cell - 1 + val maxlatind = cell_grid.maximum_latind min (index.latind + max_offset) + + for ( + i <- index.latind to maxlatind; + j <- index.longind to (index.longind + max_offset) + ) yield RegularCellIndex.coerce(cell_grid, i, j) + } +} + +/** + * Grid composed of possibly-overlapping multi cells, based on an underlying + * grid of regularly-spaced square cells tiling the earth. The multi cells, + * over which word distributions are computed for comparison with the word + * distribution of a given document, are composed of NxN tiles, where possibly + * N > 1. + * + * FIXME: We should abstract out the concept of a grid composed of tiles and + * a grid composed of overlapping conglomerations of tiles; this could be + * useful e.g. for KD trees or other representations where we might want to + * compare with cells at multiple levels of granularity. + * + * @param degrees_per_cell Size of each cell in degrees. Determined by the + * --degrees-per-cell option, unless --miles-per-cell is set, in which + * case it takes priority. + * @param width_of_multi_cell Size of multi cells in tiling cells, + * determined by the --width-of-multi-cell option. + */ +class MultiRegularCellGrid( + val degrees_per_cell: Double, + val width_of_multi_cell: Int, + table: SphereDocumentTable +) extends SphereCellGrid(table) { + + /** + * Size of each cell (vertical dimension; horizontal dimension only near + * the equator) in km. Determined from degrees_per_cell. + */ + val km_per_cell = degrees_per_cell * km_per_degree + + /* Set minimum, maximum latitude/longitude in indices (integers used to + index the set of cells that tile the earth). The actual maximum + latitude is exactly 90 (the North Pole). But if we set degrees per + cell to be a number that exactly divides 180, and we use + maximum_latitude = 90 in the following computations, then we would + end up with the North Pole in a cell by itself, something we probably + don't want. + */ + val maximum_index = + coord_to_tiling_cell_index(SphereCoord(maximum_latitude - 1e-10, + maximum_longitude)) + val maximum_latind = maximum_index.latind + val maximum_longind = maximum_index.longind + val minimum_index = + coord_to_tiling_cell_index(SphereCoord(minimum_latitude, minimum_longitude)) + val minimum_latind = minimum_index.latind + val minimum_longind = minimum_index.longind + + /** + * Mapping from index of southwest corner of multi cell to corresponding + * cell object. A "multi cell" is made up of a square of tiling cells, + * with the number of cells on a side determined by `width_of_multi_cell'. + * A word distribution is associated with each multi cell. + * + * We don't just create an array because we expect many cells to have no + * documents in them, esp. as we decrease the cell size. + */ + val corner_to_multi_cell = mutable.Map[RegularCellIndex, MultiRegularCell]() + + var total_num_cells = 0 + + /********** Conversion between Cell indices and SphereCoords **********/ + + /* The different functions vary depending on where in the particular cell + the SphereCoord is wanted, e.g. one of the corners or the center. */ + + /** + * Convert a coordinate to the indices of the southwest corner of the + * corresponding tiling cell. + */ + def coord_to_tiling_cell_index(coord: SphereCoord) = { + val latind = floor(coord.lat / degrees_per_cell).toInt + val longind = floor(coord.long / degrees_per_cell).toInt + RegularCellIndex(latind, longind) + } + + /** + * Convert a coordinate to the indices of the southwest corner of the + * corresponding multi cell. Note that if `width_of_multi_cell` > 1, + * there will be more than one multi cell containing the coordinate. + * In that case, we want the multi cell in which the coordinate is most + * centered. (For example, if `width_of_multi_cell` = 3, then each multi + * cell has 9 tiling cells in it, only one of which is in the center. + * A given coordinate will belong to only one tiling cell, and we want + * the multi cell which has that tiling cell in its center.) + */ + def coord_to_multi_cell_index(coord: SphereCoord) = { + // When width_of_multi_cell = 1, don't subtract anything. + // When width_of_multi_cell = 2, subtract 0.5*degrees_per_cell. + // When width_of_multi_cell = 3, subtract degrees_per_cell. + // When width_of_multi_cell = 4, subtract 1.5*degrees_per_cell. + // In general, subtract (width_of_multi_cell-1)/2.0*degrees_per_cell. + + // Compute the indices of the southwest cell + val subval = (width_of_multi_cell - 1) / 2.0 * degrees_per_cell + coord_to_tiling_cell_index( + SphereCoord(coord.lat - subval, coord.long - subval)) + } + + /** + * Convert a fractional cell index to the corresponding coordinate. Useful + * for indices not referring to the corner of a cell. + * + * @see #cell_index_to_coord + */ + def fractional_cell_index_to_coord(index: FractionalRegularCellIndex, + method: String = "coerce-warn") = { + SphereCoord(index.latind * degrees_per_cell, + index.longind * degrees_per_cell, method) + } + + /** + * Convert cell indices to the corresponding coordinate. This can also + * be used to find the coordinate of the southwest corner of a tiling cell + * or multi cell, as both are identified by the cell indices of + * their southwest corner. + */ + def cell_index_to_coord(index: RegularCellIndex, + method: String = "coerce-warn") = + fractional_cell_index_to_coord(index.toFractional, method) + + /** + * Add 'offset' to both latind and longind of 'index' and then convert to a + * coordinate. Coerce the coordinate to be within bounds. + */ + def offset_cell_index_to_coord(index: RegularCellIndex, + offset: Double) = { + fractional_cell_index_to_coord( + FractionalRegularCellIndex(index.latind + offset, index.longind + offset), + "coerce") + } + + /** + * Convert cell indices of a tiling cell to the coordinate of the + * near (i.e. southwest) corner of the cell. + */ + def tiling_cell_index_to_near_corner_coord(index: RegularCellIndex) = { + cell_index_to_coord(index) + } + + /** + * Convert cell indices of a tiling cell to the coordinate of the + * center of the cell. + */ + def tiling_cell_index_to_center_coord(index: RegularCellIndex) = { + offset_cell_index_to_coord(index, 0.5) + } + + /** + * Convert cell indices of a tiling cell to the coordinate of the + * far (i.e. northeast) corner of the cell. + */ + def tiling_cell_index_to_far_corner_coord(index: RegularCellIndex) = { + offset_cell_index_to_coord(index, 1.0) + } + /** + * Convert cell indices of a tiling cell to the coordinate of the + * near (i.e. southwest) corner of the cell. + */ + def multi_cell_index_to_near_corner_coord(index: RegularCellIndex) = { + cell_index_to_coord(index) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * center of the cell. + */ + def multi_cell_index_to_center_coord(index: RegularCellIndex) = { + offset_cell_index_to_coord(index, width_of_multi_cell / 2.0) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * far (i.e. northeast) corner of the cell. + */ + def multi_cell_index_to_far_corner_coord(index: RegularCellIndex) = { + offset_cell_index_to_coord(index, width_of_multi_cell) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * northwest corner of the cell. + */ + def multi_cell_index_to_nw_corner_coord(index: RegularCellIndex) = { + cell_index_to_coord( + RegularCellIndex.coerce(this, index.latind + width_of_multi_cell, + index.longind)) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * southeast corner of the cell. + */ + def multi_cell_index_to_se_corner_coord(index: RegularCellIndex) = { + cell_index_to_coord( + RegularCellIndex.coerce(this, index.latind, + index.longind + width_of_multi_cell)) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * southwest corner of the cell. + */ + def multi_cell_index_to_sw_corner_coord(index: RegularCellIndex) = { + multi_cell_index_to_near_corner_coord(index) + } + + /** + * Convert cell indices of a multi cell to the coordinate of the + * northeast corner of the cell. + */ + def multi_cell_index_to_ne_corner_coord(index: RegularCellIndex) = { + multi_cell_index_to_far_corner_coord(index) + } + + /*************** End conversion functions *************/ + + /** + * For a given coordinate, iterate over the multi cells containing the + * coordinate. This first finds the tiling cell containing the + * coordinate and then finds the multi cells containing the tiling cell. + * The returned values are the indices of the (southwest corner of the) + * multi cells. + */ + def iterate_overlapping_multi_cells(coord: SphereCoord) = { + // The logic is almost exactly the same as in iterate_tiling_cells() + // except that the offset is negative. + val index = coord_to_tiling_cell_index(coord) + // In order to handle coordinates near the edges of the grid, we need to + // truncate the latitude ourselves, but coerce() handles the longitude + // wrapping. See iterate_tiling_cells(). + val max_offset = width_of_multi_cell - 1 + val minlatind = minimum_latind max (index.latind - max_offset) + + for ( + i <- minlatind to index.latind; + j <- (index.longind - max_offset) to index.longind + ) yield RegularCellIndex.coerce(this, i, j) + } + + def find_best_cell_for_document(doc: SphereDocument, + create_non_recorded: Boolean) = { + assert(all_cells_computed) + val index = coord_to_multi_cell_index(doc.coord) + find_cell_for_cell_index(index, create = create_non_recorded, + record_created_cell = false) + } + + /** + * For a given multi cell index, find the corresponding cell. + * If no such cell exists, create one if `create` is true; + * else, return null. If a cell is created, record it in the + * grid if `record_created_cell` is true. + */ + protected def find_cell_for_cell_index(index: RegularCellIndex, + create: Boolean, record_created_cell: Boolean) = { + val cell = corner_to_multi_cell.getOrElse(index, null) + if (cell != null) + cell + else if (!create) null + else { + val newcell = new MultiRegularCell(this, index) + if (record_created_cell) { + num_non_empty_cells += 1 + corner_to_multi_cell(index) = newcell + } + newcell + } + } + + /** + * Add the document to the cell(s) it belongs to. This finds all the + * multi cells, creating them as necessary, and adds the document to each. + */ + def add_document_to_cell(doc: SphereDocument) { + for (index <- iterate_overlapping_multi_cells(doc.coord)) { + val cell = find_cell_for_cell_index(index, create = true, + record_created_cell = true) + if (debug("cell")) + errprint("Adding document %s to cell %s", doc, cell) + cell.add_document(doc) + } + } + + protected def initialize_cells() { + val task = new ExperimentMeteredTask(table.driver, "Earth-tiling cell", + "generating non-empty") + + for (i <- minimum_latind to maximum_latind view) { + for (j <- minimum_longind to maximum_longind view) { + total_num_cells += 1 + val cell = find_cell_for_cell_index(RegularCellIndex(i, j), + create = false, record_created_cell = false) + if (cell != null) { + cell.finish() + if (debug("cell")) + errprint("--> (%d,%d): %s", i, j, cell) + } + task.item_processed() + } + } + task.finish() + } + + def iter_nonempty_cells(nonempty_word_dist: Boolean = false) = { + assert(all_cells_computed) + for { + v <- corner_to_multi_cell.values + val empty = ( + if (nonempty_word_dist) v.combined_dist.is_empty_for_word_dist() + else v.combined_dist.is_empty()) + if (!empty) + } yield v + } + + /** + * Output a "ranking grid" of information so that a nice 3-D graph + * can be created showing the ranks of cells surrounding the true + * cell, out to a certain distance. + * + * @param pred_cells List of predicted cells, along with their scores. + * @param true_cell True cell. + * @param grsize Total size of the ranking grid. (For example, a total size + * of 21 will result in a ranking grid with the true cell and 10 + * cells on each side shown.) + */ + def output_ranking_grid(pred_cells: Iterable[(MultiRegularCell, Double)], + true_cell: MultiRegularCell, grsize: Int) { + val (true_latind, true_longind) = + (true_cell.index.latind, true_cell.index.longind) + val min_latind = true_latind - grsize / 2 + val max_latind = min_latind + grsize - 1 + val min_longind = true_longind - grsize / 2 + val max_longind = min_longind + grsize - 1 + val grid = mutable.Map[RegularCellIndex, (MultiRegularCell, Double, Int)]() + for (((cell, score), rank) <- pred_cells zip (1 to pred_cells.size)) { + val (la, lo) = (cell.index.latind, cell.index.longind) + if (la >= min_latind && la <= max_latind && + lo >= min_longind && lo <= max_longind) + // FIXME: This assumes KL-divergence or similar scores, which have + // been negated to make larger scores better. + grid(cell.index) = (cell, -score, rank) + } + + errprint("Grid ranking, gridsize %dx%d", grsize, grsize) + errprint("NW corner: %s", + multi_cell_index_to_nw_corner_coord( + RegularCellIndex.coerce(this, max_latind, min_longind))) + errprint("SE corner: %s", + multi_cell_index_to_se_corner_coord( + RegularCellIndex.coerce(this, min_latind, max_longind))) + for (doit <- Seq(0, 1)) { + if (doit == 0) + errprint("Grid for ranking:") + else + errprint("Grid for goodness/distance:") + for (lat <- max_latind to min_latind) { + for (long <- fromto(min_longind, max_longind)) { + val cellvalrank = + grid.getOrElse(RegularCellIndex.coerce(this, lat, long), null) + if (cellvalrank == null) + errout(" %-8s", "empty") + else { + val (cell, value, rank) = cellvalrank + val showit = if (doit == 0) rank else value + if (lat == true_latind && long == true_longind) + errout("!%-8.6s", showit) + else + errout(" %-8.6s", showit) + } + } + errout("\n") + } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/geolocate/SphereCell.scala b/src/main/scala/opennlp/fieldspring/geolocate/SphereCell.scala new file mode 100644 index 0000000..914a1f8 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/SphereCell.scala @@ -0,0 +1,236 @@ +/////////////////////////////////////////////////////////////////////////////// +// SphereCell.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import opennlp.fieldspring.util.distances._ + +import opennlp.fieldspring.gridlocate.{GeoCell,CellGrid} + +///////////////////////////////////////////////////////////////////////////// +// Cells in a grid // +///////////////////////////////////////////////////////////////////////////// + +abstract class SphereCell( + cell_grid: SphereCellGrid +) extends GeoCell[SphereCoord, SphereDocument](cell_grid) { + /** + * Generate KML for a single cell. + */ + def generate_kml(xfprob: Double, xf_minprob: Double, xf_maxprob: Double, + params: KMLParameters): Iterable[xml.Elem] +} + +/** + * A cell in a polygonal shape. + * + * @param cell_grid The CellGrid object for the grid this cell is in. + */ +abstract class PolygonalCell( + cell_grid: SphereCellGrid +) extends SphereCell(cell_grid) { + /** + * Return the boundary of the cell as an Iterable of coordinates, tracing + * out the boundary vertex by vertex. The last coordinate should be the + * same as the first, as befits a closed shape. + */ + def get_boundary(): Iterable[SphereCoord] + + /** + * Return the "inner boundary" -- something echoing the actual boundary of the + * cell but with smaller dimensions. Used for outputting KML to make the + * output easier to read. + */ + def get_inner_boundary() = { + val center = get_center_coord() + for (coord <- get_boundary()) + yield SphereCoord((center.lat + coord.lat) / 2.0, + average_longitudes(center.long, coord.long)) + } + + /** + * Generate the KML placemark for the cell's name. Currently it's rectangular + * for rectangular cells. FIXME: Perhaps it should be generalized so it doesn't + * need to be redefined for differently-shaped cells. + * + * @param name The name to display in the placemark + */ + def generate_kml_name_placemark(name: String): xml.Elem + + def generate_kml(xfprob: Double, xf_minprob: Double, xf_maxprob: Double, + params: KMLParameters) = { + val offprob = xfprob - xf_minprob + val fracprob = offprob / (xf_maxprob - xf_minprob) + var coordtext = "\n" + for (coord <- get_inner_boundary()) { + coordtext += "%s,%s,%s\n" format ( + coord.long, coord.lat, fracprob * params.kml_max_height) + } + val name = + if (most_popular_document != null) most_popular_document.title + else "" + + // Placemark indicating name + val name_placemark = generate_kml_name_placemark(name) + + // Interpolate colors + val color = Array(0.0, 0.0, 0.0) + for (i <- 0 until 3) { + color(i) = (params.kml_mincolor(i) + + fracprob * (params.kml_maxcolor(i) - params.kml_mincolor(i))) + } + // Original color dc0155ff + //rgbcolor = "dc0155ff" + val revcol = color.reverse + val rgbcolor = "ff%02x%02x%02x" format ( + revcol(0).toInt, revcol(1).toInt, revcol(2).toInt) + + // Yield cylinder indicating probability by height and color + + // !!PY2SCALA: BEGIN_PASSTHRU + val cylinder_placemark = + + { "%s POLYGON" format name } + #bar + + + 1 + 1 + relativeToGround + + + { coordtext } + + + + + // !!PY2SCALA: END_PASSTHRU + Seq(name_placemark, cylinder_placemark) + } +} + +/** + * A cell in a rectangular shape. + * + * @param cell_grid The CellGrid object for the grid this cell is in. + */ +abstract class RectangularCell( + cell_grid: SphereCellGrid +) extends PolygonalCell(cell_grid) { + /** + * Return the coordinate of the southwest point of the rectangle. + */ + def get_southwest_coord(): SphereCoord + /** + * Return the coordinate of the northeast point of the rectangle. + */ + def get_northeast_coord(): SphereCoord + + /** + * Define the center based on the southwest and northeast points, + * or based on the centroid of the cell. + */ + var centroid: Array[Double] = new Array[Double](2) + var num_docs: Int = 0 + + def get_center_coord() = { + if (num_docs == 0 || cell_grid.table.driver.params.center_method == "center") { + // use the actual cell center + // also, if we have an empty cell, there is no such thing as + // a centroid, so default to the center + val sw = get_southwest_coord() + val ne = get_northeast_coord() + SphereCoord((sw.lat + ne.lat) / 2.0, (sw.long + ne.long) / 2.0) + } else { + // use the centroid + SphereCoord(centroid(0) / num_docs, centroid(1) / num_docs); + } + } + + override def add_document(document: SphereDocument) { + num_docs += 1 + centroid(0) += document.coord.lat + centroid(1) += document.coord.long + super.add_document(document) + } + + + + /** + * Define the boundary given the specified southwest and northeast + * points. + */ + def get_boundary() = { + val sw = get_southwest_coord() + val ne = get_northeast_coord() + val center = get_center_coord() + val nw = SphereCoord(ne.lat, sw.long) + val se = SphereCoord(sw.lat, ne.long) + Seq(sw, nw, ne, se, sw) + } + + /** + * Generate the name placemark as a smaller rectangle within the + * larger rectangle. (FIXME: Currently it is exactly the size of + * the inner boundary. Perhaps this should be generalized, so + * that the definition of this function can be handled up at the + * polygonal-shaped-cell level.) + */ + def generate_kml_name_placemark(name: String) = { + val sw = get_southwest_coord() + val ne = get_northeast_coord() + val center = get_center_coord() + // !!PY2SCALA: BEGIN_PASSTHRU + // Because it tries to frob the # sign + + { name } + , + + + { ((center.lat + ne.lat) / 2).toString } + { ((center.lat + sw.lat) / 2).toString } + { ((center.long + ne.long) / 2).toString } + { ((center.long + sw.long) / 2).toString } + + + 16 + + + #bar + + { "%s,%s" format (center.long, center.lat) } + + + // !!PY2SCALA: END_PASSTHRU + } +} + +/** + * Abstract class for a grid of cells covering the earth. + */ +abstract class SphereCellGrid( + override val table: SphereDocumentTable +) extends CellGrid[SphereCoord, SphereDocument, SphereCell](table) { +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/SphereCellDist.scala b/src/main/scala/opennlp/fieldspring/geolocate/SphereCellDist.scala new file mode 100644 index 0000000..bc05a76 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/SphereCellDist.scala @@ -0,0 +1,129 @@ +/////////////////////////////////////////////////////////////////////////////// +// SphereCellDist.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import math._ + +import opennlp.fieldspring.util.distances._ + +import opennlp.fieldspring.gridlocate.{WordCellDist,CellDistFactory} +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +///////////////////////////////////////////////////////////////////////////// +// Cell distributions // +///////////////////////////////////////////////////////////////////////////// + +/** + * This is the Sphere-specific version of `WordCellDist`, which is used for + * a distribution over cells that is associated with a single word. This + * is used in particular for the strategies and Geolocate applications that + * need to invert the per-cell word distributions to obtain a per-word cell + * distribution: Specifically, the GenerateKML app (which generates a KML + * file showing the cell distribution of a given word across the Earth); + * the `average-cell-probability` strategy (which uses the inversion strategy + * to obtain a distribution of cells for each word in a document, and then + * averages them); and the `most-common-toponym` baseline strategy (which + * similarly uses the inversion strategy but obtains a single distribution of + * cells for the most common toponym in the document, and uses this + * distribution directly to generate the list of ranked cells). + + * Instances of this class are normally generated by the + * `SphereCellDistFactory` class, which handles caching of common words + * (which may get requested multiple times across a set of documents). + * The above cases generally create a factory and then request individual + * `SphereWordCellDist` objects using `get_cell_dist`, which may call + * `create_word_cell_dist` to create the actual `SphereWordCellDist`. + * + * @param word Word for which the cell is computed + * @param cellprobs Hash table listing probabilities associated with cells + */ + +class SphereWordCellDist( + cell_grid: SphereCellGrid, + word: Word +) extends WordCellDist[SphereCoord, SphereDocument, SphereCell]( + cell_grid, word) { + // Convert cell to a KML file showing the distribution + def generate_kml_file(filename: String, params: KMLParameters) { + val xform = if (params.kml_transform == "log") (x: Double) => log(x) + else if (params.kml_transform == "logsquared") (x: Double) => -log(x) * log(x) + else (x: Double) => x + + val xf_minprob = xform(cellprobs.values min) + val xf_maxprob = xform(cellprobs.values max) + + def yield_cell_kml() = { + for { + (cell, prob) <- cellprobs + kml <- cell.generate_kml(xform(prob), xf_minprob, xf_maxprob, params) + expr <- kml + } yield expr + } + + val allcellkml = yield_cell_kml() + + val kml = + + + + + + { unmemoize_string(word) } + 1 + { "Cell distribution for word '%s'" format unmemoize_string(word) } + + 42 + -102 + 0 + 5000000 + 53.454348562403 + 0 + + { allcellkml } + + + + + xml.XML.save(filename, kml) + } +} + +class SphereCellDistFactory( + lru_cache_size: Int +) extends CellDistFactory[SphereCoord, SphereDocument, SphereCell]( + lru_cache_size) { + type TCellDist = SphereWordCellDist + type TGrid = SphereCellGrid + def create_word_cell_dist(cell_grid: TGrid, word: Word) = + new TCellDist(cell_grid, word) +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/SphereDocument.scala b/src/main/scala/opennlp/fieldspring/geolocate/SphereDocument.scala new file mode 100644 index 0000000..485a725 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/SphereDocument.scala @@ -0,0 +1,190 @@ +/////////////////////////////////////////////////////////////////////////////// +// SphereDocument.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import collection.mutable + +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.textdbutil.Schema +import opennlp.fieldspring.util.printutil.warning + +import opennlp.fieldspring.gridlocate.{DistDocument,DistDocumentTable} +import opennlp.fieldspring.gridlocate.DistDocumentConverters._ + +import opennlp.fieldspring.worddist.WordDistFactory + +abstract class SphereDocument( + schema: Schema, + table: SphereDocumentTable +) extends DistDocument[SphereCoord](schema, table) { + var coord: SphereCoord = _ + def has_coord = coord != null + + override def set_field(name: String, value: String) { + name match { + case "coord" => coord = get_x_or_null[SphereCoord](value) + case _ => super.set_field(name, value) + } + } + + def distance_to_coord(coord2: SphereCoord) = spheredist(coord, coord2) + def degree_distance_to_coord(coord2: SphereCoord) = degree_dist(coord, coord2) + def output_distance(dist: Double) = km_and_miles(dist) +} + +/** + * A subtable holding SphereDocuments corresponding to a specific corpus + * type (e.g. Wikipedia or Twitter). + */ +abstract class SphereDocumentSubtable[TDoc <: SphereDocument]( + val table: SphereDocumentTable +) { + /** + * Create and return a document of the current type. + */ + def create_document(schema: Schema): TDoc + + /** + * Given the schema and field values read from a document file, create + * and return a document. Return value can be null if the document is + * to be skipped; otherwise, it will be recorded in the appropriate split. + */ + def create_and_init_document(schema: Schema, fieldvals: Seq[String], + record_in_table: Boolean) = { + val doc = create_document(schema) + if (doc != null) + doc.set_fields(fieldvals) + doc + } + + /** + * Do any subtable-specific operations needed after all documents have + * been loaded. + */ + def finish_document_loading() { } +} + +/** + * A DistDocumentTable specifically for documents with coordinates described + * by a SphereCoord (latitude/longitude coordinates on the Earth). + * We delegate the actual document creation to a subtable specific to the + * type of corpus (e.g. Wikipedia or Twitter). + */ +class SphereDocumentTable( + override val driver: GeolocateDriver, + word_dist_factory: WordDistFactory +) extends DistDocumentTable[SphereCoord, SphereDocument, SphereCellGrid]( + driver, word_dist_factory +) { + val corpus_type_to_subtable = + mutable.Map[String, SphereDocumentSubtable[_ <: SphereDocument]]() + + def register_subtable(corpus_type: String, subtable: + SphereDocumentSubtable[_ <: SphereDocument]) { + corpus_type_to_subtable(corpus_type) = subtable + } + + register_subtable("wikipedia", new WikipediaDocumentSubtable(this)) + register_subtable("twitter-tweet", new TwitterTweetDocumentSubtable(this)) + register_subtable("twitter-user", new TwitterUserDocumentSubtable(this)) + register_subtable("generic", new GenericSphereDocumentSubtable(this)) + + def wikipedia_subtable = + corpus_type_to_subtable("wikipedia").asInstanceOf[WikipediaDocumentSubtable] + + def create_document(schema: Schema): SphereDocument = { + throw new UnsupportedOperationException("This shouldn't be called directly; instead, use create_and_init_document()") + } + + override def imp_create_and_init_document(schema: Schema, + fieldvals: Seq[String], record_in_table: Boolean) = { + find_subtable(schema, fieldvals). + create_and_init_document(schema, fieldvals, record_in_table) + } + + /** + * Find the subtable for the field values of a document as read from + * from a document file. Currently this simply locates the 'corpus-type' + * parameter and calls `find_subtable(java.lang.String)` to find + * the appropriate table. + */ + def find_subtable(schema: Schema, fieldvals: Seq[String]): + SphereDocumentSubtable[_ <: SphereDocument] = { + val cortype = schema.get_field_or_else(fieldvals, "corpus-type", "generic") + find_subtable(cortype) + } + + /** + * Find the document table for a given corpus type. + */ + def find_subtable(cortype: String) = { + if (corpus_type_to_subtable contains cortype) + corpus_type_to_subtable(cortype) + else { + warning("Unrecognized corpus type: %s", cortype) + corpus_type_to_subtable("generic") + } + } + + /** + * Iterate over all the subtables that exist. + */ + def iterate_subtables() = corpus_type_to_subtable.values + + override def finish_document_loading() { + for (subtable <- iterate_subtables()) + subtable.finish_document_loading() + super.finish_document_loading() + } +} + +/** + * A generic SphereDocument for when the corpus type is missing or + * unrecognized. (FIXME: Do we really need this? Should we just throw an + * error or ignore it?) + */ +class GenericSphereDocument( + schema: Schema, + subtable: GenericSphereDocumentSubtable +) extends SphereDocument(schema, subtable.table) { + var title: String = _ + + override def set_field(name: String, value: String) { + name match { + case "title" => title = value + case _ => super.set_field(name, value) + } + } + + def struct = + + { title } + { + if (has_coord) + { coord } + } + +} + +class GenericSphereDocumentSubtable( + table: SphereDocumentTable +) extends SphereDocumentSubtable[GenericSphereDocument](table) { + def create_document(schema: Schema) = + new GenericSphereDocument(schema, this) +} diff --git a/src/main/scala/opennlp/fieldspring/geolocate/SphereEvaluation.scala b/src/main/scala/opennlp/fieldspring/geolocate/SphereEvaluation.scala new file mode 100644 index 0000000..cfa9d65 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/SphereEvaluation.scala @@ -0,0 +1,447 @@ +/////////////////////////////////////////////////////////////////////////////// +// SphereEvaluation.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// Copyright (C) 2011 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import math.{round, floor} +import collection.mutable +import util.control.Breaks._ + +import opennlp.fieldspring.util.collectionutil.{DoubleTableByRange} +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.experiment.ExperimentDriverStats +import opennlp.fieldspring.util.mathutil.{mean, median} +import opennlp.fieldspring.util.ioutil.{FileHandler} +import opennlp.fieldspring.util.osutil.output_resource_usage +import opennlp.fieldspring.util.printutil.{errprint, warning} +import opennlp.fieldspring.util.textutil.split_text_into_words + +import opennlp.fieldspring.gridlocate._ +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ + +///////////////////////////////////////////////////////////////////////////// +// General statistics on evaluation results // +///////////////////////////////////////////////////////////////////////////// + +//////// Statistics for geolocating documents + +/** + * A general trait for encapsulating SphereDocument-specific behavior. + * In this case, this is largely the computation of "degree distances" in + * addition to "true distances", and making sure results are output in + * miles and km. + */ +trait SphereDocumentEvalStats extends DocumentEvalStats { + // "True dist" means actual distance in km's or whatever. + // "Degree dist" is the distance in degrees. + val degree_dists = mutable.Buffer[Double]() + val oracle_degree_dists = mutable.Buffer[Double]() + + def record_predicted_degree_distance(pred_degree_dist: Double) { + degree_dists += pred_degree_dist + } + + def record_oracle_degree_distance(oracle_degree_dist: Double) { + oracle_degree_dists += oracle_degree_dist + } + + protected def output_result_with_units(kmdist: Double) = km_and_miles(kmdist) + + override def output_incorrect_results() { + super.output_incorrect_results() + errprint(" Mean degree error distance = %.2f degrees", + mean(degree_dists)) + errprint(" Median degree error distance = %.2f degrees", + median(degree_dists)) + errprint(" Median oracle true error distance = %s", + km_and_miles(median(oracle_true_dists))) + } +} + +/** + * SphereDocument version of `CoordDocumentEvalStats`. + */ +class CoordSphereDocumentEvalStats( + driver_stats: ExperimentDriverStats, + prefix: String +) extends CoordDocumentEvalStats(driver_stats, prefix) + with SphereDocumentEvalStats { +} + +/** + * SphereDocument version of `RankedDocumentEvalStats`. + */ +class RankedSphereDocumentEvalStats( + driver_stats: ExperimentDriverStats, + prefix: String, + max_rank_for_credit: Int = 10 +) extends RankedDocumentEvalStats(driver_stats, prefix, max_rank_for_credit) + with SphereDocumentEvalStats { +} + +/** + * SphereDocument version of `GroupedDocumentEvalStats`. This keeps separate + * sets of statistics for different subgroups of the test documents, i.e. + * those within particular ranges of one or more quantities of interest. + */ +class GroupedSphereDocumentEvalStats( + driver_stats: ExperimentDriverStats, + cell_grid: SphereCellGrid, + results_by_range: Boolean, + is_ranked: Boolean +) extends GroupedDocumentEvalStats[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid, + SphereDocumentEvaluationResult +](driver_stats, cell_grid, results_by_range) { + override def create_stats(prefix: String) = { + if (is_ranked) + new RankedSphereDocumentEvalStats(driver_stats, prefix) + else + new CoordSphereDocumentEvalStats(driver_stats, prefix) + } + + val docs_by_degree_dist_to_true_center = + docmap("degree_dist_to_true_center") + + val docs_by_degree_dist_to_pred_center = + new DoubleTableByRange(dist_fractions_for_error_dist, + create_stats_for_range("degree_dist_to_pred_center", _)) + + override def record_one_result(stats: DocumentEvalStats, + res: SphereDocumentEvaluationResult) { + super.record_one_result(stats, res) + stats.asInstanceOf[SphereDocumentEvalStats]. + record_predicted_degree_distance(res.pred_degdist) + } + + override def record_one_oracle_result(stats: DocumentEvalStats, + res: SphereDocumentEvaluationResult) { + super.record_one_oracle_result(stats, res) + stats.asInstanceOf[SphereDocumentEvalStats]. + record_oracle_degree_distance(res.true_degdist) + } + + override def record_result_by_range(res: SphereDocumentEvaluationResult) { + super.record_result_by_range(res) + + /* FIXME: This code specific to MultiRegularCellGrid is kind of ugly. + Perhaps it should go elsewhere. + + FIXME: Also note that we don't actually make use of the info we + record here. See below. + */ + if (cell_grid.isInstanceOf[MultiRegularCellGrid]) { + val multigrid = cell_grid.asInstanceOf[MultiRegularCellGrid] + + /* For distance to center of true cell, which will be small (no more + than width_of_multi_cell * size-of-tiling-cell); we convert to + fractions of tiling-cell size and record in ranges corresponding + to increments of 0.25 (see above). */ + /* True distance (in both km and degrees) as a fraction of + cell size */ + val frac_true_truedist = res.true_truedist / multigrid.km_per_cell + val frac_true_degdist = res.true_degdist / multigrid.degrees_per_cell + /* Round the fractional distances to multiples of + dist_fraction_increment */ + val fracinc = dist_fraction_increment + val rounded_frac_true_truedist = + fracinc * floor(frac_true_degdist / fracinc) + val rounded_frac_true_degdist = + fracinc * floor(frac_true_degdist / fracinc) + res.record_result(docs_by_true_dist_to_true_center( + rounded_frac_true_truedist)) + res.record_result(docs_by_degree_dist_to_true_center( + rounded_frac_true_degdist)) + + /* For distance to center of predicted cell, which may be large, since + predicted cell may be nowhere near the true cell. Again we convert + to fractions of tiling-cell size and record in the ranges listed in + dist_fractions_for_error_dist (see above). */ + /* Predicted distance (in both km and degrees) as a fraction of + cell size */ + val frac_pred_truedist = res.pred_truedist / multigrid.km_per_cell + val frac_pred_degdist = res.pred_degdist / multigrid.degrees_per_cell + res.record_result(docs_by_true_dist_to_pred_center.get_collector( + frac_pred_truedist)) + res.record_result(docs_by_degree_dist_to_pred_center.get_collector( + frac_pred_degdist)) + } else if (cell_grid.isInstanceOf[KdTreeCellGrid]) { + // for kd trees, we do something similar to above, but round to the nearest km... + val kdgrid = cell_grid.asInstanceOf[KdTreeCellGrid] + res.record_result(docs_by_true_dist_to_true_center( + round(res.true_truedist))) + res.record_result(docs_by_degree_dist_to_true_center( + round(res.true_degdist))) + } + } + + override def output_results_by_range() { + super.output_results_by_range() + errprint("") + + if (cell_grid.isInstanceOf[MultiRegularCellGrid]) { + val multigrid = cell_grid.asInstanceOf[MultiRegularCellGrid] + + for ( + (frac_truedist, obj) <- + docs_by_true_dist_to_true_center.toSeq sortBy (_._1) + ) { + val lowrange = frac_truedist * multigrid.km_per_cell + val highrange = ((frac_truedist + dist_fraction_increment) * + multigrid.km_per_cell) + errprint("") + errprint("Results for documents where distance to center") + errprint(" of true cell in km is in the range [%.2f,%.2f):", + lowrange, highrange) + obj.output_results() + } + errprint("") + for ( + (frac_degdist, obj) <- + docs_by_degree_dist_to_true_center.toSeq sortBy (_._1) + ) { + val lowrange = frac_degdist * multigrid.degrees_per_cell + val highrange = ((frac_degdist + dist_fraction_increment) * + multigrid.degrees_per_cell) + errprint("") + errprint("Results for documents where distance to center") + errprint(" of true cell in degrees is in the range [%.2f,%.2f):", + lowrange, highrange) + obj.output_results() + } + } + } +} + +///////////////////////////////////////////////////////////////////////////// +// Main evaluation code // +///////////////////////////////////////////////////////////////////////////// + +/** + * A general trait holding SphereDocument-specific code for storing the + * result of evaluation on a document. Here we simply compute the + * true and predicted "degree distances" -- i.e. measured in degrees, + * rather than in actual distance along a great circle. + */ +trait SphereDocumentEvaluationResult extends DocumentEvaluationResult[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid +] { + val xdocument: SphereDocument + /* The following must be declared as 'lazy' because 'xdocument' above isn't + initialized at creation time (which is impossible because traits can't + have construction parameters). */ + /** + * Distance in degrees between document's coordinate and central + * point of true cell + */ + lazy val true_degdist = xdocument.degree_distance_to_coord(true_center) + /** + * Distance in degrees between document's coordinate and predicted + * coordinate + */ + lazy val pred_degdist = xdocument.degree_distance_to_coord(pred_coord) +} + +/** + * Result of evaluating a SphereDocument using an algorithm that does + * cell-by-cell comparison and computes a ranking of all the cells. + * The predicted coordinate is the central point of the top-ranked cell, + * and the cell grid is derived from the cell. + * + * @param document document whose coordinate is predicted + * @param pred_cell top-ranked predicted cell in which the document should + * belong + * @param true_rank rank of the document's true cell among all of the + * predicted cell + */ +class RankedSphereDocumentEvaluationResult( + document: SphereDocument, + pred_cell: SphereCell, + true_rank: Int +) extends RankedDocumentEvaluationResult[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid + ]( + document, pred_cell, true_rank +) with SphereDocumentEvaluationResult { + val xdocument = document +} + +/** + * Result of evaluating a SphereDocument using an algorithm that + * predicts a coordinate that is not necessarily the central point of + * any cell (e.g. using a mean-shift algorithm). + * + * @param document document whose coordinate is predicted + * @param cell_grid cell grid against which error comparison should be done + * @param pred_coord predicted coordinate of the document + */ +class CoordSphereDocumentEvaluationResult( + document: SphereDocument, + cell_grid: SphereCellGrid, + pred_coord: SphereCoord +) extends CoordDocumentEvaluationResult[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid + ]( + document, cell_grid, pred_coord +) with SphereDocumentEvaluationResult { + val xdocument = document +} + +/** + * Specialization of `RankedCellGridEvaluator` for SphereCoords (latitude/ + * longitude coordinates on the surface of a sphere). Class for evaluating + * (geolocating) a test document using a strategy that ranks the cells in the + * cell grid and picks the central point of the top-ranked one. + */ +class RankedSphereCellGridEvaluator( + strategy: GridLocateDocumentStrategy[SphereCell, SphereCellGrid], + stratname: String, + driver: GeolocateDocumentTypeDriver +) extends RankedCellGridEvaluator[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid, + SphereDocumentEvaluationResult +](strategy, stratname, driver) { + def create_grouped_eval_stats(driver: GridLocateDocumentDriver, + cell_grid: SphereCellGrid, results_by_range: Boolean) = + new GroupedSphereDocumentEvalStats( + driver, cell_grid, results_by_range, is_ranked = true) + def create_cell_evaluation_result(document: SphereDocument, + pred_cell: SphereCell, true_rank: Int) = + new RankedSphereDocumentEvaluationResult(document, pred_cell, true_rank) + + override def print_individual_result(doctag: String, document: SphereDocument, + result: SphereDocumentEvaluationResult, + pred_cells: Iterable[(SphereCell, Double)]) { + super.print_individual_result(doctag, document, result, pred_cells) + + assert(doctag(0) == '#') + if (debug("gridrank") || + (debuglist("gridrank") contains doctag.drop(1))) { + val grsize = debugval("gridranksize").toInt + if (!result.true_cell.isInstanceOf[MultiRegularCell]) + warning("Can't output ranking grid, cell not of right type") + else { + strategy.cell_grid.asInstanceOf[MultiRegularCellGrid]. + output_ranking_grid( + pred_cells.asInstanceOf[Iterable[(MultiRegularCell, Double)]], + result.true_cell.asInstanceOf[MultiRegularCell], grsize) + } + } + } +} + +/** + * Specialization of `MeanShiftCellGridEvaluator` for SphereCoords (latitude/ + * longitude coordinates on the surface of a sphere). Class for evaluating + * (geolocating) a test document using a mean-shift strategy, i.e. picking the + * K-best-ranked cells and using the mean-shift algorithm to derive a single + * point that hopefully should be in the center of the largest cluster. + */ +class MeanShiftSphereCellGridEvaluator( + strategy: GridLocateDocumentStrategy[SphereCell, SphereCellGrid], + stratname: String, + driver: GeolocateDocumentTypeDriver, + k_best: Int, + mean_shift_window: Double, + mean_shift_max_stddev: Double, + mean_shift_max_iterations: Int +) extends MeanShiftCellGridEvaluator[ + SphereCoord, SphereDocument, SphereCell, SphereCellGrid, + SphereDocumentEvaluationResult +](strategy, stratname, driver, k_best, mean_shift_window, + mean_shift_max_stddev, mean_shift_max_iterations) { + def create_grouped_eval_stats(driver: GridLocateDocumentDriver, + cell_grid: SphereCellGrid, results_by_range: Boolean) = + new GroupedSphereDocumentEvalStats( + driver, cell_grid, results_by_range, is_ranked = false) + def create_coord_evaluation_result(document: SphereDocument, + cell_grid: SphereCellGrid, pred_coord: SphereCoord) = + new CoordSphereDocumentEvaluationResult(document, cell_grid, pred_coord) + def create_mean_shift_obj(h: Double, max_stddev: Double, + max_iterations: Int) = new SphereMeanShift(h, max_stddev, max_iterations) +} + +case class TitledDocument(title: String, text: String) +class TitledDocumentResult { } + +/** + * A class for geolocation where each test document is a chapter in a book + * in the PCL Travel corpus. + */ +class PCLTravelGeolocateDocumentEvaluator( + strategy: GridLocateDocumentStrategy[SphereCell, SphereCellGrid], + stratname: String, + driver: GeolocateDocumentTypeDriver +) extends CorpusEvaluator[ + TitledDocument, TitledDocumentResult +](stratname, driver) with DocumentIteratingEvaluator[ + TitledDocument, TitledDocumentResult +] { + def iter_documents(filehand: FileHandler, filename: String) = { + val dom = try { + // On error, just return, so that we don't have problems when called + // on the whole PCL corpus dir (which includes non-XML files). + // FIXME!! Needs to use the FileHandler somehow for Hadoop access. + xml.XML.loadFile(filename) + } catch { + case _ => { + warning("Unable to parse XML filename: %s", filename) + null + } + } + + if (dom == null) Seq[TitledDocument]() + else for { + chapter <- dom \\ "div" if (chapter \ "@type").text == "chapter" + val (heads, nonheads) = chapter.child.partition(_.label == "head") + val headtext = (for (x <- heads) yield x.text) mkString "" + val text = (for (x <- nonheads) yield x.text) mkString "" + //errprint("Head text: %s", headtext) + //errprint("Non-head text: %s", text) + } yield TitledDocument(headtext, text) + } + + def evaluate_document(doc: TitledDocument, doctag: String) = { + val dist = driver.word_dist_factory.create_word_dist() + for (text <- Seq(doc.title, doc.text)) + dist.add_document(split_text_into_words(text, ignore_punc = true)) + dist.finish_before_global() + dist.finish_after_global() + val cells = + strategy.return_ranked_cells(dist, include = Iterable[SphereCell]()) + errprint("") + errprint("Document with title: %s", doc.title) + val num_cells_to_show = 5 + for ((rank, cellval) <- (1 to num_cells_to_show) zip cells) { + val (cell, vall) = cellval + if (debug("pcl-travel")) { + errprint(" Rank %d, goodness %g:", rank, vall) + errprint(cell.struct.toString) // indent=4 + } else + errprint(" Rank %d, goodness %g: %s", rank, vall, cell.shortstr) + } + + new TitledDocumentResult() + } + + def output_results(isfinal: Boolean = false) { + } +} + diff --git a/src/main/scala/opennlp/fieldspring/geolocate/TwitterDocument.scala b/src/main/scala/opennlp/fieldspring/geolocate/TwitterDocument.scala new file mode 100644 index 0000000..661d79f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/TwitterDocument.scala @@ -0,0 +1,83 @@ +/////////////////////////////////////////////////////////////////////////////// +// TwitterDocument.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import opennlp.fieldspring.util.textdbutil.Schema + +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +class TwitterTweetDocument( + schema: Schema, + subtable: TwitterTweetDocumentSubtable +) extends SphereDocument(schema, subtable.table) { + var id = 0L + def title = id.toString + + override def set_field(field: String, value: String) { + field match { + case "title" => id = value.toLong + case _ => super.set_field(field, value) + } + } + + def struct = + + { id } + { + if (has_coord) + { coord } + } + +} + +class TwitterTweetDocumentSubtable( + table: SphereDocumentTable +) extends SphereDocumentSubtable[TwitterTweetDocument](table) { + def create_document(schema: Schema) = new TwitterTweetDocument(schema, this) +} + +class TwitterUserDocument( + schema: Schema, + subtable: TwitterUserDocumentSubtable +) extends SphereDocument(schema, subtable.table) { + var userind = blank_memoized_string + def title = unmemoize_string(userind) + + override def set_field(field: String, value: String) { + field match { + case "user" => userind = memoize_string(value) + case _ => super.set_field(field, value) + } + } + + def struct = + + { unmemoize_string(userind) } + { + if (has_coord) + { coord } + } + +} + +class TwitterUserDocumentSubtable( + table: SphereDocumentTable +) extends SphereDocumentSubtable[TwitterUserDocument](table) { + def create_document(schema: Schema) = new TwitterUserDocument(schema, this) +} diff --git a/src/main/scala/opennlp/fieldspring/geolocate/WikipediaDocument.scala b/src/main/scala/opennlp/fieldspring/geolocate/WikipediaDocument.scala new file mode 100644 index 0000000..16b9e21 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/WikipediaDocument.scala @@ -0,0 +1,360 @@ +/////////////////////////////////////////////////////////////////////////////// +// WikipediaDocument.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate + +import collection.mutable + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil.Schema +import opennlp.fieldspring.util.printutil.{errprint, warning} +import opennlp.fieldspring.util.textutil.capfirst + +import opennlp.fieldspring.gridlocate.DistDocumentConverters._ +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ + +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +/** + * A document corresponding to a Wikipedia article. + * + * Defined fields for Wikipedia: + * + * id: Wikipedia article ID (for display purposes only). + * incoming_links: Number of incoming links, or None if unknown. + * redir: If this is a redirect, document title that it redirects to; else + * an empty string. + * + * Other Wikipedia params that we mostly ignore, and don't record: + * + * namespace: Namespace of document (e.g. "Main", "Wikipedia", "File"); if + * the namespace isn't Main, we currently don't record the article at all, + * even if it has a coordinate (e.g. some images do) + * is_list_of: Whether document title is "List of *" + * is_disambig: Whether document is a disambiguation page. + * is_list: Whether document is a list of any type ("List of *", disambig, + * or in Category or Book namespaces) + */ +class WikipediaDocument( + schema: Schema, + subtable: WikipediaDocumentSubtable +) extends SphereDocument(schema, subtable.table) { + var id = 0L + var incoming_links_value: Option[Int] = None + override def incoming_links = incoming_links_value + var redirind = blank_memoized_string + def redir = unmemoize_string(redirind) + var titleind = blank_memoized_string + def title = unmemoize_string(titleind) + + override def set_field(field: String, value: String) { + field match { + case "id" => id = value.toLong + case "title" => titleind = memoize_string(value) + case "redir" => redirind = memoize_string(value) + case "incoming_links" => incoming_links_value = get_int_or_none(value) + case _ => super.set_field(field, value) + } + } + + override def get_field(field: String) = { + field match { + case "id" => id.toString + case "redir" => redir + case "incoming_links" => put_int_or_none(incoming_links) + case _ => super.get_field(field) + } + } + + def struct = + + { title } + { id } + { + if (has_coord) + { coord } + } + { + if (redir.length > 0) + { redir } + } + + + override def toString = { + val redirstr = + if (redir.length > 0) ", redirect to %s".format(redir) else "" + "%s (id=%s%s)".format(super.toString, id, redirstr) + } + + def adjusted_incoming_links = + WikipediaDocument.adjust_incoming_links(incoming_links) +} + +object WikipediaDocument { + /** + * Compute the short form of a document name. If short form includes a + * division (e.g. "Tucson, Arizona"), return a tuple (SHORTFORM, DIVISION); + * else return a tuple (SHORTFORM, None). + */ + def compute_short_form(name: String) = { + val includes_div_re = """(.*?), (.*)$""".r + val includes_parentag_re = """(.*) \(.*\)$""".r + name match { + case includes_div_re(tucson, arizona) => (tucson, arizona) + case includes_parentag_re(tucson, city) => (tucson, null) + case _ => (name, null) + } + } + + def log_adjust_incoming_links(links: Int) = { + if (links == 0) // Whether from unknown count or count is actually zero + 0.01 // So we don't get errors from log(0) + else links + } + + def adjust_incoming_links(incoming_links: Option[Int]) = { + val ail = + incoming_links match { + case None => { + if (debug("some")) + warning("Strange, object has no link count") + 0 + } + case Some(il) => { + if (debug("some")) + errprint("--> Link count is %s", il) + il + } + } + ail + } +} + +/** + * Document table for documents corresponding to Wikipedia articles. + * + * Handling of redirect articles: + * + * (1) Documents that are redirects to articles without geotags (i.e. + * coordinates) should have been filtered out during preprocessing; + * we want to keep only content articles with coordinates, and redirect + * articles to such articles. (But in the future we will want to make + * use of non-geotagged articles, e.g. in label propagation.) + * (2) Documents that redirect to articles with coordinates have associated + * WikipediaDocument objects created for them. These objects have their + * `redir` field set to the name of the article redirected to. (Objects + * for non-redirect articles have this field blank.) + * (3) However, these objects should not appear in the lists of documents by + * split. + * (4) When we read documents in, when we encounter a non-redirect article, + * we call `record_document` to record it, and add it to the cell grid. + * For redirect articles, however, we simply note them in a list, to be + * processed later. + * (5) When we've loaded all documents, we go through the list of redirect + * articles and for each one, we look up the article pointed to and + * call `record_document` with the two articles. We do it this way + * because we don't know the order in which we will load a redirecting + * vs. redirected-to article. + * (6) The effect of the final `record_document` call for redirect articles + * is that (a) the incoming-link count of the redirecting article gets + * added to the redirected-to article, and (b) the name of the redirecting + * article gets recorded as an additional name of the redirected-to + * article. + * (7) Note that currently we don't actually keep a mapping of all the names + * of a given WikipediaDocument; instead, we have tables that + * map names of various sorts to associated articles. The articles + * pointed to in these maps are only content articles, except when there + * happen to be double redirects, i.e. redirects to other redirects. + * Wikipedia daemons actively remove such double redirects by pursuing + * the chain of redirects to the end. We don't do such following + * ourselves; hence we may have some redirect articles listed in the + * maps. (FIXME, we should probably ignore these articles rather than + * record them.) Note that this means that we don't need the + * WikipediaDocument objects for redirect articles once we've finished + * loading the table; they should end up garbage collected. + */ +class WikipediaDocumentSubtable( + override val table: SphereDocumentTable +) extends SphereDocumentSubtable[WikipediaDocument](table) { + def create_document(schema: Schema) = new WikipediaDocument(schema, this) + + override def create_and_init_document(schema: Schema, fieldvals: Seq[String], + record_in_table: Boolean) = { + /** + * FIXME: Perhaps we should filter the document file when we generate it, + * to remove stuff not in the Main namespace. We also need to remove + * the duplication between this function and would_add_document_to_list(). + */ + val namespace = schema.get_field_or_else(fieldvals, "namepace") + if (namespace != null && namespace != "Main") { + errprint("Skipped document %s, namespace %s is not Main", + schema.get_field_or_else(fieldvals, "title", "unknown title??"), + namespace) + null + } else { + val doc = create_document(schema) + doc.set_fields(fieldvals) + if (doc.redir.length > 0) { + if (record_in_table) + redirects += doc + null + } else { + if (record_in_table) + record_document(doc, doc) + doc + } + } + } + + // val wikipedia_fields = Seq("incoming_links", "redir") + + /** + * Mapping from document names to WikipediaDocument objects, using the actual + * case of the document. + */ + val name_to_document = mutable.Map[Word, WikipediaDocument]() + + /** + * Map from short name (lowercased) to list of documents. + * The short name for a document is computed from the document's name. If + * the document name has a comma, the short name is the part before the + * comma, e.g. the short name of "Springfield, Ohio" is "Springfield". + * If the name has no comma, the short name is the same as the document + * name. The idea is that the short name should be the same as one of + * the toponyms used to refer to the document. + */ + val short_lower_name_to_documents = bufmap[Word, WikipediaDocument]() + + /** + * Map from tuple (NAME, DIV) for documents of the form "Springfield, Ohio", + * lowercased. + */ + val lower_name_div_to_documents = + bufmap[(Word, Word), WikipediaDocument]() + + /** + * For each toponym, list of documents matching the name. + */ + val lower_toponym_to_document = bufmap[Word, WikipediaDocument]() + + /** + * Mapping from lowercased document names to WikipediaDocument objects + */ + val lower_name_to_documents = bufmap[Word, WikipediaDocument]() + + /** + * Total # of incoming links for all documents in each split. + */ + val incoming_links_by_split = + table.driver.countermap("incoming_links_by_split") + + /** + * List of documents that are Wikipedia redirect articles, accumulated + * during loading and processed at the end. + */ + val redirects = mutable.Buffer[WikipediaDocument]() + + /** + * Look up a document named NAME and return the associated document. + * Note that document names are case-sensitive but the first letter needs to + * be capitalized. + */ + def lookup_document(name: String) = { + assert(name != null) + assert(name.length > 0) + name_to_document.getOrElse(memoize_string(capfirst(name)), + null.asInstanceOf[WikipediaDocument]) + } + + /** + * Record the document as having NAME as one of its names (there may be + * multiple names, due to redirects). Also add to related lists mapping + * lowercased form, short form, etc. + */ + def record_document_name(name: String, doc: WikipediaDocument) { + // Must pass in properly cased name + // errprint("name=%s, capfirst=%s", name, capfirst(name)) + // println("length=%s" format name.length) + // if (name.length > 1) { + // println("name(0)=0x%x" format name(0).toInt) + // println("name(1)=0x%x" format name(1).toInt) + // println("capfirst(0)=0x%x" format capfirst(name)(0).toInt) + // } + assert(name != null) + assert(name.length > 0) + assert(name == capfirst(name)) + name_to_document(memoize_string(name)) = doc + val loname = name.toLowerCase + val loname_word = memoize_string(loname) + lower_name_to_documents(loname_word) += doc + val (short, div) = WikipediaDocument.compute_short_form(loname) + val short_word = memoize_string(short) + if (div != null) { + val div_word = memoize_string(div) + lower_name_div_to_documents((short_word, div_word)) += doc + } + short_lower_name_to_documents(short_word) += doc + if (!(lower_toponym_to_document(loname_word) contains doc)) + lower_toponym_to_document(loname_word) += doc + if (short_word != loname_word && + !(lower_toponym_to_document(short_word) contains doc)) + lower_toponym_to_document(short_word) += doc + } + + /** + * Record either a normal document ('docfrom' same as 'docto') or a + * redirect ('docfrom' redirects to 'docto'). + */ + def record_document(docfrom: WikipediaDocument, docto: WikipediaDocument) { + record_document_name(docfrom.title, docto) + + // Handle incoming links. + val split = docto.split + val fromlinks = docfrom.adjusted_incoming_links + incoming_links_by_split(split) += fromlinks + if (docfrom.redir != "" && fromlinks != 0) { + // Add count of links pointing to a redirect to count of links + // pointing to the document redirected to, so that the total incoming + // link count of a document includes any redirects to that document. + docto.incoming_links_value = + Some(docto.adjusted_incoming_links + fromlinks) + } + } + + override def finish_document_loading() { + for (x <- redirects) { + val reddoc = lookup_document(x.redir) + if (reddoc != null) + record_document(x, reddoc) + } + /* FIXME: Consider setting the variable itself to null so that no + further additions can happen, to catch bad code. */ + redirects.clear() + super.finish_document_loading() + } + + def construct_candidates(toponym: String) = { + val lw = memoize_string(toponym.toLowerCase) + lower_toponym_to_document(lw) + } + + def word_is_toponym(word: String) = { + val lw = memoize_string(word.toLowerCase) + lower_toponym_to_document contains lw + } +} diff --git a/src/main/scala/opennlp/fieldspring/geolocate/toponym/Toponym.scala b/src/main/scala/opennlp/fieldspring/geolocate/toponym/Toponym.scala new file mode 100644 index 0000000..91f963f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/geolocate/toponym/Toponym.scala @@ -0,0 +1,1644 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.geolocate.toponym + +import collection.mutable +import util.control.Breaks._ +import math._ + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.distances +import opennlp.fieldspring.util.distances.SphereCoord +import opennlp.fieldspring.util.distances.spheredist +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil.FileHandler +import opennlp.fieldspring.util.textdbutil.Schema +import opennlp.fieldspring.util.osutil._ +import opennlp.fieldspring.util.printutil.{errout, errprint, warning} + +import opennlp.fieldspring.gridlocate._ +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.geolocate._ + +import opennlp.fieldspring.worddist.{WordDist,WordDistFactory} +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +/* FIXME: Eliminate this. */ +import GeolocateToponymApp.Params + +// A class holding the boundary of a geographic object. Currently this is +// just a bounding box, but eventually may be expanded to including a +// convex hull or more complex model. + +class Boundary(botleft: SphereCoord, topright: SphereCoord) { + override def toString = { + "%s-%s" format (botleft, topright) + } + + // def __repr__() = { + // "Boundary(%s)" format toString + // } + + def struct = + + def contains(coord: SphereCoord) = { + if (!(coord.lat >= botleft.lat && coord.lat <= topright.lat)) + false + else if (botleft.long <= topright.long) + (coord.long >= botleft.long && coord.long <= topright.long) + else { + // Handle case where boundary overlaps the date line. + (coord.long >= botleft.long && + coord.long <= topright.long + 360.) || + (coord.long >= botleft.long - 360. && + coord.long <= topright.long) + } + } + + def square_area() = distances.square_area(botleft, topright) + + /** + * Iterate over the cells that overlap the boundary. + * + * @param cell_grid FIXME: Currently we need a cell grid of a certain type + * passed in so we can iterate over the cells. Fix this so that + * the algorithm is done differently, or something. + */ + def iter_nonempty_tiling_cells(cell_grid: TopoCellGrid) = { + val botleft_index = cell_grid.coord_to_tiling_cell_index(botleft) + val (latind1, longind1) = (botleft_index.latind, botleft_index.longind) + val topright_index = cell_grid.coord_to_tiling_cell_index(topright) + val (latind2, longind2) = (topright_index.latind, topright_index.longind) + for { + i <- latind1 to latind2 view + val it = if (longind1 <= longind2) longind1 to longind2 view + else (longind1 to cell_grid.maximum_longind view) ++ + (cell_grid.minimum_longind to longind2 view) + j <- it + val index = RegularCellIndex(i, j) + if (cell_grid.tiling_cell_to_documents contains index) + } yield index + } +} + +///////////// Locations //////////// + +// A general location (either locality or division). The following +// fields are defined: +// +// name: Name of location. +// altnames: List of alternative names of location. +// typ: Type of location (locality, agglomeration, country, state, +// territory, province, etc.) +// docmatch: Document corresponding to this location. +// div: Next higher-level division this location is within, or None. + +abstract class Location( + val name: String, + val altnames: Seq[String], + val typ: String +) { + var docmatch: TopoDocument = null + var div: Division = null + def toString(no_document: Boolean = false): String + def shortstr: String + def struct(no_document: Boolean = false): xml.Elem + def distance_to_coord(coord: SphereCoord): Double + def matches_coord(coord: SphereCoord): Boolean +} + +// A location corresponding to an entry in a gazetteer, with a single +// coordinate. +// +// The following fields are defined, in addition to those for Location: +// +// coord: Coordinates of the location, as a SphereCoord object. + +class Locality( + name: String, + val coord: SphereCoord, + altnames: Seq[String], + typ: String +) extends Location(name, altnames, typ) { + + def toString(no_document: Boolean = false) = { + var docmatch = "" + if (!no_document) + docmatch = ", match=%s" format docmatch + "Locality %s (%s) at %s%s" format ( + name, if (div != null) div.path.mkString("/") else "unknown", + coord, docmatch) + } + + // def __repr__() = { + // toString.encode("utf-8") + // } + + def shortstr = { + "Locality %s (%s)" format ( + name, if (div != null) div.path.mkString("/") else "unknown") + } + + def struct(no_document: Boolean = false) = + + { name } + { if (div != null) div.path.mkString("/") else "" } + { coord } + { + if (!no_document) + { if (docmatch != null) docmatch.struct else "none" } + } + + + def distance_to_coord(coord: SphereCoord) = spheredist(coord, coord) + + def matches_coord(coord: SphereCoord) = { + distance_to_coord(coord) <= Params.max_dist_for_close_match + } +} + +// A division higher than a single locality. According to the World +// gazetteer, there are three levels of divisions. For the U.S., this +// corresponds to country, state, county. +// +class Division( + // Tuple of same size as the level #, listing the path of divisions + // from highest to lowest, leading to this division. The last + // element is the same as the "name" of the division. + val path: Seq[String] +) extends Location(path(path.length - 1), Seq[String](), "unknown") { + + // 1, 2, or 3 for first, second, or third-level division + val level = path.length + // List of locations inside of the division. + var locs = mutable.Buffer[Locality]() + // List of locations inside of the division other than those + // rejected as outliers (too far from all other locations). + var goodlocs = mutable.Buffer[Locality]() + // Boundary object specifying the boundary of the area of the + // division. Currently in the form of a rectangular bounding box. + // Eventually may contain a convex hull or even more complex + // cell (e.g. set of convex cells). + var boundary: Boundary = null + // For cell-based Naive Bayes disambiguation, a distribution + // over the division's document and all locations within the cell. + var combined_dist: CombinedWordDist = null + + def toString(no_document: Boolean = false) = { + val docmatchstr = + if (no_document) "" else ", match=%s" format docmatch + "Division %s (%s)%s, boundary=%s" format ( + name, path.mkString("/"), docmatchstr, boundary) + } + + // def __repr__() = toString.encode("utf-8") + + def shortstr = { + ("Division %s" format name) + ( + if (level > 1) " (%s)" format (path.mkString("/")) else "") + } + + def struct(no_document: Boolean = false): xml.Elem = + + { name } + { path.mkString("/") } + { + if (!no_document) + { if (docmatch != null) docmatch.struct else "none" } + } + { boundary.struct } + + + def distance_to_coord(coord: SphereCoord) = java.lang.Double.NaN + + def matches_coord(coord: SphereCoord) = this contains coord + + // Compute the boundary of the geographic cell of this division, based + // on the points in the cell. + def compute_boundary() { + // Yield up all points that are not "outliers", where outliers are defined + // as points that are more than Params.max_dist_for_outliers away from all + // other points. + def iter_non_outliers() = { + // If not enough points, just return them; otherwise too much possibility + // that all of them, or some good ones, will be considered outliers. + if (locs.length <= 5) { + for (p <- locs) yield p + } else { + // FIXME: Actually look for outliers. + for (p <- locs) yield p + //for { + // p <- locs + // // Find minimum distance to all other points and check it. + // mindist = (for (x <- locs if !(x eq p)) yield spheredist(p, x)) min + // if (mindist <= Params.max_dist_for_outliers) + //} yield p + } + } + + if (debug("lots")) { + errprint("Computing boundary for %s, path %s, num points %s", + name, path, locs.length) + } + + goodlocs = iter_non_outliers() + // If we've somehow discarded all points, just use the original list + if (goodlocs.length == 0) { + if (debug("some")) { + warning("All points considered outliers? Division %s, path %s", + name, path) + } + goodlocs = locs + } + // FIXME! This will fail for a division that crosses the International + // Date Line. + val topleft = SphereCoord((for (x <- goodlocs) yield x.coord.lat) min, + (for (x <- goodlocs) yield x.coord.long) min) + val botright = SphereCoord((for (x <- goodlocs) yield x.coord.lat) max, + (for (x <- goodlocs) yield x.coord.long) max) + boundary = new Boundary(topleft, botright) + } + + def generate_word_dist(word_dist_factory: WordDistFactory) { + combined_dist = new CombinedWordDist(word_dist_factory) + for (loc <- Seq(this) ++ goodlocs if loc.docmatch != null) + yield combined_dist.add_document(loc.docmatch) + combined_dist.word_dist.finish_before_global() + combined_dist.word_dist.finish_after_global() + } + + def contains(coord: SphereCoord) = boundary contains coord +} + +class DivisionFactory(gazetteer: Gazetteer) { + // For each division, map from division's path to Division object. + val path_to_division = mutable.Map[Seq[String], Division]() + + // For each tiling cell, list of divisions that have territory in it + val tiling_cell_to_divisions = bufmap[RegularCellIndex, Division]() + + // Find the division for a point in the division with a given path, + // add the point to the division. Create the division if necessary. + // Return the corresponding Division. + def find_division_note_point(loc: Locality, path: Seq[String]): Division = { + val higherdiv = if (path.length > 1) + // Also note location in next-higher division. + find_division_note_point(loc, path.dropRight(1)) + else null + // Skip divisions where last element in path is empty; this is a + // reference to a higher-level division with no corresponding lower-level + // division. + if (path.last.length == 0) higherdiv + else { + val division = { + if (path_to_division contains path) + path_to_division(path) + else { + // If we haven't seen this path, create a new Division object. + // Record the mapping from path to division, and also from the + // division's "name" (name of lowest-level division in path) to + // the division. + val newdiv = new Division(path) + newdiv.div = higherdiv + path_to_division(path) = newdiv + gazetteer.record_division(path.last.toLowerCase, newdiv) + newdiv + } + } + division.locs += loc + division + } + } + + /** + * Finish all computations related to Divisions, after we've processed + * all points (and hence all points have been added to the appropriate + * Divisions). + */ + def finish_all() { + val divs_by_area = mutable.Buffer[(Division, Double)]() + for (division <- path_to_division.values) { + if (debug("lots")) { + errprint("Processing division named %s, path %s", + division.name, division.path) + } + division.compute_boundary() + val docmatch = gazetteer.cell_grid.table.asInstanceOf[TopoDocumentTable]. + topo_subtable.find_match_for_division(division) + if (docmatch != null) { + if (debug("lots")) { + errprint("Matched document %s for division %s, path %s", + docmatch, division.name, division.path) + } + division.docmatch = docmatch + docmatch.location = division + } else { + if (debug("lots")) { + errprint("Couldn't find match for division %s, path %s", + division.name, division.path) + } + } + for (index <- + division.boundary.iter_nonempty_tiling_cells(gazetteer.cell_grid)) + tiling_cell_to_divisions(index) += division + if (debug("cell")) + divs_by_area += ((division, division.boundary.square_area())) + } + if (debug("cell")) { + // sort by second element of tuple, in reverse order + for ((div, area) <- divs_by_area sortWith (_._2 > _._2)) + errprint("%.2f square km: %s", area, div) + } + } +} + +class TopoCellGrid( + degrees_per_cell: Double, + width_of_multi_cell: Int, + table: TopoDocumentTable +) extends MultiRegularCellGrid(degrees_per_cell, width_of_multi_cell, table) { + + /** + * Mapping from tiling cell to documents in the cell. + */ + var tiling_cell_to_documents = bufmap[RegularCellIndex, SphereDocument]() + + /** + * Override so that we can keep a mapping of cell to documents in the cell. + * FIXME: Do we really need this? + */ + override def add_document_to_cell(doc: SphereDocument) { + val index = coord_to_tiling_cell_index(doc.coord) + tiling_cell_to_documents(index) += doc + super.add_document_to_cell(doc) + } +} + +class TopoDocument( + schema: Schema, + subtable: TopoDocumentSubtable +) extends WikipediaDocument(schema, subtable) { + // Cell-based distribution corresponding to this document. + var combined_dist: CombinedWordDist = null + // Corresponding location for this document. + var location: Location = null + + override def toString = { + var ret = super.toString + if (location != null) { + ret += (", matching location %s" format + location.toString(no_document = true)) + } + val divs = find_covering_divisions() + val top_divs = + for (div <- divs if div.level == 1) + yield div.toString(no_document = true) + val topdivstr = + if (top_divs.length > 0) + ", in top-level divisions %s" format (top_divs.mkString(", ")) + else + ", not in any top-level divisions" + ret + topdivstr + } + + override def shortstr = { + var str = super.shortstr + if (location != null) + str += ", matching %s" format location.shortstr + val divs = find_covering_divisions() + val top_divs = (for (div <- divs if div.level == 1) yield div.name) + if (top_divs.length > 0) + str += ", in top-level divisions %s" format (top_divs.mkString(", ")) + str + } + + override def struct = { + val xml = super.struct + + { xml.child } + { + if (location != null) + { location.struct(no_document = true) } + } + { + val divs = find_covering_divisions() + val top_divs = (for (div <- divs if div.level == 1) + yield div.struct(no_document = true)) + if (top_divs != null) + { top_divs } + else + none + } + + } + + def matches_coord(coord: SphereCoord) = { + if (distance_to_coord(coord) <= Params.max_dist_for_close_match) true + else if (location != null && location.isInstanceOf[Division] && + location.matches_coord(coord)) true + else false + } + + // Determine the cell word-distribution object for a given document: + // Create and populate one if necessary. + def find_combined_word_dist(cell_grid: SphereCellGrid) = { + val loc = location + if (loc != null && loc.isInstanceOf[Division]) { + val div = loc.asInstanceOf[Division] + if (div.combined_dist == null) + div.generate_word_dist(cell_grid.table.word_dist_factory) + div.combined_dist + } else { + if (combined_dist == null) { + val cell = cell_grid.find_best_cell_for_document(this, false) + if (cell != null) + combined_dist = cell.combined_dist + else { + warning("Couldn't find existing cell distribution for document %s", + this) + combined_dist = new CombinedWordDist(table.word_dist_factory) + combined_dist.word_dist.finish_before_global() + combined_dist.word_dist.finish_after_global() + } + } + combined_dist + } + } + + // Find the divisions that cover the given document. + def find_covering_divisions() = { + val inds = subtable.gazetteer.cell_grid.coord_to_tiling_cell_index(coord) + val divs = subtable.gazetteer.divfactory.tiling_cell_to_divisions(inds) + (for (div <- divs if div contains coord) yield div) + } +} + +// Static class maintaining additional tables listing mapping between +// names, ID's and documents. See comments at WikipediaDocumentTable. +class TopoDocumentSubtable( + val topo_table: TopoDocumentTable +) extends WikipediaDocumentSubtable(topo_table) { + override def create_document(schema: Schema) = + new TopoDocument(schema, this) + + var gazetteer: Gazetteer = null + + /** + * Set the gazetteer. Must do it this way because creation of the + * gazetteer wants the TopoDocumentTable already created. + */ + def set_gazetteer(gaz: Gazetteer) { + gazetteer = gaz + } + + // Construct the list of possible candidate documents for a given toponym + override def construct_candidates(toponym: String) = { + val lotop = toponym.toLowerCase + val locs = ( + gazetteer.lower_toponym_to_location(lotop) ++ + gazetteer.lower_toponym_to_division(lotop)) + val documents = super.construct_candidates(toponym) + documents ++ ( + for {loc <- locs + if (loc.docmatch != null && !(documents contains loc.docmatch))} + yield loc.docmatch + ) + } + + override def word_is_toponym(word: String) = { + val lw = word.toLowerCase + (super.word_is_toponym(word) || + (gazetteer.lower_toponym_to_location contains lw) || + (gazetteer.lower_toponym_to_division contains lw)) + } + + // Find document matching name NAME for location LOC. NAME will generally + // be one of the names of LOC (either its canonical name or one of the + // alternate name). CHECK_MATCH is a function that is passed one arument, + // the document, and should return true if the location matches the document. + // PREFER_MATCH is used when two or more documents match. It is passed + // two arguments, the two documents. It should return TRUE if the first is + // to be preferred to the second. Return the document matched, or None. + + def find_one_document_match(loc: Location, name: String, + check_match: (TopoDocument) => Boolean, + prefer_match: (TopoDocument, TopoDocument) => Boolean): TopoDocument = { + + val loname = memoize_string(name.toLowerCase) + + // Look for any documents with same name (case-insensitive) as the + // location, check for matches + for (wiki_doc <- lower_name_to_documents(loname); + doc = wiki_doc.asInstanceOf[TopoDocument]) + if (check_match(doc)) return doc + + // Check whether there is a match for a document whose name is + // a combination of the location's name and one of the divisions that + // the location is in (e.g. "Augusta, Georgia" for a location named + // "Augusta" in a second-level division "Georgia"). + if (loc.div != null) { + for { + div <- loc.div.path + lodiv = memoize_string(div.toLowerCase) + wiki_doc <- lower_name_div_to_documents((loname, lodiv)) + doc = wiki_doc.asInstanceOf[TopoDocument] + } if (check_match(doc)) return doc + } + + // See if there is a match with any of the documents whose short name + // is the same as the location's name + val docs = short_lower_name_to_documents(loname) + if (docs != null) { + val gooddocs = + (for (wiki_doc <- docs; + doc = wiki_doc.asInstanceOf[TopoDocument]; + if check_match(doc)) + yield doc) + if (gooddocs.length == 1) + return gooddocs(0) // One match + else if (gooddocs.length > 1) { + // Multiple matches: Sort by preference, return most preferred one + if (debug("lots")) { + errprint("Warning: Saw %s toponym matches: %s", + gooddocs.length, gooddocs) + } + val sorteddocs = gooddocs sortWith (prefer_match(_, _)) + return sorteddocs(0) + } + } + + // No match. + return null + } + + // Find document matching location LOC. CHECK_MATCH and PREFER_MATCH are + // as above. Return the document matched, or None. + + def find_document_match(loc: Location, + check_match: (TopoDocument) => Boolean, + prefer_match: (TopoDocument, TopoDocument) => Boolean): TopoDocument = { + // Try to find a match for the canonical name of the location + val docmatch = find_one_document_match(loc, loc.name, check_match, + prefer_match) + if (docmatch != null) return docmatch + + // No match; try each of the alternate names in turn. + for (altname <- loc.altnames) { + val docmatch2 = find_one_document_match(loc, altname, check_match, + prefer_match) + if (docmatch2 != null) return docmatch2 + } + + // No match. + return null + } + + // Find document matching locality LOC; the two coordinates must be at most + // MAXDIST away from each other. + + def find_match_for_locality(loc: Locality, maxdist: Double) = { + + def check_match(doc: TopoDocument) = { + val dist = spheredist(loc.coord, doc.coord) + if (dist <= maxdist) true + else { + if (debug("lots")) { + errprint("Found document %s but dist %s > %s", + doc, dist, maxdist) + } + false + } + } + + def prefer_match(doc1: TopoDocument, doc2: TopoDocument) = { + spheredist(loc.coord, doc1.coord) < spheredist(loc.coord, doc2.coord) + } + + find_document_match(loc, check_match, prefer_match) + } + + // Find document matching division DIV; the document coordinate must be + // inside of the division's boundaries. + + def find_match_for_division(div: Division) = { + + def check_match(doc: TopoDocument) = { + if (doc.has_coord && (div contains doc.coord)) true + else { + if (debug("lots")) { + if (!doc.has_coord) { + errprint("Found document %s but no coordinate, so not in location named %s, path %s", + doc, div.name, div.path) + } else { + errprint("Found document %s but not in location named %s, path %s", + doc, div.name, div.path) + } + } + false + } + } + + def prefer_match(doc1: TopoDocument, doc2: TopoDocument) = { + val l1 = doc1.incoming_links + val l2 = doc2.incoming_links + // Prefer according to incoming link counts, if that info is available + if (l1 != None && l2 != None) l1.get > l2.get + else { + // FIXME: Do something smart here -- maybe check that location is + // farther in the middle of the bounding box (does this even make + // sense???) + true + } + } + + find_document_match(div, check_match, prefer_match) + } +} + +/** + * A version of SphereDocumentTable that substitutes a TopoDocumentSubtable + * for the Wikipedia subtable. + */ +class TopoDocumentTable( + val topo_driver: GeolocateToponymDriver, + word_dist_factory: WordDistFactory +) extends SphereDocumentTable( + topo_driver, word_dist_factory +) { + val topo_subtable = new TopoDocumentSubtable(this) + override val wikipedia_subtable = topo_subtable +} + +class EvalStatsWithCandidateList( + driver_stats: ExperimentDriverStats, + prefix: String, + incorrect_reasons: Map[String, String], + max_individual_candidates: Int = 5 +) extends EvalStats(driver_stats, prefix, incorrect_reasons) { + + def record_result(correct: Boolean, reason: String, num_candidates: Int) { + super.record_result(correct, reason) + increment_counter("instances.total.by_candidate." + num_candidates) + if (correct) + increment_counter("instances.correct.by_candidate." + num_candidates) + else + increment_counter("instances.incorrect.by_candidate." + num_candidates) + } + + // SCALABUG: The need to write collection.Map here rather than simply + // Map seems clearly wrong. It seems the height of obscurity that + // "collection.Map" is the common supertype of plain "Map"; the use of + // overloaded "Map" seems to be the root of the problem. + def output_table_by_num_candidates(group: String, total: Long) { + for (i <- 0 to max_individual_candidates) + output_fraction(" With %d candidates" format i, + get_counter(group + "." + i), total) + val items = ( + for {counter <- list_counters(group, false, false) + key = counter.toInt + if key > max_individual_candidates + } + yield get_counter(group + "." + counter) + ).sum + output_fraction( + " With %d+ candidates" format (1 + max_individual_candidates), + items, total) + } + + override def output_correct_results() { + super.output_correct_results() + output_table_by_num_candidates("instances.correct", correct_instances) + } + + override def output_incorrect_results() { + super.output_incorrect_results() + output_table_by_num_candidates("instances.incorrect", incorrect_instances) + } +} + +object GeolocateToponymResults { + val incorrect_geolocate_toponym_reasons = Map( + "incorrect_with_no_candidates" -> + "Incorrect, with no candidates", + "incorrect_with_no_correct_candidates" -> + "Incorrect, with candidates but no correct candidates", + "incorrect_with_multiple_correct_candidates" -> + "Incorrect, with multiple correct candidates", + "incorrect_one_correct_candidate_missing_link_info" -> + "Incorrect, with one correct candidate, but link info missing", + "incorrect_one_correct_candidate" -> + "Incorrect, with one correct candidate") +} + +//////// Results for geolocating toponyms +class GeolocateToponymResults(driver_stats: ExperimentDriverStats) { + import GeolocateToponymResults._ + + // Overall statistics + val all_toponym = new EvalStatsWithCandidateList( + driver_stats, "", incorrect_geolocate_toponym_reasons) + // Statistics when toponym not same as true name of location + val diff_surface = new EvalStatsWithCandidateList( + driver_stats, "diff_surface", incorrect_geolocate_toponym_reasons) + // Statistics when toponym not same as true name or short form of location + val diff_short = new EvalStatsWithCandidateList( + driver_stats, "diff_short", incorrect_geolocate_toponym_reasons) + + def record_geolocate_toponym_result(correct: Boolean, toponym: String, + trueloc: String, reason: String, num_candidates: Int) { + all_toponym.record_result(correct, reason, num_candidates) + if (toponym != trueloc) { + diff_surface.record_result(correct, reason, num_candidates) + val (short, div) = WikipediaDocument.compute_short_form(trueloc) + if (toponym != short) + diff_short.record_result(correct, reason, num_candidates) + } + } + + def output_geolocate_toponym_results() { + errprint("Results for all toponyms:") + all_toponym.output_results() + errprint("") + errprint("Results for toponyms when different from true location name:") + diff_surface.output_results() + errprint("") + errprint("Results for toponyms when different from either true location name") + errprint(" or its short form:") + diff_short.output_results() + output_resource_usage() + } +} + +// Class of word in a file containing toponyms. Fields: +// +// word: The identity of the word. +// is_stop: true if it is a stopword. +// is_toponym: true if it is a toponym. +// coord: For a toponym with specified ground-truth coordinate, the +// coordinate. Else, null. +// location: true location if given, else null. +// context: Vector including the word and 10 words on other side. +// document: The document (document, etc.) of the word. Useful when a single +// file contains multiple such documents. +// +class GeogWord(val word: String) { + var is_stop = false + var is_toponym = false + var coord: SphereCoord = null + var location: String = null + var context: Array[(Int, String)] = null + var document: String = null +} + +abstract class GeolocateToponymStrategy { + def need_context(): Boolean + def compute_score(geogword: GeogWord, doc: TopoDocument): Double +} + +// Find each toponym explicitly mentioned as such and disambiguate it +// (find the correct geographic location) using the "link baseline", i.e. +// use the location with the highest number of incoming links. +class BaselineGeolocateToponymStrategy( + cell_grid: SphereCellGrid, + val baseline_strategy: String) extends GeolocateToponymStrategy { + def need_context() = false + + def compute_score(geogword: GeogWord, doc: TopoDocument) = { + val topo_table = cell_grid.table.asInstanceOf[TopoDocumentTable] + val params = topo_table.topo_driver.params + if (baseline_strategy == "internal-link") { + if (params.context_type == "cell") + doc.find_combined_word_dist(cell_grid).incoming_links + else + doc.adjusted_incoming_links + } else if (baseline_strategy == "num-documents") { + if (params.context_type == "cell") + doc.find_combined_word_dist(cell_grid).num_docs_for_links + else { + val location = doc.location + location match { + case x:Division => x.locs.length + case _ => 1 + } + } + } else random + } +} + +// Find each toponym explicitly mentioned as such and disambiguate it +// (find the correct geographic location) using Naive Bayes, possibly +// in conjunction with the baseline. +class NaiveBayesToponymStrategy( + cell_grid: SphereCellGrid, + val use_baseline: Boolean +) extends GeolocateToponymStrategy { + def need_context() = true + + def compute_score(geogword: GeogWord, doc: TopoDocument) = { + // FIXME FIXME!!! We are assuming that the baseline is "internal-link", + // regardless of its actual settings. + val thislinks = WikipediaDocument.log_adjust_incoming_links( + doc.adjusted_incoming_links) + val topo_table = cell_grid.table.asInstanceOf[TopoDocumentTable] + val params = topo_table.topo_driver.params + + val gen_distobj = + if (params.context_type == "document") doc.dist + else doc.find_combined_word_dist(cell_grid).word_dist + val distobj = UnigramStrategy.check_unigram_dist(gen_distobj) + var totalprob = 0.0 + var total_word_weight = 0.0 + val (word_weight, baseline_weight) = + if (!use_baseline) (1.0, 0.0) + else if (params.naive_bayes_weighting == "equal") (1.0, 1.0) + else (1 - params.naive_bayes_baseline_weight, + params.naive_bayes_baseline_weight) + for ((dist, word) <- geogword.context) { + val lword = + if (params.preserve_case_words) word else word.toLowerCase + val wordprob = + distobj.lookup_word(WordDist.memoizer.memoize_string(lword)) + + // Compute weight for each word, based on distance from toponym + val thisweight = + if (params.naive_bayes_weighting == "equal" || + params.naive_bayes_weighting == "equal-words") 1.0 + else 1.0 / (1 + dist) + + total_word_weight += thisweight + totalprob += thisweight * log(wordprob) + } + if (debug("some")) + errprint("Computed total word log-likelihood as %s", totalprob) + // Normalize probability according to the total word weight + if (total_word_weight > 0) + totalprob /= total_word_weight + // Combine word and prior (baseline) probability acccording to their + // relative weights + totalprob *= word_weight + totalprob += baseline_weight * log(thislinks) + if (debug("some")) + errprint("Computed total log-likelihood as %s", totalprob) + totalprob + } +} + +class ToponymEvaluationResult { } +case class GeogWordDocument(words: Iterable[GeogWord]) + +abstract class GeolocateToponymEvaluator( + strategy: GeolocateToponymStrategy, + stratname: String, + driver: GeolocateToponymDriver +) extends CorpusEvaluator[ + GeogWordDocument, ToponymEvaluationResult +](stratname, driver) with DocumentIteratingEvaluator[ + GeogWordDocument, ToponymEvaluationResult +] { + val toponym_results = new GeolocateToponymResults(driver) + + // Given an evaluation file, read in the words specified, including the + // toponyms. Mark each word with the "document" (e.g. document) that it's + // within. + def iter_geogwords(filehand: FileHandler, filename: String): GeogWordDocument + + // Retrieve the words yielded by iter_geogwords() and separate by "document" + // (e.g. document); yield each "document" as a list of such GeogWord objects. + // If compute_context, also generate the set of "context" words used for + // disambiguation (some window, e.g. size 20, of words around each + // toponym). + def iter_documents(filehand: FileHandler, filename: String) = { + def return_word(word: GeogWord) = { + if (word.is_toponym) { + if (debug("lots")) { + errprint("Saw loc %s with true coordinates %s, true location %s", + word.word, word.coord, word.location) + } + } else { + if (debug("tons")) + errprint("Non-toponym %s", word.word) + } + word + } + + for ((k, g) <- iter_geogwords(filehand, filename).words.groupBy(_.document)) + yield { + if (k != null) + errprint("Processing document %s...", k) + val results = (for (word <- g) yield return_word(word)).toArray + + // Now compute context for words + val nbcl = driver.params.naive_bayes_context_len + if (strategy.need_context()) { + // First determine whether each word is a stopword + for (i <- 0 until results.length) { + // FIXME: Check that we aren't accessing a list or something with + // O(N) random access + // If a word tagged as a toponym is homonymous with a stopword, it + // still isn't a stopword. + results(i).is_stop = (results(i).coord == null && + (driver.stopwords contains results(i).word)) + } + // Now generate context for toponyms + for (i <- 0 until results.length) { + // FIXME: Check that we aren't accessing a list or something with + // O(N) random access + if (results(i).coord != null) { + // Select up to naive_bayes_context_len words on either side; + // skip stopwords. Associate each word with the distance away from + // the toponym. + val minind = 0 max i - nbcl + val maxind = results.length min i + nbcl + 1 + results(i).context = + (for { + (dist, x) <- ((i - minind until i - maxind) zip results.slice(minind, maxind)) + if (!(driver.stopwords contains x.word)) + } yield (dist, x.word)).toArray + } + } + } + + val geogwords = + (for (word <- results if word.coord != null) yield word).toIterable + new GeogWordDocument(geogwords) + } + } + + // Disambiguate the toponym, specified in GEOGWORD. Determine the possible + // locations that the toponym can map to, and call COMPUTE_SCORE on each one + // to determine a score. The best score determines the location considered + // "correct". Locations without a matching document are skipped. The + // location considered "correct" is compared with the actual correct + // location specified in the toponym, and global variables corresponding to + // the total number of toponyms processed and number correctly determined are + // incremented. Various debugging info is output if 'debug' is set. + // COMPUTE_SCORE is passed two arguments: GEOGWORD and the location to + // compute the score of. + + def disambiguate_toponym(geogword: GeogWord) { + val toponym = geogword.word + val coord = geogword.coord + if (coord == null) return // If no ground-truth, skip it + val documents = + driver.document_table.wikipedia_subtable.construct_candidates(toponym) + var bestscore = Double.MinValue + var bestdoc: TopoDocument = null + if (documents.length == 0) { + if (debug("some")) + errprint("Unable to find any possibilities for %s", toponym) + } else { + if (debug("some")) { + errprint("Considering toponym %s, coordinates %s", + toponym, coord) + errprint("For toponym %s, %d possible documents", + toponym, documents.length) + } + for (idoc <- documents) { + val doc = idoc.asInstanceOf[TopoDocument] + if (debug("some")) + errprint("Considering document %s", doc) + val thisscore = strategy.compute_score(geogword, doc) + if (thisscore > bestscore) { + bestscore = thisscore + bestdoc = doc + } + } + } + val correct = + if (bestdoc != null) + bestdoc.matches_coord(coord) + else + false + + val num_candidates = documents.length + + val reason = + if (correct) null + else { + if (num_candidates == 0) + "incorrect_with_no_candidates" + else { + val good_docs = + (for { idoc <- documents + val doc = idoc.asInstanceOf[TopoDocument] + if doc.matches_coord(coord) + } + yield doc) + if (good_docs == null) + "incorrect_with_no_correct_candidates" + else if (good_docs.length > 1) + "incorrect_with_multiple_correct_candidates" + else { + val gooddoc = good_docs(0) + if (gooddoc.incoming_links == None) + "incorrect_one_correct_candidate_missing_link_info" + else + "incorrect_one_correct_candidate" + } + } + } + + errout("Eval: Toponym %s (true: %s at %s),", toponym, geogword.location, + coord) + if (correct) + errprint("correct") + else + errprint("incorrect, reason = %s", reason) + + toponym_results.record_geolocate_toponym_result(correct, toponym, + geogword.location, reason, num_candidates) + + if (debug("some") && bestdoc != null) { + errprint("Best document = %s, score = %s, dist = %s, correct %s", + bestdoc, bestscore, bestdoc.distance_to_coord(coord), correct) + } + } + + def evaluate_document(doc: GeogWordDocument, doctag: String) = { + for (geogword <- doc.words) + disambiguate_toponym(geogword) + new ToponymEvaluationResult() + } + + def output_results(isfinal: Boolean = false) { + toponym_results.output_geolocate_toponym_results() + } +} + +class TRCoNLLGeolocateToponymEvaluator( + strategy: GeolocateToponymStrategy, + stratname: String, + driver: GeolocateToponymDriver +) extends GeolocateToponymEvaluator(strategy, stratname, driver) { + // Read a file formatted in TR-CONLL text format (.tr files). An example of + // how such files are fomatted is: + // + //... + //... + //last O I-NP JJ + //week O I-NP NN + //&equo;s O B-NP POS + //U.N. I-ORG I-NP NNP + //Security I-ORG I-NP NNP + //Council I-ORG I-NP NNP + //resolution O I-NP NN + //threatening O I-VP VBG + //a O I-NP DT + //ban O I-NP NN + //on O I-PP IN + //Sudanese I-MISC I-NP NNP + //flights O I-NP NNS + //abroad O I-ADVP RB + //if O I-SBAR IN + //Khartoum LOC + // >c1 NGA 15.5833333 32.5333333 Khartoum > Al Kharom > Sudan + // c2 NGA -17.8833333 30.1166667 Khartoum > Zimbabwe + // c3 NGA 15.5880556 32.5341667 Khartoum > Al Kharom > Sudan + // c4 NGA 15.75 32.5 Khartoum > Al Kharom > Sudan + //does O I-VP VBZ + //not O I-NP RB + //hand O I-NP NN + //over O I-PP IN + //three O I-NP CD + //men O I-NP NNS + //... + //... + // + // Yield GeogWord objects, one per word. + def iter_geogwords(filehand: FileHandler, filename: String) = { + var in_loc = false + var wordstruct: GeogWord = null + val lines = filehand.openr(filename, errors = "replace") + def iter_1(): Stream[GeogWord] = { + if (lines.hasNext) { + val line = lines.next + try { + val ss = """\t""".r.split(line) + require(ss.length == 2) + val Array(word, ty) = ss + if (word != null) { + var toyield = null: GeogWord + if (in_loc) { + in_loc = false + toyield = wordstruct + } + wordstruct = new GeogWord(word) + wordstruct.document = filename + if (ty.startsWith("LOC")) { + in_loc = true + wordstruct.is_toponym = true + } else + toyield = wordstruct + if (toyield != null) + return toyield #:: iter_1() + } else if (in_loc && ty(0) == '>') { + val ss = """\t""".r.split(ty) + require(ss.length == 5) + val Array(_, lat, long, fulltop, _) = ss + wordstruct.coord = SphereCoord(lat.toDouble, long.toDouble) + wordstruct.location = fulltop + } + } catch { + case exc: Exception => { + errprint("Bad line %s", line) + errprint("Exception is %s", exc) + if (!exc.isInstanceOf[NumberFormatException]) + exc.printStackTrace() + } + } + return iter_1() + } else if (in_loc) + return wordstruct #:: Stream[GeogWord]() + else + return Stream[GeogWord]() + } + new GeogWordDocument(iter_1()) + } +} + +class WikipediaGeolocateToponymEvaluator( + strategy: GeolocateToponymStrategy, + stratname: String, + driver: GeolocateToponymDriver +) extends GeolocateToponymEvaluator(strategy, stratname, driver) { + def iter_geogwords(filehand: FileHandler, filename: String) = { + var title: String = null + val titlere = """Article title: (.*)$""".r + val linkre = """Link: (.*)$""".r + val lines = filehand.openr(filename, errors = "replace") + def iter_1(): Stream[GeogWord] = { + if (lines.hasNext) { + val line = lines.next + line match { + case titlere(mtitle) => { + title = mtitle + iter_1() + } + case linkre(mlink) => { + val args = mlink.split('|') + val truedoc = args(0) + var linkword = truedoc + if (args.length > 1) + linkword = args(1) + val word = new GeogWord(linkword) + word.is_toponym = true + word.location = truedoc + word.document = title + val doc = + driver.document_table.wikipedia_subtable. + lookup_document(truedoc) + if (doc != null) + word.coord = doc.coord + word #:: iter_1() + } + case _ => { + val word = new GeogWord(line) + word.document = title + word #:: iter_1() + } + } + } else + Stream[GeogWord]() + } + new GeogWordDocument(iter_1()) + } +} + +class Gazetteer(val cell_grid: TopoCellGrid) { + + // Factory object for creating new divisions relative to the gazetteer + val divfactory = new DivisionFactory(this) + + // For each toponym (name of location), value is a list of Locality + // items, listing gazetteer locations and corresponding matching documents. + val lower_toponym_to_location = bufmap[String,Locality]() + + // For each toponym corresponding to a division higher than a locality, + // list of divisions with this name. + val lower_toponym_to_division = bufmap[String,Division]() + + // Table of all toponyms seen in evaluation files, along with how many + // times seen. Used to determine when caching of certain + // toponym-specific values should be done. + //val toponyms_seen_in_eval_files = intmap[String]() + + /** + * Record mapping from name to Division. + */ + def record_division(name: String, div: Division) { + lower_toponym_to_division(name) += div + } + + // Given an evaluation file, count the toponyms seen and add to the + // global count in toponyms_seen_in_eval_files. + // def count_toponyms_in_file(fname: String) { + // def count_toponyms(geogword: GeogWord) { + // toponyms_seen_in_eval_files(geogword.word.toLowerCase) += 1 + // } + // process_eval_file(fname, count_toponyms, compute_context = false, + // only_toponyms = true) + // } +} + +/** + * Gazetteer of the World-gazetteer format. + * + * @param filename File holding the World Gazetteer. + * + * @param cell_grid FIXME: Currently required for certain internal reasons. + * Fix so we don't need it, or it's created internally! + */ +class WorldGazetteer( + filehand: FileHandler, + filename: String, + cell_grid: TopoCellGrid +) extends Gazetteer(cell_grid) { + + // Find the document matching an entry in the gazetteer. + // The format of an entry is + // + // ID NAME ALTNAMES ORIG-SCRIPT-NAME TYPE POPULATION LAT LONG DIV1 DIV2 DIV3 + // + // where there is a tab character separating each field. Fields may + // be empty; but there will still be a tab character separating the + // field from others. + // + // The ALTNAMES specify any alternative names of the location, often + // including the equivalent of the original name without any accent + // characters. If there is more than one alternative name, the + // possibilities are separated by a comma and a space, e.g. + // "Dongshi, Dongshih, Tungshih". The ORIG-SCRIPT-NAME is the name + // in its original script, if that script is not Latin characters + // (e.g. names in Russia will be in Cyrillic). (For some reason, names + // in Chinese characters are listed in the ALTNAMES rather than the + // ORIG-SCRIPT-NAME.) + // + // LAT and LONG specify the latitude and longitude, respectively. + // These are given as integer values, where the actual value is found + // by dividing this integer value by 100. + // + // DIV1, DIV2 and DIV3 specify different-level divisions that a location is + // within, from largest to smallest. Typically the largest is a country. + // For locations in the U.S., the next two levels will be state and county, + // respectively. Note that such divisions also have corresponding entries + // in the gazetteer. However, these entries are somewhat lacking in that + // (1) no coordinates are given, and (2) only the top-level division (the + // country) is given, even for third-level divisions (e.g. counties in the + // U.S.). + // + // For localities, add them to the cell-map that covers the earth if + // ADD_TO_CELL_MAP is true. + + protected def match_world_gazetteer_entry(line: String) { + // Split on tabs, make sure at least 11 fields present and strip off + // extra whitespace + var fields = """\t""".r.split(line.trim) ++ Seq.fill(11)("") + fields = (for (x <- fields.slice(0, 11)) yield x.trim) + val Array(id, name, altnames, orig_script_name, typ, population, + lat, long, div1, div2, div3) = fields + + // Skip places without coordinates + if (lat == "" || long == "") { + if (debug("lots")) + errprint("Skipping location %s (div %s/%s/%s) without coordinates", + name, div1, div2, div3) + return + } + + if (lat == "0" && long == "9999") { + if (debug("lots")) + errprint("Skipping location %s (div %s/%s/%s) with bad coordinates", + name, div1, div2, div3) + return + } + + // Create and populate a Locality object + val loc = new Locality(name, SphereCoord(lat.toInt / 100., long.toInt / 100.), + typ = typ, altnames = if (altnames != null) ", ".r.split(altnames) else null) + loc.div = divfactory.find_division_note_point(loc, Seq(div1, div2, div3)) + if (debug("lots")) + errprint("Saw location %s (div %s/%s/%s) with coordinates %s", + loc.name, div1, div2, div3, loc.coord) + + // Record the location. For each name for the location (its + // canonical name and all alternates), add the location to the list of + // locations associated with the name. Record the name in lowercase + // for ease in matching. + for (name <- Seq(loc.name) ++ loc.altnames) { + val loname = name.toLowerCase + if (debug("lots")) + errprint("Noting lower_toponym_to_location for toponym %s, canonical name %s", + name, loc.name) + lower_toponym_to_location(loname) += loc + } + + // We start out looking for documents whose distance is very close, + // then widen until we reach params.max_dist_for_close_match. + var maxdist = 5 + var docmatch: TopoDocument = null + val topo_table = cell_grid.table.asInstanceOf[TopoDocumentTable] + val params = topo_table.topo_driver.params + + breakable { + while (maxdist <= params.max_dist_for_close_match) { + docmatch = + topo_table.topo_subtable.find_match_for_locality(loc, maxdist) + if (docmatch != null) break + maxdist *= 2 + } + } + + if (docmatch == null) { + if (debug("lots")) + errprint("Unmatched name %s", loc.name) + return + } + + // Record the match. + loc.docmatch = docmatch + docmatch.location = loc + if (debug("lots")) + errprint("Matched location %s (coord %s) with document %s, dist=%s", + (loc.name, loc.coord, docmatch, + spheredist(loc.coord, docmatch.coord))) + } + + // Read in the data from the World gazetteer in FILENAME and find the + // document matching each entry in the gazetteer. (Unimplemented: + // For localities, add them to the cell-map that covers the earth if + // ADD_TO_CELL_MAP is true.) + protected def read_world_gazetteer_and_match() { + val topo_table = cell_grid.table.asInstanceOf[TopoDocumentTable] + val params = topo_table.topo_driver.params + + val task = new ExperimentMeteredTask(cell_grid.table.driver, + "gazetteer entry", "matching", maxtime = params.max_time_per_stage) + errprint("Matching gazetteer entries in %s...", filename) + errprint("") + + // Match each entry in the gazetteer + breakable { + val lines = filehand.openr(filename) + try { + for (line <- lines) { + if (debug("lots")) + errprint("Processing line: %s", line) + match_world_gazetteer_entry(line) + if (task.item_processed()) + break + } + } finally { + lines.close() + } + } + + divfactory.finish_all() + task.finish() + output_resource_usage() + } + + // Upon creation, populate gazetteer from file + read_world_gazetteer_and_match() +} + +class GeolocateToponymParameters( + parser: ArgParser = null +) extends GeolocateParameters(parser) { + var eval_format = + ap.option[String]("f", "eval-format", + default = "wikipedia", + choices = Seq("wikipedia", "raw-text", "tr-conll"), + help = """Format of evaluation file(s). The evaluation files themselves +are specified using --eval-file. The following formats are +recognized: + +'wikipedia' is the normal format. The data file is in a format very +similar to that of the counts file, but has "toponyms" identified using +the prefix 'Link: ' followed either by a toponym name or the format +'DOCUMENT-NAME|TOPONYM', indicating a toponym (e.g. 'London') that maps +to a given document that disambiguates the toponym (e.g. 'London, +Ontario'). When a raw toponym is given, the document is assumed to have +the same name as the toponym. (The format is called 'wikipedia' because +the link format used here comes directly from the way that links are +specified in Wikipedia documents, and often the text itself comes from +preprocessing a Wikipedia dump.) The mapping here is used for evaluation +but not for constructing training data. + +'raw-text' assumes that the eval file is simply raw text. +(NOT YET IMPLEMENTED.) + +'tr-conll' is an alternative. It specifies the toponyms in a document +along with possible locations to map to, with the correct one identified. +As with the 'document' format, the correct location is used only for +evaluation, not for constructing training data; the other locations are +ignored.""") + + var gazetteer_file = + ap.option[String]("gazetteer-file", "gf", + help = """File containing gazetteer information to match.""") + var gazetteer_type = + ap.option[String]("gazetteer-type", "gt", + metavar = "FILE", + default = "world", choices = Seq("world", "db"), + help = """Type of gazetteer file specified using --gazetteer-file. +NOTE: type 'world' is the only one currently implemented. Default +'%default'.""") + + var strategy = + ap.multiOption[String]("s", "strategy", + default = Seq("baseline"), + aliasedChoices = Seq( + Seq("baseline"), + Seq("none"), + Seq("naive-bayes-with-baseline", "nb-base"), + Seq("naive-bayes-no-baseline", "nb-nobase")), + help = """Strategy/strategies to use for geolocating. +'baseline' means just use the baseline strategy (see --baseline-strategy). + +'none' means don't do any geolocating. Useful for testing the parts that +read in data and generate internal structures. + +'naive-bayes-with-baseline' (or 'nb-base') means also use the words around the +toponym to be disambiguated, in a Naive-Bayes scheme, using the baseline as the +prior probability; 'naive-bayes-no-baseline' (or 'nb-nobase') means use uniform +prior probability. + +Default is 'baseline'. + +NOTE: Multiple --strategy options can be given, and each strategy will +be tried, one after the other.""") + + var baseline_strategy = + ap.multiOption[String]("baseline-strategy", "bs", + default = Seq("internal-link"), + aliasedChoices = Seq( + Seq("internal-link", "link"), + Seq("random"), + Seq("num-documents", "numdocs", "num-docs")), + help = """Strategy to use to compute the baseline. + +'internal-link' (or 'link') means use number of internal links pointing to the +document or cell. + +'random' means choose randomly. + +'num-documents' (or 'num-docs' or 'numdocs'; only in cell-type matching) means +use number of documents in cell. + +Default '%default'. + +NOTE: Multiple --baseline-strategy options can be given, and each strategy will +be tried, one after the other.""") + + var naive_bayes_context_len = + ap.option[Int]("naive-bayes-context-len", "nbcl", + default = 10, + help = """Number of words on either side of a toponym to use +in Naive Bayes matching, during toponym resolution. Default %default.""") + var max_dist_for_close_match = + ap.option[Double]("max-dist-for-close-match", "mdcm", + default = 80.0, + help = """Maximum number of km allowed when looking for a +close match for a toponym during toponym resolution. Default %default.""") + var max_dist_for_outliers = + ap.option[Double]("max-dist-for-outliers", "mdo", + default = 200.0, + help = """Maximum number of km allowed between a point and +any others in a division, during toponym resolution. Points farther away than +this are ignored as "outliers" (possible errors, etc.). NOTE: Not +currently implemented. Default %default.""") + var context_type = + ap.option[String]("context-type", "ct", + default = "cell-dist-document-links", + choices = Seq("document", "cell", "cell-dist-document-links"), + help = """Type of context used when doing disambiguation. +There are two cases where this choice applies: When computing a word +distribution, and when counting the number of incoming internal links. +'document' means use the document itself for both. 'cell' means use the +cell for both. 'cell-dist-document-links' means use the cell for +computing a word distribution, but the document for counting the number of +incoming internal links. (Note that this only applies when doing toponym +resolution. During document resolution, only cells are considered.) +Default '%default'.""") +} + +class GeolocateToponymDriver extends + GeolocateDriver with StandaloneExperimentDriverStats { + type TParam = GeolocateToponymParameters + type TRunRes = + Seq[(String, GeolocateToponymStrategy, CorpusEvaluator[_,_])] + + override def handle_parameters() { + super.handle_parameters() + + /* FIXME: Eliminate this. */ + GeolocateToponymApp.Params = params + + need(params.gazetteer_file, "gazetteer-file") + // FIXME! Can only currently handle World-type gazetteers. + if (params.gazetteer_type != "world") + param_error("Currently can only handle world-type gazetteers") + + if (params.strategy == Seq("baseline")) + () + + if (params.eval_format == "raw-text") { + // FIXME!!!! + param_error("Raw-text reading not implemented yet") + } + + need_seq(params.eval_file, "eval-file", "evaluation file(s)") + } + + override protected def initialize_document_table( + word_dist_factory: WordDistFactory) = { + new TopoDocumentTable(this, word_dist_factory) + } + + override protected def initialize_cell_grid(table: SphereDocumentTable) = { + if (params.kd_tree) + param_error("Can't currently handle K-d trees") + new TopoCellGrid(degrees_per_cell, params.width_of_multi_cell, + table.asInstanceOf[TopoDocumentTable]) + } + + /** + * Do the actual toponym geolocation. Results to stderr (see above), and + * also returned. + * + * Return value very much like for run_after_setup() for document + * geolocation, but less useful info may be returned for each document + * processed. + */ + + def run_after_setup() = { + // errprint("Processing evaluation file(s) %s for toponym counts...", + // args.eval_file) + // process_dir_files(args.eval_file, count_toponyms_in_file) + // errprint("Number of toponyms seen: %s", + // toponyms_seen_in_eval_files.length) + // errprint("Number of toponyms seen more than once: %s", + // (for {(foo,count) <- toponyms_seen_in_eval_files + // if (count > 1)} yield foo).length) + // output_reverse_sorted_table(toponyms_seen_in_eval_files, + // outfile = sys.stderr) + + if (params.gazetteer_file != null) { + /* FIXME!!! */ + assert(cell_grid.isInstanceOf[TopoCellGrid]) + val gazetteer = + new WorldGazetteer(get_file_handler, params.gazetteer_file, + cell_grid.asInstanceOf[TopoCellGrid]) + // Bootstrapping issue: Creating the gazetteer requires that the + // TopoDocumentTable already exist, but the TopoDocumentTable wants + // a pointer to a gazetter, so have to set it afterwards. + document_table.wikipedia_subtable. + asInstanceOf[TopoDocumentSubtable].set_gazetteer(gazetteer) + } + + val strats_unflat = ( + for (stratname <- params.strategy) yield { + // Generate strategy object + if (stratname == "baseline") { + for (basestratname <- params.baseline_strategy) + yield ("baseline " + basestratname, + new BaselineGeolocateToponymStrategy(cell_grid, basestratname)) + } else { + val strategy = new NaiveBayesToponymStrategy(cell_grid, + use_baseline = (stratname == "naive-bayes-with-baseline")) + Seq((stratname, strategy)) + } + }) + val strats = strats_unflat reduce (_ ++ _) + process_strategies(strats)((stratname, strategy) => { + // Generate reader object + if (params.eval_format == "tr-conll") + new TRCoNLLGeolocateToponymEvaluator(strategy, stratname, this) + else + new WikipediaGeolocateToponymEvaluator(strategy, stratname, this) + }) + } +} + +object GeolocateToponymApp extends GeolocateApp("geolocate-toponyms") { + type TDriver = GeolocateToponymDriver + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() + var Params: TParam = _ +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/Cell.scala b/src/main/scala/opennlp/fieldspring/gridlocate/Cell.scala new file mode 100644 index 0000000..4cd37ff --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/Cell.scala @@ -0,0 +1,480 @@ +/////////////////////////////////////////////////////////////////////////////// +// Cell.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2011, 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import opennlp.fieldspring.util.printutil.{errprint, warning} +import opennlp.fieldspring.util.experiment._ + +import opennlp.fieldspring.worddist.WordDistFactory +/* FIXME: Eliminate this. */ +import GridLocateDriver.Params + +///////////////////////////////////////////////////////////////////////////// +// Word distributions // +///////////////////////////////////////////////////////////////////////////// + +/** + * Distribution over words resulting from combining the individual + * distributions of a number of documents. We track the number of + * documents making up the distribution, as well as the total incoming link + * count for all of these documents. Note that some documents contribute + * to the link count but not the word distribution; hence, there are two + * concepts of "empty", depending on whether all contributing documents or + * only those that contributed to the word distribution are counted. + * (The primary reason for documents not contributing to the distribution + * is that they're not in the training set; see comments below. However, + * some documents simply don't have distributions defined for them in the + * document file -- e.g. if there was a problem extracting the document's + * words in the preprocessing stage.) + * + * Note that we embed the actual object describing the word distribution + * as a field in this object, rather than extending (subclassing) WordDist. + * The reason for this is that there are multiple types of WordDists, and + * so subclassing would require creating a different subclass for every + * such type, along with extra boilerplate functions to create objects of + * these subclasses. + */ +class CombinedWordDist(factory: WordDistFactory) { + /** The combined word distribution itself. */ + val word_dist = factory.create_word_dist() + /** Number of documents included in incoming-link computation. */ + var num_docs_for_links = 0 + /** Total number of incoming links. */ + var incoming_links = 0 + /** Number of documents included in word distribution. All such + * documents also contribute to the incoming link count. */ + var num_docs_for_word_dist = 0 + + /** True if no documents have contributed to the word distribution. + * This should generally be the same as if the distribution is empty + * (unless documents with an empty distribution were added??). */ + def is_empty_for_word_dist() = num_docs_for_word_dist == 0 + + /** True if the object is completely empty. This means no documents + * at all have been added using `add_document`. */ + def is_empty() = num_docs_for_links == 0 + + /** + * Add the given document to the total distribution seen so far. + * `partial` is a scaling factor (between 0.0 and 1.0) used for + * interpolating multiple distributions. + */ + def add_document(doc: DistDocument[_], partial: Double = 1.0) { + /* Formerly, we arranged things so that we were passed in all documents, + regardless of the split. The reason for this was that the decision + was made to accumulate link counts from all documents, even in the + evaluation set. + + Strictly, this is a violation of the "don't train on your evaluation + set" rule. The reason motivating this was that + + (1) The links are used only in Naive Bayes, and only in establishing + a prior probability. Hence they aren't the main indicator. + (2) Often, nearly all the link count for a given cell comes from + a particular document -- e.g. the Wikipedia article for the primary + city in the cell. If we pull the link count for this document + out of the cell because it happens to be in the evaluation set, + we will totally distort the link count for this cell. In a "real" + usage case, we would be testing against an unknown document, not + against a document in our training set that we've artificially + removed so as to construct an evaluation set, and this problem + wouldn't arise, so by doing this we are doing a more realistic + evaluation. + + Note that we do NOT include word counts from dev-set or test-set + documents in the word distribution for a cell. This keeps to the + above rule about only training on your training set, and is OK + because (1) each document in a cell contributes a similar amount of + word counts (assuming the documents are somewhat similar in size), + hence in a cell with multiple documents, each individual document + only computes a fairly small fraction of the total word counts; + (2) distributions are normalized in any case, so the exact number + of documents in a cell does not affect the distribution. + + However, once the corpora were separated into sub-corpora based on + the training/dev/test split, passing in all documents complicated + things, as it meant having to read all the sub-corpora. Furthermore, + passing in non-training documents into the K-d cell grid changes the + grids in ways that are not easily predictable -- a significantly + greater effect than simply changing the link counts. So (for the + moment at least) we don't do this any more. */ + assert (doc.split == "training") + + /* Add link count of document to cell. */ + doc.incoming_links match { + // Might be None, for unknown link count + case Some(x) => incoming_links += x + case _ => + } + num_docs_for_links += 1 + + if (doc.dist == null) { + if (Params.max_time_per_stage == 0.0 && Params.num_training_docs == 0) + warning("Saw document %s without distribution", doc) + } else { + word_dist.add_word_distribution(doc.dist, partial) + num_docs_for_word_dist += 1 + } + } +} + +///////////////////////////////////////////////////////////////////////////// +// Cells in a grid // +///////////////////////////////////////////////////////////////////////////// + +/** + * Abstract class for a general cell in a cell grid. + * + * @param cell_grid The CellGrid object for the grid this cell is in. + * @tparam TCoord The type of the coordinate object used to specify a + * a point somewhere in the grid. + * @tparam TDoc The type of documents stored in a cell in the grid. + */ +abstract class GeoCell[TCoord, TDoc <: DistDocument[TCoord]]( + val cell_grid: CellGrid[TCoord, TDoc, + _ <: GeoCell[TCoord, TDoc]] +) { + val combined_dist = + new CombinedWordDist(cell_grid.table.word_dist_factory) + var most_popular_document: TDoc = _ + var mostpopdoc_links = 0 + + /** + * Return a string describing the location of the cell in its grid, + * e.g. by its boundaries or similar. + */ + def describe_location(): String + + /** + * Return a string describing the indices of the cell in its grid. + * Only used for debugging. + */ + def describe_indices(): String + + /** + * Return the coordinate of the "center" of the cell. This is the + * coordinate used in computing distances between arbitary points and + * given cells, for evaluation and such. For odd-shaped cells, the + * center can be more or less arbitrarily placed as long as it's somewhere + * central. + */ + def get_center_coord(): TCoord + + /** + * Return true if we have finished creating and populating the cell. + */ + def finished = combined_dist.word_dist.finished + /** + * Return a string representation of the cell. Generally does not need + * to be overridden. + */ + override def toString = { + val unfinished = if (finished) "" else ", unfinished" + val contains = + if (most_popular_document != null) + ", most-pop-doc %s(%d links)" format ( + most_popular_document, mostpopdoc_links) + else "" + + "GeoCell(%s%s%s, %d documents(dist), %d documents(links), %s types, %s tokens, %d links)" format ( + describe_location(), unfinished, contains, + combined_dist.num_docs_for_word_dist, + combined_dist.num_docs_for_links, + combined_dist.word_dist.model.num_types, + combined_dist.word_dist.model.num_tokens, + combined_dist.incoming_links) + } + + // def __repr__() = { + // toString.encode("utf-8") + // } + + /** + * Return a shorter string representation of the cell, for + * logging purposes. + */ + def shortstr = { + var str = "Cell %s" format describe_location() + val mostpop = most_popular_document + if (mostpop != null) + str += ", most-popular %s" format mostpop.shortstr + str + } + + /** + * Return an XML representation of the cell. Currently used only for + * debugging-output purposes, so the exact representation isn't too important. + */ + def struct() = + + { describe_location() } + { finished } + { + if (most_popular_document != null) + (most_popular_document.struct() + mostpopdoc_links) + } + { combined_dist.num_docs_for_word_dist } + { combined_dist.num_docs_for_links } + { combined_dist.incoming_links } + + + /** + * Add a document to the distribution for the cell. + */ + def add_document(doc: TDoc) { + assert(!finished) + combined_dist.add_document(doc) + if (doc.incoming_links != None && + doc.incoming_links.get > mostpopdoc_links) { + mostpopdoc_links = doc.incoming_links.get + most_popular_document = doc + } + } + + /** + * Finish any computations related to the cell's word distribution. + */ + def finish() { + assert(!finished) + combined_dist.word_dist.finish_before_global() + combined_dist.word_dist.finish_after_global() + } +} + +/** + * A mix-in trait for GeoCells that create their distribution by remembering + * all the documents that go into the distribution, and then generating + * the distribution from them at the end. + * + * NOTE: This is *not* the ideal way of doing things! It can cause + * out-of-memory errors for large corpora. It is better to create the + * distributions on the fly. Note that for K-d cells this may require + * two passes over the input corpus: One to note the documents that go into + * the cells and create the cells appropriately, and another to add the + * document distributions to those cells. If so, we should add a function + * to cell grids indicating whether they want the documents given to them + * in two passes, and modify the code in DistDocumentTable (DistDocument.scala) + * so that it does two passes over the documents if so requested. + */ +trait DocumentRememberingCell[TCoord, TDoc <: DistDocument[TCoord]] { + this: GeoCell[TCoord, TDoc] => + + /** + * Return an Iterable over documents, listing the documents in the cell. + */ + def iterate_documents(): Iterable[TDoc] + + /** + * Generate the distribution for the cell from the documents in it. + */ + def generate_dist() { + assert(!finished) + for (doc <- iterate_documents()) + add_document(doc) + finish() + } +} + +/** + * Abstract class for a general grid of cells. The grid is defined over + * a continuous space (e.g. the surface of the Earth). The space is indexed + * by coordinates (of type TCoord). Each cell (of type TCell) covers + * some portion of the space. There is also a set of documents (of type + * TDoc), each of which is indexed by a coordinate and which has a + * distribution describing the contents of the document. The distributions + * of all the documents in a cell (i.e. whose coordinate is within the cell) + * are amalgamated to form the distribution of the cell. + * + * One example is the SphereCellGrid -- a grid of cells covering the Earth. + * ("Sphere" is used here in its mathematical meaning of the surface of a + * round ball.) Coordinates, of type SphereCoord, are pairs of latitude and + * longitude. Documents are of type SphereDocument and have a SphereCoord + * as their coordinate. Cells are of type SphereCell. Subclasses of + * SphereCellGrid refer to particular grid cell shapes. For example, the + * MultiRegularCellGrid consists of a regular tiling of the surface of the + * Earth into "rectangles" defined by minimum and maximum latitudes and + * longitudes. Most commonly, each tile is a cell, but it is possible for + * a cell to consist of an NxN square of tiles, in which case the cells + * overlap. Another subclass is KDTreeCellGrid, with rectangular cells of + * variable size so that the number of documents in a given cell stays more + * or less constant. + * + * Another possibility would be a grid indexed by years, where each cell + * corresponds to a particular range of years. + * + * In general, no assumptions are made about the shapes of cells in the grid, + * the number of dimensions in the grid, or whether the cells are overlapping. + * + * The following operations are used to populate a cell grid: + * + * (1) Documents are added one-by-one to a grid by calling + * `add_document_to_cell`. + * (2) After all documents have been added, `initialize_cells` is called + * to generate the cells and create their distribution. + * (3) After this, it should be possible to list the cells by calling + * `iter_nonempty_cells`. + */ +abstract class CellGrid[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc] +]( + val table: DistDocumentTable[TCoord, TDoc, _ <: CellGrid[TCoord, TDoc, TCell]] +) { + + /** + * Total number of cells in the grid. + */ + var total_num_cells: Int + + /* + * Number of times to pass over the training corpus + * and call add_document() + */ + val num_training_passes: Int = 1 + + /* + * Called before each new pass of the training. Usually not + * needed, but needed for KDCellGrid and possibly future CellGrids. + */ + def begin_training_pass(pass: Int) = {} + + /** + * Find the correct cell for the given document, based on the document's + * coordinates and other properties. If no such cell exists, return null + * if `create` is false. Else, create an empty cell to hold the + * coordinates -- but do *NOT* record the cell or otherwise alter the + * existing cell configuration. This situation where such a cell is needed + * is during evaluation. The cell is needed purely for comparing it against + * existing cells and determining its center. The reason for not recording + * such cells is to make sure that future evaluation results aren't affected. + */ + def find_best_cell_for_document(doc: TDoc, create_non_recorded: Boolean): + TCell + + /** + * Add the given document to the cell grid. + */ + def add_document_to_cell(document: TDoc): Unit + + /** + * Generate all non-empty cells. This will be called once (and only once), + * after all documents have been added to the cell grid by calling + * `add_document_to_cell`. The generation happens internally; but after + * this, `iter_nonempty_cells` should work properly. This is not meant + * to be called externally. + */ + protected def initialize_cells(): Unit + + /** + * Iterate over all non-empty cells. + * + * @param nonempty_word_dist If given, returned cells must also have a + * non-empty word distribution; otherwise, they just need to have at least + * one document in them. (Not all documents have word distributions, esp. + * when --max-time-per-stage has been set to a non-zero value so that we + * only load some subset of the word distributions for all documents. But + * even when not set, some documents may be listed in the document-data file + * but have no corresponding word counts given in the counts file.) + */ + def iter_nonempty_cells(nonempty_word_dist: Boolean = false): + Iterable[TCell] + + /** + * Iterate over all non-empty cells. + * + * @param nonempty_word_dist If given, returned cells must also have a + * non-empty word distribution; otherwise, they just need to have at least + * one document in them. (Not all documents have word distributions, esp. + * when --max-time-per-stage has been set to a non-zero value so that we + * only load some subset of the word distributions for all documents. But + * even when not set, some documents may be listed in the document-data file + * but have no corresponding word counts given in the counts file.) + */ + def iter_nonempty_cells_including(include: Iterable[TCell], + nonempty_word_dist: Boolean = false) = { + val cells = iter_nonempty_cells(nonempty_word_dist) + if (include.size == 0) + cells + else + include.toSeq union cells.toSeq + } + + /*********************** Not meant to be overridden *********************/ + + /* These are simply the sum of the corresponding counts + `num_docs_for_word_dist` and `num_docs_for_links` of each individual + cell. */ + var total_num_docs_for_word_dist = 0 + var total_num_docs_for_links = 0 + /* Set once finish() is called. */ + var all_cells_computed = false + /* Number of non-empty cells. */ + var num_non_empty_cells = 0 + + /** + * This function is called externally to initialize the cells. It is a + * wrapper around `initialize_cells()`, which is not meant to be called + * externally. Normally this does not need to be overridden. + */ + def finish() { + assert(!all_cells_computed) + + initialize_cells() + + all_cells_computed = true + + total_num_docs_for_links = 0 + total_num_docs_for_word_dist = 0 + + { // Put in a block to control scope of 'task' + val task = new ExperimentMeteredTask(table.driver, "non-empty cell", + "computing statistics of") + for (cell <- iter_nonempty_cells()) { + total_num_docs_for_word_dist += + cell.combined_dist.num_docs_for_word_dist + total_num_docs_for_links += + cell.combined_dist.num_docs_for_links + task.item_processed() + } + task.finish() + } + + errprint("Number of non-empty cells: %s", num_non_empty_cells) + errprint("Total number of cells: %s", total_num_cells) + errprint("Percent non-empty cells: %g", + num_non_empty_cells.toDouble / total_num_cells) + val recorded_training_docs_with_coordinates = + table.num_recorded_documents_with_coordinates_by_split("training").value + errprint("Training documents per non-empty cell: %g", + recorded_training_docs_with_coordinates.toDouble / num_non_empty_cells) + // Clear out the document distributions of the training set, since + // only needed when computing cells. + // + // FIXME: Could perhaps save more memory, or at least total memory used, + // by never creating these distributions at all, but directly adding + // them to the cells. Would require a bit of thinking when reading + // in the counts. + table.driver.heartbeat + table.clear_training_document_distributions() + table.driver.heartbeat + } +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/CellDist.scala b/src/main/scala/opennlp/fieldspring/gridlocate/CellDist.scala new file mode 100644 index 0000000..3b9c645 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/CellDist.scala @@ -0,0 +1,195 @@ +/////////////////////////////////////////////////////////////////////////////// +// CellDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import collection.mutable + +import opennlp.fieldspring.util.collectionutil.{LRUCache, doublemap} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.worddist.{WordDist,UnigramWordDist} +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +/** + * A general distribution over cells, associating a probability with each + * cell. The caller needs to provide the probabilities. + */ + +class CellDist[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc] +]( + val cell_grid: CellGrid[TCoord, TDoc, TCell] +) { + val cellprobs: mutable.Map[TCell, Double] = + mutable.Map[TCell, Double]() + + def set_cell_probabilities( + probs: collection.Map[TCell, Double]) { + cellprobs.clear() + cellprobs ++= probs + } + + def get_ranked_cells(include: Iterable[TCell]) = { + val probs = + if (include.size == 0) + cellprobs + else + // Elements on right override those on left + include.map((_, 0.0)).toMap ++ cellprobs.toMap + // sort by second element of tuple, in reverse order + probs.toSeq sortWith (_._2 > _._2) + } +} + +/** + * Distribution over cells that is associated with a word. This class knows + * how to populate its own probabilities, based on the relative probabilities + * of the word in the word distributions of the various cells. That is, + * if we have a set of cells, each with a word distribution, then we can + * imagine conceptually inverting the process to generate a cell distribution + * over words. Basically, for a given word, look to see what its probability + * is in all cells; normalize, and we have a cell distribution. + * + * Instances of this class are normally generated by a factory, specifically + * `CellDistFactory` or a subclass. Currently only used by `SphereWordCellDist` + * and `SphereCellDistFactory`; see them for info on how they are used. + * + * @param word Word for which the cell is computed + * @param cellprobs Hash table listing probabilities associated with cells + */ + +class WordCellDist[TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc] +]( + cell_grid: CellGrid[TCoord, TDoc, TCell], + val word: Word +) extends CellDist[TCoord, TDoc, TCell](cell_grid) { + var normalized = false + + protected def init() { + // It's expensive to compute the value for a given word so we cache word + // distributions. + var totalprob = 0.0 + // Compute and store un-normalized probabilities for all cells + for (cell <- cell_grid.iter_nonempty_cells(nonempty_word_dist = true)) { + val word_dist = + UnigramStrategy.check_unigram_dist(cell.combined_dist.word_dist) + val prob = word_dist.lookup_word(word) + // Another way of handling zero probabilities. + /// Zero probabilities are just a bad idea. They lead to all sorts of + /// pathologies when trying to do things like "normalize". + //if (prob == 0.0) + // prob = 1e-50 + cellprobs(cell) = prob + totalprob += prob + } + // Normalize the probabilities; but if all probabilities are 0, then + // we can't normalize, so leave as-is. (FIXME When can this happen? + // It does happen when you use --mode=generate-kml and specify words + // that aren't seen. In other circumstances, the smoothing ought to + // ensure that 0 probabilities don't exist? Anything else I missed?) + if (totalprob != 0) { + normalized = true + for ((cell, prob) <- cellprobs) + cellprobs(cell) /= totalprob + } else + normalized = false + } + + init() +} + +/** + * Factory object for creating CellDists, i.e. objects describing a + * distribution over cells. You can create two types of CellDists, one for + * a single word and one based on a distribution of words. The former + * process returns a WordCellDist, which initializes the probability + * distribution over cells as described for that class. The latter process + * returns a basic CellDist. It works by retrieving WordCellDists for + * each of the words in the distribution, and then averaging all of these + * distributions, weighted according to probability of the word in the word + * distribution. + * + * The call to `get_cell_dist` on this class either locates a cached + * distribution or creates a new one, using `create_word_cell_dist`, + * which creates the actual `WordCellDist` class. + * + * @param lru_cache_size Size of the cache used to avoid creating a new + * WordCellDist for a given word when one is already available for that + * word. + */ + +abstract class CellDistFactory[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc] +]( + val lru_cache_size: Int +) { + type TCellDist <: WordCellDist[TCoord, TDoc, TCell] + type TGrid <: CellGrid[TCoord, TDoc, TCell] + def create_word_cell_dist(cell_grid: TGrid, word: Word): TCellDist + + var cached_dists: LRUCache[Word, TCellDist] = null + + /** + * Return a cell distribution over a single word, using a least-recently-used + * cache to optimize access. + */ + def get_cell_dist(cell_grid: TGrid, word: Word) = { + if (cached_dists == null) + cached_dists = new LRUCache(maxsize = lru_cache_size) + cached_dists.get(word) match { + case Some(dist) => dist + case None => { + val dist = create_word_cell_dist(cell_grid, word) + cached_dists(word) = dist + dist + } + } + } + + /** + * Return a cell distribution over a distribution over words. This works + * by adding up the distributions of the individual words, weighting by + * the count of the each word. + */ + def get_cell_dist_for_word_dist(cell_grid: TGrid, xword_dist: WordDist) = { + // FIXME!!! Figure out what to do if distribution is not a unigram dist. + // Can we break this up into smaller operations? Or do we have to + // make it an interface for WordDist? + val word_dist = xword_dist.asInstanceOf[UnigramWordDist] + val cellprobs = doublemap[TCell]() + for ((word, count) <- word_dist.model.iter_items) { + val dist = get_cell_dist(cell_grid, word) + for ((cell, prob) <- dist.cellprobs) + cellprobs(cell) += count * prob + } + val totalprob = (cellprobs.values sum) + for ((cell, prob) <- cellprobs) + cellprobs(cell) /= totalprob + val retval = new CellDist[TCoord, TDoc, TCell](cell_grid) + retval.set_cell_probabilities(cellprobs) + retval + } +} + diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/DistDocument.scala b/src/main/scala/opennlp/fieldspring/gridlocate/DistDocument.scala new file mode 100644 index 0000000..ac5bf07 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/DistDocument.scala @@ -0,0 +1,940 @@ +/////////////////////////////////////////////////////////////////////////////// +// DistDocument.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2011, 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import collection.mutable +import util.matching.Regex +import util.control.Breaks._ + +import java.io._ + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil._ +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil._ +import opennlp.fieldspring.util.osutil.output_resource_usage +import opennlp.fieldspring.util.printutil.{errprint, warning} +import opennlp.fieldspring.util.Serializer +import opennlp.fieldspring.util.textutil.capfirst + +import opennlp.fieldspring.worddist.{WordDist,WordDistFactory} + +import GridLocateDriver.Debug._ + +///////////////////////////////////////////////////////////////////////////// +// DistDocument tables // +///////////////////////////////////////////////////////////////////////////// + +/** + * A simple class holding properties referring to extra operations that + * may be needed during document loading, depending on the particular + * strategies and/or type of cell grids. All extra operations start + * out set to false. If anyone requests extra, we do it. + */ +class DocumentLoadingProperties { + var need_training_docs_in_memory_during_testing: Boolean = false + var need_two_passes_over_training_docs: Boolean = false + var need_dist_during_first_pass_over_training_docs: Boolean = false + var need_pass_over_eval_docs_during_training: Boolean = false + var need_dist_during_pass_over_eval_docs_during_training: Boolean = false +} + +////////////////////// DistDocument table + +/** + * Class maintaining tables listing all documents and mapping between + * names, ID's and documents. + */ +abstract class DistDocumentTable[ + TCoord : Serializer, + TDoc <: DistDocument[TCoord], + TGrid <: CellGrid[TCoord,TDoc,_] +]( + /* SCALABUG!!! Declaring TDoc <: DistDocument[TCoord] isn't sufficient + for Scala to believe that null is an OK value for TDoc, even though + DistDocument is a reference type and hence TDoc must be reference. + */ + val driver: GridLocateDriver, + val word_dist_factory: WordDistFactory +) { + /** + * Properties indicating whether we need to do more than simply do a + * single pass through training and eval documents. Set by individual + * strategies or cell grid types. + */ + val loading_props = new DocumentLoadingProperties + + /** + * List of documents in each split. + */ + val documents_by_split = bufmap[String, TDoc]() + + // Example of using TaskCounterWrapper directly for non-split values. + // val num_documents = new driver.TaskCounterWrapper("num_documents") + + /** # of records seen in each split. */ + val num_records_by_split = + driver.countermap("num_records_by_split") + /** # of records skipped in each split due to errors */ + val num_error_skipped_records_by_split = + driver.countermap("num_error_skipped_records_by_split") + /** # of records skipped in each split, due to issues other than errors + * (e.g. for Wikipedia documents, not being in the Main namespace). */ + val num_non_error_skipped_records_by_split = + driver.countermap("num_non_error_skipped_records_by_split") + /** # of documents seen in each split. This does not include skipped + * records (see above). */ + val num_documents_by_split = + driver.countermap("num_documents_by_split") + /** # of documents seen in each split skipped because lacking coordinates. + * Note that although most callers skip documents without coordinates, + * there are at least some cases where callers request to include such + * documents. */ + val num_documents_skipped_because_lacking_coordinates_by_split = + driver.countermap("num_documents_skipped_because_lacking_coordinates_by_split") + /** # of documents seen in each split skipped because lacking coordinates, + * but which otherwise would have been recorded. Note that although most + * callers skip documents without coordinates, there are at least some + * cases where callers request to include such documents. In addition, + * some callers do not ask for documents to be recorded (this happens + * particularly with eval-set documents). */ + val num_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split = + driver.countermap("num_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split") + /** # of recorded documents seen in each split (i.e. those added to the + * cell grid). Non-recorded documents are generally those in the eval set. + */ + val num_recorded_documents_by_split = + driver.countermap("num_recorded_documents_by_split") + /** # of documents in each split with coordinates. */ + val num_documents_with_coordinates_by_split = + driver.countermap("num_documents_with_coordinates_by_split") + /** # of recorded documents in each split with coordinates. Non-recorded + * documents are generally those in the eval set. */ + val num_recorded_documents_with_coordinates_by_split = + driver.countermap("num_recorded_documents_with_coordinates_by_split") + /** # of word tokens for documents seen in each split. This does not + * include skipped records (see above). */ + val word_tokens_of_documents_by_split = + driver.countermap("word_tokens_of_documents_by_split") + /** # of word tokens for documents seen in each split skipped because + * lacking coordinates (see above). */ + val word_tokens_of_documents_skipped_because_lacking_coordinates_by_split = + driver.countermap("word_tokens_of_documents_skipped_because_lacking_coordinates_by_split") + /** # of word tokens for documents seen in each split skipped because + * lacking coordinates, but which otherwise would have been recorded + * (see above). Non-recorded documents are generally those in the + * eval set. */ + val word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split = + driver.countermap("word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split") + /** # of word tokens for recorded documents seen in each split (i.e. + * those added to the cell grid). Non-recorded documents are generally + * those in the eval set. */ + val word_tokens_of_recorded_documents_by_split = + driver.countermap("word_tokens_of_recorded_documents_by_split") + /** # of word tokens for documents in each split with coordinates. */ + val word_tokens_of_documents_with_coordinates_by_split = + driver.countermap("word_tokens_of_documents_with_coordinates_by_split") + /** # of word tokens for recorded documents in each split with coordinates. + * Non-recorded documents are generally those in the eval set. */ + val word_tokens_of_recorded_documents_with_coordinates_by_split = + driver.countermap("word_tokens_of_recorded_documents_with_coordinates_by_split") + + def create_document(schema: Schema): TDoc + + /** + * Implementation of `create_and_init_document`. Subclasses should + * override this if needed. External callers should call + * `create_and_init_document`, not this. Note also that the + * parameter `record_in_table` has a different meaning here -- it only + * refers to recording in subsidiary tables, subclasses, etc. The + * wrapping function `create_and_init_document` takes care of recording + * in the main table. + */ + protected def imp_create_and_init_document(schema: Schema, + fieldvals: Seq[String], record_in_table: Boolean) = { + val doc = create_document(schema) + if (doc != null) + doc.set_fields(fieldvals) + doc + } + + /** + * Create, initialize and return a document with the given fieldvals, + * loaded from a corpus with the given schema. Return value may be + * null, meaning that the given record was skipped (e.g. due to erroneous + * field values or for some other reason -- e.g. Wikipedia records not + * in the Main namespace are skipped). + * + * @param schema Schema of the corpus from which the record was loaded + * @param fieldvals Field values, taken from the record + * @param record_in_table If true, record the document in the table and + * in any subsidiary tables, subclasses, etc. This does not record + * the document in the cell grid; the caller needs to do that if + * needed. + * @param must_have_coord If true, the document must have a coordinate; + * if not, it will be skipped, and null will be returned. + */ + def create_and_init_document(schema: Schema, fieldvals: Seq[String], + record_in_table: Boolean, must_have_coord: Boolean = true) = { + val split = schema.get_field_or_else(fieldvals, "split", "unknown") + if (record_in_table) + num_records_by_split(split) += 1 + val doc = try { + imp_create_and_init_document(schema, fieldvals, record_in_table) + } catch { + case e:Exception => { + num_error_skipped_records_by_split(split) += 1 + throw e + } + } + if (doc == null) { + num_non_error_skipped_records_by_split(split) += 1 + doc + } else { + assert(doc.split == split) + assert(doc.dist != null) + val double_tokens = doc.dist.model.num_tokens + val tokens = double_tokens.toInt + // Partial counts should not occur in training documents. + assert(double_tokens == tokens) + if (record_in_table) { + num_documents_by_split(split) += 1 + word_tokens_of_documents_by_split(split) += tokens + } + if (!doc.has_coord && must_have_coord) { + errprint("Document %s skipped because it has no coordinate", doc) + num_documents_skipped_because_lacking_coordinates_by_split(split) += 1 + word_tokens_of_documents_skipped_because_lacking_coordinates_by_split(split) += tokens + if (record_in_table) { + num_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split(split) += 1 + word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split(split) += tokens + } + // SCALABUG, same bug with null and a generic type inheriting from + // a reference type + null.asInstanceOf[TDoc] + } + if (doc.has_coord) { + num_documents_with_coordinates_by_split(split) += 1 + word_tokens_of_documents_with_coordinates_by_split(split) += tokens + } + if (record_in_table) { + // documents_by_split(split) += doc + num_recorded_documents_by_split(split) += 1 + word_tokens_of_recorded_documents_by_split(split) += tokens + } + if (doc.has_coord && record_in_table) { + num_recorded_documents_with_coordinates_by_split(split) += 1 + (word_tokens_of_recorded_documents_with_coordinates_by_split(split) + += tokens) + } + doc + } + } + + /** + * A file processor that reads corpora containing document metadata, + * creates a DistDocument for each document described, and adds it to + * this document table. + * + * @param suffix Suffix specifying the type of document file wanted + * (e.g. "counts" or "document-metadata" + * @param cell_grid Cell grid to add newly created DistDocuments to + */ + class DistDocumentTableFileProcessor( + suffix: String, cell_grid: TGrid, + task: ExperimentMeteredTask + ) extends DistDocumentFileProcessor(suffix, driver) { + def handle_document(fieldvals: Seq[String]) = { + val doc = create_and_init_document(schema, fieldvals, true) + if (doc != null) { + assert(doc.dist != null) + cell_grid.add_document_to_cell(doc) + (true, true) + } + else (false, true) + } + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + // Stop if we've reached the maximum + var should_stop = false + breakable { + for (line <- lines) { + if (!parse_row(line)) + should_stop = true + if (task.item_processed()) + should_stop = true + if ((driver.params.num_training_docs > 0 && + task.num_processed >= driver.params.num_training_docs)) { + errprint("") + errprint("Stopping because limit of %s documents reached", + driver.params.num_training_docs) + should_stop = true + } + val sleep_at = debugval("sleep-at-docs") + if (sleep_at != "") { + if (task.num_processed == sleep_at.toInt) { + errprint("Reached %d documents, sleeping ...") + Thread.sleep(5000) + } + } + if (should_stop) + break + } + } + (!should_stop, ()) + } + } + + /** + * Read the training documents from the given corpus. Documents listed in + * the document file(s) are created, listed in this table, + * and added to the cell grid corresponding to the table. + * + * @param filehand The FileHandler for working with the file. + * @param dir Directory containing the corpus. + * @param suffix Suffix specifying the type of document file wanted + * (e.g. "counts" or "document-metadata" + * @param cell_grid Cell grid into which the documents are added. + */ + def read_training_documents(filehand: FileHandler, dir: String, + suffix: String, cell_grid: TGrid) { + + for (pass <- 1 to cell_grid.num_training_passes) { + cell_grid.begin_training_pass(pass) + val task = + new ExperimentMeteredTask(driver, "document", "reading pass " + pass, + maxtime = driver.params.max_time_per_stage) + val training_distproc = + new DistDocumentTableFileProcessor("training-" + suffix, cell_grid, task) + training_distproc.read_schema_from_textdb(filehand, dir) + training_distproc.process_files(filehand, Seq(dir)) + task.finish() + output_resource_usage() + } + } + + def clear_training_document_distributions() { + for (doc <- documents_by_split("training")) + doc.dist = null + } + + def finish_document_loading() { + // Compute overall distribution values (e.g. back-off statistics). + errprint("Finishing global dist...") + word_dist_factory.finish_global_distribution() + + // Now compute per-document values dependent on the overall distribution + // statistics just computed. + errprint("Finishing document dists...") + for ((split, table) <- documents_by_split) { + for (doc <- table) { + if (doc.dist != null) + doc.dist.finish_after_global() + } + } + + // Now output statistics on number of documents seen, etc. + errprint("") + errprint("-------------------------------------------------------------------------") + errprint("Document/record/word token statistics:") + + var total_num_records = 0L + var total_num_error_skipped_records = 0L + var total_num_non_error_skipped_records = 0L + var total_num_documents = 0L + var total_num_documents_skipped_because_lacking_coordinates = 0L + var total_num_would_be_recorded_documents_skipped_because_lacking_coordinates = 0L + var total_num_recorded_documents = 0L + var total_num_documents_with_coordinates = 0L + var total_num_recorded_documents_with_coordinates = 0L + var total_word_tokens_of_documents = 0L + var total_word_tokens_of_documents_skipped_because_lacking_coordinates = 0L + var total_word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates = 0L + var total_word_tokens_of_recorded_documents = 0L + var total_word_tokens_of_documents_with_coordinates = 0L + var total_word_tokens_of_recorded_documents_with_coordinates = 0L + for (split <- num_records_by_split.keys) { + errprint("For split '%s':", split) + + val num_records = num_records_by_split(split).value + errprint(" %s records seen", num_records) + total_num_records += num_records + + val num_error_skipped_records = + num_error_skipped_records_by_split(split).value + errprint(" %s records skipped due to error seen", + num_error_skipped_records) + total_num_error_skipped_records += num_error_skipped_records + + val num_non_error_skipped_records = + num_non_error_skipped_records_by_split(split).value + errprint(" %s records skipped due to other than error seen", + num_non_error_skipped_records) + total_num_non_error_skipped_records += num_non_error_skipped_records + + def print_line(documents: String, num_documents: Long, + num_tokens: Long) { + errprint(" %s %s, %s total tokens, %.2f tokens/document", + num_documents, documents, num_tokens, + // Avoid division by zero + num_tokens.toDouble / (num_documents + 1e-100)) + } + + val num_documents = num_documents_by_split(split).value + val word_tokens_of_documents = + word_tokens_of_documents_by_split(split).value + print_line("documents seen", num_documents, word_tokens_of_documents) + total_num_documents += num_documents + total_word_tokens_of_documents += word_tokens_of_documents + + val num_recorded_documents = + num_recorded_documents_by_split(split).value + val word_tokens_of_recorded_documents = + word_tokens_of_recorded_documents_by_split(split).value + print_line("documents recorded", num_recorded_documents, + word_tokens_of_recorded_documents) + total_num_recorded_documents += num_recorded_documents + total_word_tokens_of_recorded_documents += + word_tokens_of_recorded_documents + + val num_documents_skipped_because_lacking_coordinates = + num_documents_skipped_because_lacking_coordinates_by_split(split).value + val word_tokens_of_documents_skipped_because_lacking_coordinates = + word_tokens_of_documents_skipped_because_lacking_coordinates_by_split( + split).value + print_line("documents skipped because lacking coordinates", + num_documents_skipped_because_lacking_coordinates, + word_tokens_of_documents_skipped_because_lacking_coordinates) + total_num_documents_skipped_because_lacking_coordinates += + num_documents_skipped_because_lacking_coordinates + total_word_tokens_of_documents_skipped_because_lacking_coordinates += + word_tokens_of_documents_skipped_because_lacking_coordinates + + val num_would_be_recorded_documents_skipped_because_lacking_coordinates = + num_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split( + split).value + val word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates = + word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates_by_split( + split).value + print_line("would-be-recorded documents skipped because lacking coordinates", + num_would_be_recorded_documents_skipped_because_lacking_coordinates, + word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates) + total_num_would_be_recorded_documents_skipped_because_lacking_coordinates += + num_would_be_recorded_documents_skipped_because_lacking_coordinates + total_word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates += + word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates + + val num_documents_with_coordinates = + num_documents_with_coordinates_by_split(split).value + val word_tokens_of_documents_with_coordinates = + word_tokens_of_documents_with_coordinates_by_split(split).value + print_line("documents having coordinates seen", + num_documents_with_coordinates, + word_tokens_of_documents_with_coordinates) + total_num_documents_with_coordinates += num_documents_with_coordinates + total_word_tokens_of_documents_with_coordinates += + word_tokens_of_documents_with_coordinates + + val num_recorded_documents_with_coordinates = + num_recorded_documents_with_coordinates_by_split(split).value + val word_tokens_of_recorded_documents_with_coordinates = + word_tokens_of_recorded_documents_with_coordinates_by_split(split).value + print_line("documents having coordinates recorded", + num_recorded_documents_with_coordinates, + word_tokens_of_recorded_documents_with_coordinates) + total_num_recorded_documents_with_coordinates += + num_recorded_documents_with_coordinates + total_word_tokens_of_recorded_documents_with_coordinates += + word_tokens_of_recorded_documents_with_coordinates + } + + errprint("Total: %s records, %s skipped records (%s from error)", + total_num_records, + (total_num_error_skipped_records + total_num_non_error_skipped_records), + total_num_error_skipped_records) + errprint("Total: %s documents with %s total word tokens", + total_num_documents, total_word_tokens_of_documents) + errprint("Total: %s documents skipped because lacking coordinates,\n with %s total word tokens", + total_num_documents_skipped_because_lacking_coordinates, + total_word_tokens_of_documents_skipped_because_lacking_coordinates) + errprint("Total: %s would-be-recorded documents skipped because lacking coordinates,\n with %s total word tokens", + total_num_would_be_recorded_documents_skipped_because_lacking_coordinates, + total_word_tokens_of_would_be_recorded_documents_skipped_because_lacking_coordinates) + errprint("Total: %s recorded documents with %s total word tokens", + total_num_recorded_documents, total_word_tokens_of_recorded_documents) + errprint("Total: %s documents having coordinates with %s total word tokens", + total_num_documents_with_coordinates, + total_word_tokens_of_documents_with_coordinates) + errprint("Total: %s recorded documents having coordinates with %s total word tokens", + total_num_recorded_documents_with_coordinates, + total_word_tokens_of_recorded_documents_with_coordinates) + } +} + +///////////////////////////////////////////////////////////////////////////// +// DistDocuments // +///////////////////////////////////////////////////////////////////////////// + +/** + * An exception thrown to indicate an error during document creation + * (typically due to a bad field value). + */ +case class DocumentValidationException( + message: String, + cause: Option[Throwable] = None +) extends Exception(message) { + if (cause != None) + initCause(cause.get) + + /** + * Alternate constructor. + * + * @param message exception message + */ + def this(msg: String) = this(msg, None) + + /** + * Alternate constructor. + * + * @param message exception message + * @param cause wrapped, or nested, exception + */ + def this(msg: String, cause: Throwable) = this(msg, Some(cause)) +} + +/** + * A document with an associated coordinate placing it in a grid, with a word + * distribution describing the text of the document. For a "coordinate" + * referring to a location on the Earth, documents can come from Wikipedia + * articles, individual tweets, Twitter feeds (all tweets from a user), book + * chapters from travel stories, etc. For a "coordinate" referring to a + * point in time, documents might be biographical entries in an encyclopedia, + * snippets of text surrounding a date from an arbitrary web source, etc. + * + * The fields available depend on the source of the document (e.g. Wikipedia, + * Twitter, etc.). + * + * Defined general fields: + * + * corpus: Corpus of the document (stored as a fixed field in the schema). + * corpus-type: Corpus type of the corpus ('wikipedia', 'twitter', stored + * as a fixed field in the schema. + * title: Title of document. The combination of (corpus, title) needs to + * uniquely identify a document. + * coord: Coordinates of document. + * split: Evaluation split of document ("training", "dev", "test"), usually + * stored as a fixed field in the schema. + */ +abstract class DistDocument[TCoord : Serializer]( + val schema: Schema, + val table: DistDocumentTable[TCoord,_,_] +) { + + import DistDocumentConverters._ + + /** + * Title of the document -- something that uniquely identifies it, + * at least within all documents with a given corpus name. (The combination + * of corpus and title uniquely identifies the document.) Title can be + * taken from an actual title (e.g. in Wikipedia) or ID of some sort + * (e.g. a tweet ID), or simply taken from an assigned serial number. + * The underlying representation can be arbitrary -- `title` is just a + * function to produce a string representation. + * + * FIXME: The corpus name is stored in the schema, but we don't currently + * make any attempt to verify that corpus names are unique or implement + * any operations involving corpus names -- much less verify that titles + * are in fact unique for a given corpus name. There's in general no way + * to look up a document by title -- perhaps this is good since such + * lookup adds a big hash table and entails storing the documents to + * begin with, both of which things we want to avoid when possible. + */ + def title: String + /** + * True if this document has a coordinate associated with it. + */ + def has_coord: Boolean + /** + * Return the coordinate of a document with a coordinate. Results are + * undefined if the document has no coordinate (e.g. it might throw an + * error or return a default value such as `null`). + */ + def coord: TCoord + /** + * Return the evaluation split ("training", "dev" or "test") of the document. + * This was created before corpora were sub-divided by the value of this + * field, and hence it could be a property of the document. It's now a + * fixed value in the schema, but the field remains. + */ + def split = schema.get_fixed_field("split", error_if_missing = true) + /** + * If this document has an incoming-link value associated with it (i.e. + * number of links pointing to it in some sort of link structure), return + * Some(NUM-LINKS); else return None. + * + * FIXME: This is used to establish a prior for the Naive Bayes strategy, + * and is computed on a cell level, which is why we have it here; but it + * seems too specific and tied to Wikipedia. Also, perhaps we want to + * split it into has_incoming_links and incoming_links (without the Option[] + * wrapping). We also need some comment of "corpus type" and a way to + * request whether a given corpus type has incoming links marked on it. + */ + def incoming_links: Option[Int] = None + + /** + * Object containing word distribution of this document. + */ + var dist: WordDist = _ + + /** + * Set the fields of the document to the given values. + * + * @param fieldvals A list of items, of the same length and in the same + * order as the corresponding schema. + * + * Note that we don't include the field values as a constructor parameter + * and set them during construction, because we run into bootstrapping + * problems. In particular, if subclasses declare and initialize + * additional fields, then those fields get initialized *after* the + * constructor runs, in which case the values determined from the field + * values get *overwritten* with their default values. Subtle and bad. + * So instead we make it so that the call to `set_fields` has to happen + * *after* construction. + */ + def set_fields(fieldvals: Seq[String]) { + for ((field, value) <- (schema.fieldnames zip fieldvals)) { + if (debug("rethrow")) + set_field(field, value) + else { + try { set_field(field, value) } + catch { + case e@_ => { + val msg = ("Bad value %s for field '%s': %s" format + (value, field, e.toString)) + if (debug("stack-trace") || debug("stacktrace")) + e.printStackTrace + throw new DocumentValidationException(msg, e) + } + } + } + } + } + + def get_fields(fields: Seq[String]) = { + for (field <- fields; + value = get_field(field); + if value != null) + yield value + } + + def set_field(field: String, value: String) { + field match { + case "counts" => { + // Set the distribution on the document. But don't use the eval + // set's distributions in computing global smoothing values and such, + // to avoid contaminating the results (training on your eval set). + // In addition, if this isn't the training or eval set, we shouldn't + // be loading at all. + val is_training_set = (this.split == "training") + val is_eval_set = (this.split == table.driver.params.eval_set) + assert (is_training_set || is_eval_set) + table.word_dist_factory.constructor.initialize_distribution(this, + value, is_training_set) + dist.finish_before_global() + } + case _ => () // Just eat the other parameters + } + } + + def get_field(field: String) = { + field match { + case "title" => title + case "coord" => if (has_coord) put_x(coord) else null + case _ => null + } + } + + // def __repr__ = "DistDocument(%s)" format toString.encode("utf-8") + + def shortstr = "%s" format title + + override def toString = { + val coordstr = if (has_coord) " at %s".format(coord) else "" + val corpus_name = schema.get_fixed_field("corpus-name") + val corpusstr = if (corpus_name != null) "%s/".format(corpus_name) else "" + "%s%s%s".format(corpusstr, title, coordstr) + } + + def struct: scala.xml.Elem + + def distance_to_coord(coord2: TCoord): Double + + /** + * Output a distance with attached units + */ + def output_distance(dist: Double): String +} + + +///////////////////////////////////////////////////////////////////////////// +// Conversion functions // +///////////////////////////////////////////////////////////////////////////// + +object DistDocumentConverters { + def yesno_to_boolean(foo: String) = { + foo match { + case "yes" => true + case "no" => false + case _ => { + warning("Expected yes or no, saw '%s'", foo) + false + } + } + } + def boolean_to_yesno(foo: Boolean) = if (foo) "yes" else "no" + + def get_int_or_none(foo: String) = + if (foo == "") None else Option[Int](foo.toInt) + def put_int_or_none(foo: Option[Int]) = { + foo match { + case None => "" + case Some(x) => x.toString + } + } + + /** + * Convert an object of type `T` into a serialized (string) form, for + * storage purposes in a text file. Note that the construction + * `T : Serializer` means essentially "T must have a Serializer". + * More technically, it adds an extra implicit parameter list with a + * single parameter of type Serializer[T]. When the compiler sees a + * call to put_x[X] for some type X, it looks in the lexical environment + * to see if there is an object in scope of type Serializer[X] that is + * marked `implicit`, and if so, it gives the implicit parameter + * the value of that object; otherwise, you get a compile error. The + * function can then retrieve the implicit parameter's value using the + * construction `implicitly[Serializer[T]]`. The `T : Serializer` + * construction is technically known as a *context bound*. + */ + def put_x[T : Serializer](foo: T) = + implicitly[Serializer[T]].serialize(foo) + /** + * Convert the serialized form of the value of an object of type `T` + * back into that type. Throw an error if an invalid string was seen. + * See `put_x` for a description of the `Serializer` type and the *context + * bound* (denoted by a colon) that ties it to `T`. + * + * @see put_x + */ + def get_x[T : Serializer](foo: String) = + implicitly[Serializer[T]].deserialize(foo) + + /** + * Convert an object of type `Option[T]` into a serialized (string) form. + * See `put_x` for more information. The only difference between that + * function is that if the value is None, a blank string is written out; + * else, for a value Some(x), where `x` is a value of type `T`, `x` is + * written out using `put_x`. + * + * @see put_x + */ + def put_x_or_none[T : Serializer](foo: Option[T]) = { + foo match { + case None => "" + case Some(x) => put_x[T](x) + } + } + /** + * Convert a blank string into None, or a valid string that converts into + * type T into Some(value), where value is of type T. Throw an error if + * a non-blank, invalid string was seen. + * + * @see get_x + * @see put_x + */ + def get_x_or_none[T : Serializer](foo: String) = + if (foo == "") None + else Option[T](get_x[T](foo)) + + /** + * Convert an object of type `T` into a serialized (string) form. + * If the object has the value `null`, write out a blank string. + * Note that T must be a reference type (i.e. not a primitive + * type such as Int or Double), so that `null` is a valid value. + * + * @see put_x + * @see put_x + */ + def put_x_or_null[T >: Null : Serializer](foo: T) = { + if (foo == null) "" + else put_x[T](foo) + } + /** + * Convert a blank string into null, or a valid string that converts into + * type T into that value. Throw an error if a non-blank, invalid string + * was seen. Note that T must be a reference type (i.e. not a primitive + * type such as Int or Double), so that `null` is a valid value. + * + * @see get_x + * @see put_x + */ + def get_x_or_null[T >: Null : Serializer](foo: String) = + if (foo == "") null + else get_x[T](foo) +} + +///////////////////////////////////////////////////////////////////////////// +// DistDocument File Processors // +///////////////////////////////////////////////////////////////////////////// + +/** + * A file processor that reads document files from a corpora. + * + * @param suffix Suffix used for selecting the particular corpus from a + * directory + * @param dstats ExperimentDriverStats used for recording counters and such. + * Pass in null to not record counters. + */ +abstract class DistDocumentFileProcessor( + suffix: String, + val dstats: ExperimentDriverStats +) extends BasicTextDBProcessor[Unit](suffix) { + + /******** Counters to track what's going on ********/ + + var shortfile: String = _ + + def get_shortfile = shortfile + + def filename_to_counter_name(filehand: FileHandler, file: String) = { + var (_, base) = filehand.split_filename(file) + breakable { + while (true) { + val newbase = """\.[a-z0-9]*$""".r.replaceAllIn(base, "") + if (newbase == base) break + base = newbase + } + } + """[^a-zA-Z0-9]""".r.replaceAllIn(base, "_") + } + + def get_file_counter_name(counter: String) = + "byfile." + get_shortfile + "." + counter + + def increment_counter(counter: String, value: Long = 1) { + if (dstats != null) { + val file_counter = get_file_counter_name(counter) + dstats.increment_task_counter(file_counter, value) + dstats.increment_task_counter(counter, value) + dstats.increment_local_counter(file_counter, value) + dstats.increment_local_counter(counter, value) + } + } + + def increment_document_counter(counter: String) { + increment_counter(counter) + increment_counter("documents.total") + } + + /******** Main code ********/ + + /** + * Handle (e.g. create and record) a document. + * + * @param fieldvals Field values of the document as read from the + * document file. + * @return Tuple `(processed, keep_going)` where `processed` indicates + * whether the document was processed to completion (rather than skipped) + * and `keep_going` indicates whether processing of further documents + * should continue or stop. Note that the value of `processed` does + * *NOT* indicate whether there is an error in the field values. In + * that case, the error should be caught and rethrown as a + * DocumentValidationException, listing the field and value as well + * as the error. + */ + def handle_document(fieldvals: Seq[String]): (Boolean, Boolean) + + override def handle_bad_row(line: String, fieldvals: Seq[String]) { + increment_document_counter("documents.bad") + super.handle_bad_row(line, fieldvals) + } + + def process_row(fieldvals: Seq[String]): (Boolean, Boolean) = { + val (processed, keep_going) = + try { handle_document(fieldvals) } + catch { + case e:DocumentValidationException => { + warning("Line %s: %s", num_processed + 1, e.message) + return (false, true) + } + } + if (processed) + increment_document_counter("documents.processed") + else { + errprint("Skipped document %s", + schema.get_field_or_else(fieldvals, "title", "unknown title??")) + increment_document_counter("documents.skipped") + } + return (true, keep_going) + } + + override def begin_process_file(filehand: FileHandler, file: String) { + shortfile = filename_to_counter_name(filehand, file) + super.begin_process_file(filehand, file) + } + + override def end_process_file(filehand: FileHandler, file: String) { + def note(counter: String, english: String) { + if (dstats != null) { + val file_counter = get_file_counter_name(counter) + val value = dstats.get_task_counter(file_counter) + errprint("Number of %s for file %s: %s", english, file, + value) + } + } + + if (debug("per-document")) { + note("documents.processed", "documents processed") + note("documents.skipped", "documents skipped") + note("documents.bad", "bad documents") + note("documents.total", "total documents") + } + super.end_process_file(filehand, file) + } +} + +/** + * A writer class for writing DistDocuments out to a corpus. + * + * @param schema schema describing the fields in the document files + * @param suffix suffix used for identifying the particular corpus in a + * directory + */ +class DistDocumentWriter[TCoord : Serializer]( + schema: Schema, + suffix: String +) extends TextDBWriter(schema, suffix) { + def output_document(outstream: PrintStream, doc: DistDocument[TCoord]) { + schema.output_row(outstream, doc.get_fields(schema.fieldnames)) + } +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/Evaluation.scala b/src/main/scala/opennlp/fieldspring/gridlocate/Evaluation.scala new file mode 100644 index 0000000..d2869db --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/Evaluation.scala @@ -0,0 +1,1158 @@ +/////////////////////////////////////////////////////////////////////////////// +// Evaluation.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2011 Stephen Roller, The University of Texas at Austin +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import util.control.Breaks._ +import collection.mutable + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.experiment.ExperimentDriverStats +import opennlp.fieldspring.util.mathutil._ +import opennlp.fieldspring.util.ioutil.{FileHandler, FileProcessor} +import opennlp.fieldspring.util.MeteredTask +import opennlp.fieldspring.util.osutil.{curtimehuman, output_resource_usage} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import GridLocateDriver.Debug._ + +///////////////////////////////////////////////////////////////////////////// +// General statistics on evaluation results // +///////////////////////////////////////////////////////////////////////////// + +// incorrect_reasons is a map from ID's for reasons to strings describing +// them. +class EvalStats( + driver_stats: ExperimentDriverStats, + prefix: String, + incorrect_reasons: Map[String, String] +) { + def construct_counter_name(name: String) = { + if (prefix == "") name + else prefix + "." + name + } + + def increment_counter(name: String) { + driver_stats.increment_local_counter(construct_counter_name(name)) + } + + def get_counter(name: String) = { + driver_stats.get_local_counter(construct_counter_name(name)) + } + + def list_counters(group: String, recursive: Boolean, + fully_qualified: Boolean = true) = + driver_stats.list_counters(construct_counter_name(group), recursive, + fully_qualified) + + def record_result(correct: Boolean, reason: String = null) { + if (reason != null) + assert(incorrect_reasons.keySet contains reason) + increment_counter("instances.total") + if (correct) + increment_counter("instances.correct") + else { + increment_counter("instances.incorrect") + if (reason != null) + increment_counter("instances.incorrect." + reason) + } + } + + def total_instances = get_counter("instances.total") + def correct_instances = get_counter("instances.correct") + def incorrect_instances = get_counter("instances.incorrect") + + def output_fraction(header: String, amount: Long, total: Long) { + if (amount > total) { + warning("Something wrong: Fractional quantity %s greater than total %s", + amount, total) + } + var percent = + if (total == 0) "indeterminate percent" + else "%5.2f%%" format (100 * amount.toDouble / total) + errprint("%s = %s/%s = %s", header, amount, total, percent) + } + + def output_correct_results() { + output_fraction("Percent correct", correct_instances, total_instances) + } + + def output_incorrect_results() { + output_fraction("Percent incorrect", incorrect_instances, total_instances) + for ((reason, descr) <- incorrect_reasons) { + output_fraction(" %s" format descr, + get_counter("instances.incorrect." + reason), total_instances) + } + } + + def output_other_stats() { + for (ty <- driver_stats.list_local_counters("", true)) { + val count = driver_stats.get_local_counter(ty) + errprint("%s = %s", ty, count) + } + } + + def output_results() { + if (total_instances == 0) { + warning("Strange, no instances found at all; perhaps --eval-format is incorrect?") + return + } + errprint("Number of instances = %s", total_instances) + output_correct_results() + output_incorrect_results() + output_other_stats() + } +} + +class EvalStatsWithRank( + driver_stats: ExperimentDriverStats, + prefix: String, + max_rank_for_credit: Int = 10 +) extends EvalStats(driver_stats, prefix, Map[String, String]()) { + val incorrect_by_exact_rank = intmap[Int]() + val correct_by_up_to_rank = intmap[Int]() + var incorrect_past_max_rank = 0 + var total_credit = 0 + + def record_result(rank: Int) { + assert(rank >= 1) + val correct = rank == 1 + super.record_result(correct, reason = null) + if (rank <= max_rank_for_credit) { + total_credit += max_rank_for_credit + 1 - rank + incorrect_by_exact_rank(rank) += 1 + for (i <- rank to max_rank_for_credit) + correct_by_up_to_rank(i) += 1 + } else + incorrect_past_max_rank += 1 + } + + override def output_correct_results() { + super.output_correct_results() + val possible_credit = max_rank_for_credit * total_instances + output_fraction("Percent correct with partial credit", + total_credit, possible_credit) + for (i <- 2 to max_rank_for_credit) { + output_fraction(" Correct is at or above rank %s" format i, + correct_by_up_to_rank(i), total_instances) + } + } + + override def output_incorrect_results() { + super.output_incorrect_results() + for (i <- 2 to max_rank_for_credit) { + output_fraction(" Incorrect, with correct at rank %s" format i, + incorrect_by_exact_rank(i), + total_instances) + } + output_fraction(" Incorrect, with correct not in top %s" format + max_rank_for_credit, + incorrect_past_max_rank, total_instances) + } +} + +//////// Statistics for locating documents + +/** + * General class for the result of evaluating a document. Specifies a + * document, cell grid, and the predicted coordinate for the document. + * The reason that a cell grid needs to be given is that we need to + * retrieve the cell that the document belongs to in order to get the + * "central point" (center or centroid of the cell), and in general we + * may be operating with multiple cell grids (e.g. in the combination of + * uniform and k-D tree grids). (FIXME: I don't know if this is actually + * true.) + * + * FIXME: Perhaps we should redo the results in terms of pseudo-documents + * instead of cells. + * + * @tparam TCoord type of a coordinate + * @tparam TDoc type of a document + * @tparam TCell type of a cell + * @tparam TGrid type of a cell grid + * + * @param document document whose coordinate is predicted + * @param cell_grid cell grid against which error comparison should be done + * @param pred_coord predicted coordinate of the document + */ +class DocumentEvaluationResult[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc], + TGrid <: CellGrid[TCoord, TDoc, TCell] +]( + val document: TDoc, + val cell_grid: TGrid, + val pred_coord: TCoord +) { + /** + * True cell in the cell grid in which the document belongs + */ + val true_cell = cell_grid.find_best_cell_for_document(document, true) + /** + * Number of documents in the true cell + */ + val num_docs_in_true_cell = true_cell.combined_dist.num_docs_for_word_dist + /** + * Central point of the true cell + */ + val true_center = true_cell.get_center_coord() + /** + * "True distance" (rather than e.g. degree distance) between document's + * coordinate and central point of true cell + */ + val true_truedist = document.distance_to_coord(true_center) + /** + * "True distance" (rather than e.g. degree distance) between document's + * coordinate and predicted coordinate + */ + val pred_truedist = document.distance_to_coord(pred_coord) + + def record_result(stats: DocumentEvalStats) { + stats.record_predicted_distance(pred_truedist) + } +} + +/** + * Subclass of `DocumentEvaluationResult` where the predicted coordinate + * is a point, not necessarily the central point of one of the grid cells. + * + * @tparam TCoord type of a coordinate + * @tparam TDoc type of a document + * @tparam TCell type of a cell + * @tparam TGrid type of a cell grid + * + * @param document document whose coordinate is predicted + * @param cell_grid cell grid against which error comparison should be done + * @param pred_coord predicted coordinate of the document + */ +class CoordDocumentEvaluationResult[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc], + TGrid <: CellGrid[TCoord, TDoc, TCell] +]( + document: TDoc, + cell_grid: TGrid, + pred_coord: TCoord +) extends DocumentEvaluationResult[TCoord, TDoc, TCell, TGrid]( + document, cell_grid, pred_coord +) { + override def record_result(stats: DocumentEvalStats) { + super.record_result(stats) + // It doesn't really make sense to record a result as "correct" or + // "incorrect" but we need to record something; just do "false" + // FIXME: Fix the incorrect assumption here that "correct" or + // "incorrect" always exists. + stats.asInstanceOf[CoordDocumentEvalStats].record_result(false) + } +} + +/** + * Subclass of `DocumentEvaluationResult` where the predicted coordinate + * is specifically the central point of one of the grid cells. + * + * @tparam TCoord type of a coordinate + * @tparam TDoc type of a document + * @tparam TCell type of a cell + * @tparam TGrid type of a cell grid + * + * @param document document whose coordinate is predicted + * @param pred_cell top-ranked predicted cell in which the document should + * belong + * @param true_rank rank of the document's true cell among all of the + * predicted cell + */ +class RankedDocumentEvaluationResult[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc], + TGrid <: CellGrid[TCoord, TDoc, TCell] +]( + document: TDoc, + val pred_cell: TCell, + val true_rank: Int +) extends DocumentEvaluationResult[TCoord, TDoc, TCell, TGrid]( + document, pred_cell.cell_grid.asInstanceOf[TGrid], + pred_cell.get_center_coord() +) { + override def record_result(stats: DocumentEvalStats) { + super.record_result(stats) + stats.asInstanceOf[RankedDocumentEvalStats].record_true_rank(true_rank) + } +} + +/** + * A basic class for accumulating statistics from multiple evaluation + * results. + */ +trait DocumentEvalStats extends EvalStats { + // "True dist" means actual distance in km's or whatever. + val true_dists = mutable.Buffer[Double]() + val oracle_true_dists = mutable.Buffer[Double]() + + def record_predicted_distance(pred_true_dist: Double) { + true_dists += pred_true_dist + } + + def record_oracle_distance(oracle_true_dist: Double) { + oracle_true_dists += oracle_true_dist + } + + protected def output_result_with_units(result: Double): String + + override def output_incorrect_results() { + super.output_incorrect_results() + errprint(" Mean true error distance = %s", + output_result_with_units(mean(true_dists))) + errprint(" Median true error distance = %s", + output_result_with_units(median(true_dists))) + errprint(" Mean oracle true error distance = %s", + output_result_with_units(mean(oracle_true_dists))) + } +} + +/** + * A class for accumulating statistics from multiple evaluation results, + * where the results directly specify a coordinate (rather than e.g. a cell). + */ +abstract class CoordDocumentEvalStats( + driver_stats: ExperimentDriverStats, + prefix: String +) extends EvalStats(driver_stats, prefix, Map[String, String]()) + with DocumentEvalStats { +} + +/** + * A class for accumulating statistics from multiple evaluation results, + * including statistics on the rank of the true cell. + */ +abstract class RankedDocumentEvalStats( + driver_stats: ExperimentDriverStats, + prefix: String, + max_rank_for_credit: Int = 10 +) extends EvalStatsWithRank(driver_stats, prefix, max_rank_for_credit) + with DocumentEvalStats { + def record_true_rank(rank: Int) { + record_result(rank) + } +} + +/** + * Class for accumulating statistics from multiple document evaluation results, + * with separate sets of statistics for different intervals of error distances + * and number of documents in true cell. ("Grouped" in the sense that we may be + * computing not only results for the documents as a whole but also for various + * subgroups.) + * + * @tparam TCoord type of a coordinate + * @tparam TDoc type of a document + * @tparam TCell type of a cell + * @tparam TGrid type of a cell grid + * @tparam TEvalRes type of object holding result of evaluating a document + * + * @param driver_stats Object (possibly a trait) through which global-level + * program statistics can be accumulated (in a Hadoop context, this maps + * to counters). + * @param cell_grid Cell grid against which results were derived. + * @param results_by_range If true, record more detailed range-by-range + * subresults. Not on by default because Hadoop may choke on the large + * number of counters created this way. + */ +abstract class GroupedDocumentEvalStats[ + TCoord, + TDoc <: DistDocument[TCoord], + TCell <: GeoCell[TCoord, TDoc], + TGrid <: CellGrid[TCoord, TDoc, TCell], + TEvalRes <: DocumentEvaluationResult[TCoord, TDoc, TCell, TGrid] +]( + driver_stats: ExperimentDriverStats, + cell_grid: TGrid, + results_by_range: Boolean +) { + def create_stats(prefix: String): DocumentEvalStats + def create_stats_for_range[T](prefix: String, range: T) = + create_stats(prefix + ".byrange." + range) + + val all_document = create_stats("") + + // naitr = "num documents in true cell" + val docs_by_naitr = new IntTableByRange(Seq(1, 10, 25, 100), + create_stats_for_range("num_documents_in_true_cell", _)) + + // Results for documents where the location is at a certain distance + // from the center of the true statistical cell. The key is measured in + // fractions of a tiling cell (determined by 'dist_fraction_increment', + // e.g. if dist_fraction_increment = 0.25 then values in the range of + // [0.25, 0.5) go in one bin, [0.5, 0.75) go in another, etc.). We measure + // distance is two ways: true distance (in km or whatever) and "degree + // distance", as if degrees were a constant length both latitudinally + // and longitudinally. + val dist_fraction_increment = 0.25 + def docmap(prefix: String) = + new SettingDefaultHashMap[Double, DocumentEvalStats]( + create_stats_for_range(prefix, _)) + val docs_by_true_dist_to_true_center = + docmap("true_dist_to_true_center") + + // Similar, but distance between location and center of top predicted + // cell. + val dist_fractions_for_error_dist = Seq( + 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8, + 12, 16, 24, 32, 48, 64, 96, 128, 192, 256, + // We're never going to see these + 384, 512, 768, 1024, 1536, 2048) + val docs_by_true_dist_to_pred_center = + new DoubleTableByRange(dist_fractions_for_error_dist, + create_stats_for_range("true_dist_to_pred_center", _)) + + def record_one_result(stats: DocumentEvalStats, res: TEvalRes) { + res.record_result(stats) + } + + def record_one_oracle_result(stats: DocumentEvalStats, res: TEvalRes) { + stats.record_oracle_distance(res.true_truedist) + } + + def record_result(res: TEvalRes) { + record_one_result(all_document, res) + record_one_oracle_result(all_document, res) + // Stephen says recording so many counters leads to crashes (at the 51st + // counter or something), so don't do it unless called for. + if (results_by_range) + record_result_by_range(res) + } + + def record_result_by_range(res: TEvalRes) { + val naitr = docs_by_naitr.get_collector(res.num_docs_in_true_cell) + record_one_result(naitr, res) + } + + def increment_counter(name: String) { + all_document.increment_counter(name) + } + + def output_results(all_results: Boolean = false) { + errprint("") + errprint("Results for all documents:") + all_document.output_results() + /* FIXME: This code specific to MultiRegularCellGrid is kind of ugly. + Perhaps it should go elsewhere. + + FIXME: Also note that we don't actually do anything here, because of + the 'if (false)'. See above. + */ + //if (all_results) + if (false) + output_results_by_range() + // FIXME: Output median and mean of true and degree error dists; also + // maybe move this info info EvalByRank so that we can output the values + // for each category + errprint("") + output_resource_usage() + } + + def output_results_by_range() { + errprint("") + for ((lower, upper, obj) <- docs_by_naitr.iter_ranges()) { + errprint("") + errprint("Results for documents where number of documents") + errprint(" in true cell is in the range [%s,%s]:", + lower, upper - 1) + obj.output_results() + } + } +} + +///////////////////////////////////////////////////////////////////////////// +// Main evaluation code // +///////////////////////////////////////////////////////////////////////////// + +/** + * Abstract class for evaluating a corpus of test documents. + * Uses the command-line parameters to determine which documents + * should be skipped. + * + * @tparam TEvalDoc Type of document to evaluate. + * @tparam TEvalRes Type of result of evaluating a document. + * + * @param stratname Name of the strategy used for performing evaluation. + * This is output in various status messages. + * @param driver Driver class that encapsulates command-line parameters and + * such, in particular command-line parameters that allow a subset of the + * total set of documents to be evaluated. + */ +abstract class CorpusEvaluator[TEvalDoc, TEvalRes]( + stratname: String, + val driver: GridLocateDriver +) { + var documents_processed = 0 + val results = mutable.Map[TEvalDoc, TEvalRes]() + var skip_initial = driver.params.skip_initial_test_docs + var skip_n = 0 + + /** + * Return true if we should skip the next document due to parameters + * calling for certain documents in a certain sequence to be skipped. + */ + def would_skip_by_parameters() = { + var do_skip = false + if (skip_initial != 0) { + skip_initial -= 1 + do_skip = true + } else if (skip_n != 0) { + skip_n -= 1 + do_skip = true + } else + skip_n = driver.params.every_nth_test_doc - 1 + do_skip + } + + /** + * Return true if we should stop processing, given that `new_processed` + * items have already been processed. + */ + def would_stop_processing(new_processed: Int) = { + // If max # of docs reached, stop + val stop = (driver.params.num_test_docs > 0 && + new_processed >= driver.params.num_test_docs) + if (stop) { + errprint("") + errprint("Stopping because limit of %s documents reached", + driver.params.num_test_docs) + } + stop + } + + /** + * Return true if document would be skipped; false if processed and + * evaluated. + */ + def would_skip_document(doc: TEvalDoc, doctag: String) = false + + /** + * Evaluate a document. Return an object describing the results of the + * evaluation. + * + * @param document Document to evaluate. + * @param doctag A short string identifying the document (e.g. '#25'), + * to be printed out at the beginning of diagnostic lines describing + * the document and its evaluation results. + */ + def evaluate_document(doc: TEvalDoc, doctag: String): + TEvalRes + + /** + * Output results so far. If 'isfinal', this is the last call, so + * output more results. + */ + def output_results(isfinal: Boolean = false): Unit + + val task = new MeteredTask("document", "evaluating", + maxtime = driver.params.max_time_per_stage) + var last_elapsed = 0.0 + var last_processed = 0 + + /** Process a document. This checks to see whether we should evaluate + * the document (e.g. based on parameters indicating which documents + * to evaluate), and evaluates as necessary, storing the results into + * `results`. + * + * @param doc Document to be processed. + * @return Tuple `(processed, keep_going)` where `processed` indicates + * whether the document was processed or skipped, and `keep_going` + * indicates whether processing of further documents should continue or + * stop. + */ + def process_document(doc: TEvalDoc): (Boolean, Boolean) = { + // errprint("Processing document: %s", doc) + val num_processed = task.num_processed + val doctag = "#%d" format (1 + num_processed) + if (would_skip_document(doc, doctag)) { + errprint("Skipped document %s", doc) + (false, true) + } else { + val do_skip = would_skip_by_parameters() + if (do_skip) + errprint("Passed over document %s", doctag) + else { + // Don't put side-effecting code inside of an assert! + val result = evaluate_document(doc, doctag) + assert(result != null) + results(doc) = result + } + + if (task.item_processed()) + (!do_skip, false) + else { + val new_elapsed = task.elapsed_time + val new_processed = task.num_processed + + if (would_stop_processing(new_processed)) { + task.finish() + (!do_skip, false) + } else { + // If five minutes and ten documents have gone by, print out results + if ((new_elapsed - last_elapsed >= 300 && + new_processed - last_processed >= 10)) { + errprint("Results after %d documents (strategy %s):", + task.num_processed, stratname) + output_results(isfinal = false) + errprint("End of results after %d documents (strategy %s):", + task.num_processed, stratname) + last_elapsed = new_elapsed + last_processed = new_processed + } + (!do_skip, true) + } + } + } + } + + def finish() { + task.finish() + + errprint("") + errprint("Final results for strategy %s: All %d documents processed:", + stratname, task.num_processed) + errprint("Ending operation at %s", curtimehuman()) + output_results(isfinal = true) + errprint("Ending final results for strategy %s", stratname) + } + + /** Process a set of files, extracting the documents in each one and + * evaluating them using `process_document`. + */ + def process_files(filehand: FileHandler, files: Iterable[String]): Boolean +} + +/** + * Abstract class for evaluating a test document by comparing it against each + * of the cells in a cell grid, where each cell has an associated + * pseudo-document created by amalgamating all of the training documents + * in the cell. + * + * Abstract class for for evaluating a test document where a collection of + * documents has been divided into "training" and "test" sets, and the + * training set used to construct a cell grid in which the training + * documents in a particular cell are amalgamated to form a pseudo-document + * and evaluation of a test document proceeds by comparing it against each + * pseudo-document in turn. + * + * This is the highest-level evaluation class that includes the concept of a + * coordinate that is associated with training and test documents, so that + * computation of error distances possible. + * + * @tparam TCoord Type of the coordinate assigned to a document + * @tparam XTDoc Type of the training and test documents + * @tparam XTCell Type of a cell in a cell grid + * @tparam XTGrid Type of a cell grid + * @tparam TEvalRes Type of result of evaluating a document. + * + * @param strategy Object encapsulating the strategy used for performing + * evaluation. + * @param stratname Name of the strategy used for performing evaluation. + * @param driver Driver class that encapsulates command-line parameters and + * such. + * + * Note that we are forced to use the strange names `XTDoc` and `XTGrid` + * because of an apparent Scala bug that prevents use of the more obvious + * names `TDoc` and `TGrid` due to a naming clash. Possibly there is a + * solution to this problem but if so I can't figure it out. + */ +abstract class CellGridEvaluator[ + TCoord, + XTDoc <: DistDocument[TCoord], + XTCell <: GeoCell[TCoord, XTDoc], + XTGrid <: CellGrid[TCoord, XTDoc, XTCell], + TEvalRes <: DocumentEvaluationResult[TCoord, XTDoc, XTCell, XTGrid] +]( + val strategy: GridLocateDocumentStrategy[XTCell, XTGrid], + val stratname: String, + override val driver: GridLocateDocumentDriver { + type TDoc = XTDoc; type TCell = XTCell; type TGrid = XTGrid + } +) extends CorpusEvaluator[XTDoc, TEvalRes](stratname, driver) { + def create_grouped_eval_stats( + driver: GridLocateDocumentDriver, + cell_grid: XTGrid, + results_by_range: Boolean + ): GroupedDocumentEvalStats[TCoord, XTDoc, XTCell, XTGrid, TEvalRes] + + val ranker = driver.create_ranker(strategy) + + val evalstats = create_grouped_eval_stats(driver, + strategy.cell_grid, results_by_range = driver.params.results_by_range) + + def output_results(isfinal: Boolean = false) { + evalstats.output_results(all_results = isfinal) + } + + /** + * A file processor that reads corpora containing document metadata and + * creates a DistDocument for each document described, and evaluates it. + * + * @param suffix Suffix specifying the type of document file wanted + * (e.g. "counts" or "document-metadata" + * @param cell_grid Cell grid to add newly created DistDocuments to + */ + class EvaluateCorpusFileProcessor( + suffix: String + ) extends DistDocumentFileProcessor(suffix, driver) { + def handle_document(fieldvals: Seq[String]) = { + val doc = driver.document_table.create_and_init_document( + schema, fieldvals, false) + if (doc == null) (false, true) + else { + doc.dist.finish_after_global() + process_document(doc) + } + } + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + var should_stop = false + breakable { + for (line <- lines) { + if (!parse_row(line)) { + should_stop = true + break + } + } + } + output_resource_usage() + (!should_stop, ()) + } + } + + def process_files(filehand: FileHandler, files: Iterable[String]): + Boolean = { + /* NOTE: `files` must actually be a list of directories, e.g. as + comes from the value of --input-corpus. */ + for (dir <- files) { + val fileproc = new EvaluateCorpusFileProcessor( + driver.params.eval_set + "-" + driver.document_file_suffix) + fileproc.read_schema_from_textdb(filehand, dir) + val (continue, _) = fileproc.process_files(filehand, Seq(dir)) + if (!continue) + return false + } + return true + } + + //title = None + //words = [] + //for line in openr(filename, errors="replace"): + // if (rematch("Article title: (.*)$", line)) + // if (title != null) + // yield (title, words) + // title = m_[1] + // words = [] + // else if (rematch("Link: (.*)$", line)) + // args = m_[1].split('|') + // truedoc = args[0] + // linkword = truedoc + // if (len(args) > 1) + // linkword = args[1] + // words.append(linkword) + // else: + // words.append(line) + //if (title != null) + // yield (title, words) + + override def would_skip_document(document: XTDoc, doctag: String) = { + if (document.dist == null) { + // This can (and does) happen when --max-time-per-stage is set, + // so that the counts for many documents don't get read in. + if (driver.params.max_time_per_stage == 0.0 && driver.params.num_training_docs == 0) + warning("Can't evaluate document %s without distribution", document) + true + } else false + } + + /** + * Compare the document to the pseudo-documents associated with each cell, + * using the strategy for this evaluator. Return a tuple + * (pred_cells, true_rank), where: + * + * pred_cells = List of predicted cells, from best to worst; each list + * entry is actually a tuple of (cell, score) where higher scores + * are better + * true_rank = Rank of true cell among predicted cells + * + * @param document Document to evaluate. + * @param true_cell Cell in the cell grid which contains the document. + */ + def return_ranked_cells(document: XTDoc, true_cell: XTCell) = { + if (driver.params.oracle_results) + (Iterable((true_cell, 0.0)), 1) + else { + def get_computed_results() = { + val cells = ranker.evaluate(document, include = Iterable[XTCell]()) + var rank = 1 + var broken = false + breakable { + for ((cell, value) <- cells) { + if (cell eq true_cell) { + broken = true + break + } + rank += 1 + } + } + if (!broken) + rank = 1000000000 + (cells, rank) + } + + get_computed_results() + } + } + + /** + * Actual implementation of code to evaluate a document. Optionally + * Return an object describing the results of the evaluation, and + * optionally print out information on these results. + * + * @param document Document to evaluate. + * @param doctag A short string identifying the document (e.g. '#25'), + * to be printed out at the beginning of diagnostic lines describing + * the document and its evaluation results. + * @param true_cell Cell in the cell grid which contains the document. + * @param want_indiv_results Whether we should print out individual + * evaluation results for the document. + */ + def imp_evaluate_document(document: XTDoc, doctag: String, + true_cell: XTCell, want_indiv_results: Boolean): TEvalRes + + /** + * Evaluate a document, record statistics about it, etc. Calls + * `imp_evaluate_document` to do the document evaluation and optionally + * print out information on the results, and records the results in + * `evalstat`. + * + * Return an object describing the results of the evaluation. + * + * @param document Document to evaluate. + * @param doctag A short string identifying the document (e.g. '#25'), + * to be printed out at the beginning of diagnostic lines describing + * the document and its evaluation results. + */ + def evaluate_document(document: XTDoc, doctag: String): TEvalRes = { + assert(!would_skip_document(document, doctag)) + assert(document.dist.finished) + val true_cell = + strategy.cell_grid.find_best_cell_for_document(document, true) + if (debug("lots") || debug("commontop")) { + val naitr = true_cell.combined_dist.num_docs_for_word_dist + errprint("Evaluating document %s with %s word-dist documents in true cell", + document, naitr) + } + val want_indiv_results = + !driver.params.oracle_results && !driver.params.no_individual_results + val result = imp_evaluate_document(document, doctag, true_cell, + want_indiv_results) + evalstats.record_result(result) + if (result.num_docs_in_true_cell == 0) { + evalstats.increment_counter("documents.no_training_documents_in_cell") + } + result + } +} + +/** + * An implementation of `CellGridEvaluator` that compares the test + * document against each pseudo-document in the cell grid, ranks them by + * score and computes the document's location by the central point of the + * top-ranked cell. + * + * @tparam TCoord Type of the coordinate assigned to a document + * @tparam TDoc Type of the training and test documents + * @tparam TCell Type of a cell in a cell grid + * @tparam TGrid Type of a cell grid + * @tparam TEvalRes Type of result of evaluating a document. + * + * @param strategy Object encapsulating the strategy used for performing + * evaluation. + * @param stratname Name of the strategy used for performing evaluation. + * @param driver Driver class that encapsulates command-line parameters and + * such. + */ +abstract class RankedCellGridEvaluator[ + TCoord, + XTDoc <: DistDocument[TCoord], + XTCell <: GeoCell[TCoord, XTDoc], + XTGrid <: CellGrid[TCoord, XTDoc, XTCell], + TEvalRes <: DocumentEvaluationResult[TCoord, XTDoc, XTCell, XTGrid] +]( + strategy: GridLocateDocumentStrategy[XTCell, XTGrid], + stratname: String, + driver: GridLocateDocumentDriver { + type TDoc = XTDoc; type TCell = XTCell; type TGrid = XTGrid + } +) extends CellGridEvaluator[ + TCoord, XTDoc, XTCell, XTGrid, TEvalRes +](strategy, stratname, driver) { + /** + * Create an evaluation-result object describing the top-ranked + * predicted cell and the rank of the document's true cell among + * all predicted cells. + */ + def create_cell_evaluation_result(document: XTDoc, pred_cell: XTCell, + true_rank: Int): TEvalRes + + /** + * Print out the evaluation result, possibly along with some of the + * top-ranked cells. + */ + def print_individual_result(doctag: String, document: XTDoc, + result: TEvalRes, pred_cells: Iterable[(XTCell, Double)]) { + errprint("%s:Document %s:", doctag, document) + // errprint("%s:Document distribution: %s", doctag, document.dist) + errprint("%s: %d types, %f tokens", + doctag, document.dist.model.num_types, document.dist.model.num_tokens) + errprint("%s: true cell at rank: %s", doctag, + result.asInstanceOf[RankedDocumentEvaluationResult[_,_,_,_]].true_rank) + errprint("%s: true cell: %s", doctag, result.true_cell) + val num_cells_to_output = + if (driver.params.num_top_cells_to_output >= 0) + math.min(driver.params.num_top_cells_to_output, pred_cells.size) + else pred_cells.size + for (((cell, score), i) <- pred_cells.take(num_cells_to_output).zipWithIndex) { + errprint("%s: Predicted cell (at rank %s, kl-div %s): %s", + // FIXME: This assumes KL-divergence or similar scores, which have + // been negated to make larger scores better. + doctag, i + 1, -score, cell) + } + + val num_nearest_neighbors = driver.params.num_nearest_neighbors + val kNN = pred_cells.take(num_nearest_neighbors).map { + case (cell, score) => cell } + val kNNranks = pred_cells.take(num_nearest_neighbors).zipWithIndex.map { + case ((cell, score), i) => (cell, i + 1) }.toMap + val closest_half_with_dists = + kNN.map(n => (n, document.distance_to_coord(n.get_center_coord))). + toSeq.sortWith(_._2 < _._2).take(num_nearest_neighbors/2) + + closest_half_with_dists.foreach { + case (cell, dist) => + errprint("%s: #%s close neighbor: %s; error distance: %s", + doctag, kNNranks(cell), cell.get_center_coord, + document.output_distance(dist)) + } + + errprint("%s: Distance %s to true cell center at %s", + doctag, document.output_distance(result.true_truedist), result.true_center) + errprint("%s: Distance %s to predicted cell center at %s", + doctag, document.output_distance(result.pred_truedist), result.pred_coord) + + val avg_dist_of_neighbors = mean(closest_half_with_dists.map(_._2)) + errprint("%s: Average distance from true cell center to %s closest cells' centers from %s best matches: %s", + doctag, (num_nearest_neighbors/2), num_nearest_neighbors, + document.output_distance(avg_dist_of_neighbors)) + + if (avg_dist_of_neighbors < result.pred_truedist) + driver.increment_local_counter("instances.num_where_avg_dist_of_neighbors_beats_pred_truedist.%s" format num_nearest_neighbors) + } + + def imp_evaluate_document(document: XTDoc, doctag: String, + true_cell: XTCell, want_indiv_results: Boolean): TEvalRes = { + val (pred_cells, true_rank) = return_ranked_cells(document, true_cell) + val result = + create_cell_evaluation_result(document, pred_cells.head._1, true_rank) + + if (debug("all-scores")) { + for (((cell, score), index) <- pred_cells.zipWithIndex) { + errprint("%s: %6d: Cell at %s: score = %g", doctag, index + 1, + cell.describe_indices(), score) + } + } + if (want_indiv_results) { + //val cells_for_average = pred_cells.zip(pred_cells.map(_._1.center)) + //for((cell, score) <- pred_cells) { + // val scell = cell.asInstanceOf[GeoCell[GeoCoord, GeoDoc]] + //} + print_individual_result(doctag, document, result, pred_cells) + } + + return result + } +} + +/** + * A general implementation of `CellGridEvaluator` that returns a single + * best point for a given test document. + * + * @tparam TCoord Type of the coordinate assigned to a document + * @tparam TDoc Type of the training and test documents + * @tparam TCell Type of a cell in a cell grid + * @tparam TGrid Type of a cell grid + * @tparam TEvalRes Type of result of evaluating a document. + * + * @param strategy Object encapsulating the strategy used for performing + * evaluation. + * @param stratname Name of the strategy used for performing evaluation. + * @param driver Driver class that encapsulates command-line parameters and + * such. + */ +abstract class CoordCellGridEvaluator[ + TCoord, + XTDoc <: DistDocument[TCoord], + XTCell <: GeoCell[TCoord, XTDoc], + XTGrid <: CellGrid[TCoord, XTDoc, XTCell], + TEvalRes <: DocumentEvaluationResult[TCoord, XTDoc, XTCell, XTGrid] +]( + strategy: GridLocateDocumentStrategy[XTCell, XTGrid], + stratname: String, + driver: GridLocateDocumentDriver { + type TDoc = XTDoc; type TCell = XTCell; type TGrid = XTGrid + } +) extends CellGridEvaluator[ + TCoord, XTDoc, XTCell, XTGrid, TEvalRes +](strategy, stratname, driver) { + /** + * Create an evaluation-result object describing the predicted coordinate. + */ + def create_coord_evaluation_result(document: XTDoc, cell_grid: XTGrid, + pred_coord: TCoord): TEvalRes + + /** + * Print out the evaluation result. + */ + def print_individual_result(doctag: String, document: XTDoc, + result: TEvalRes) { + errprint("%s:Document %s:", doctag, document) + // errprint("%s:Document distribution: %s", doctag, document.dist) + errprint("%s: %d types, %f tokens", + doctag, document.dist.model.num_types, document.dist.model.num_tokens) + errprint("%s: true cell: %s", doctag, result.true_cell) + + errprint("%s: Distance %s to true cell center at %s", + doctag, document.output_distance(result.true_truedist), result.true_center) + errprint("%s: Distance %s to predicted cell center at %s", + doctag, document.output_distance(result.pred_truedist), result.pred_coord) + } + + def find_best_point(document: XTDoc, true_cell: XTCell): TCoord + + def imp_evaluate_document(document: XTDoc, doctag: String, + true_cell: XTCell, want_indiv_results: Boolean): TEvalRes = { + val pred_coord = find_best_point(document, true_cell) + val result = create_coord_evaluation_result(document, strategy.cell_grid, + pred_coord) + + if (want_indiv_results) + print_individual_result(doctag, document, result) + + return result + } +} + +/** + * An implementation of `CellGridEvaluator` that compares the test + * document against each pseudo-document in the cell grid, selects the + * top N ranked pseudo-documents for some N, and uses the mean-shift + * algorithm to determine a single point that is hopefully in the middle + * of the strongest cluster of points among the central points of the + * pseudo-documents. + * + * @tparam TCoord Type of the coordinate assigned to a document + * @tparam TDoc Type of the training and test documents + * @tparam TCell Type of a cell in a cell grid + * @tparam TGrid Type of a cell grid + * @tparam TEvalRes Type of result of evaluating a document. + * + * @param strategy Object encapsulating the strategy used for performing + * evaluation. + * @param stratname Name of the strategy used for performing evaluation. + * @param driver Driver class that encapsulates command-line parameters and + * such. + */ +abstract class MeanShiftCellGridEvaluator[ + TCoord, + XTDoc <: DistDocument[TCoord], + XTCell <: GeoCell[TCoord, XTDoc], + XTGrid <: CellGrid[TCoord, XTDoc, XTCell], + TEvalRes <: DocumentEvaluationResult[TCoord, XTDoc, XTCell, XTGrid] +]( + strategy: GridLocateDocumentStrategy[XTCell, XTGrid], + stratname: String, + driver: GridLocateDocumentDriver { + type TDoc = XTDoc; type TCell = XTCell; type TGrid = XTGrid + }, + k_best: Int, + mean_shift_window: Double, + mean_shift_max_stddev: Double, + mean_shift_max_iterations: Int +) extends CoordCellGridEvaluator[ + TCoord, XTDoc, XTCell, XTGrid, TEvalRes +](strategy, stratname, driver) { + def create_mean_shift_obj(h: Double, max_stddev: Double, + max_iterations: Int): MeanShift[TCoord] + + val mean_shift_obj = create_mean_shift_obj(mean_shift_window, + mean_shift_max_stddev, mean_shift_max_iterations) + + def find_best_point(document: XTDoc, true_cell: XTCell) = { + val (pred_cells, true_rank) = return_ranked_cells(document, true_cell) + val top_k = pred_cells.take(k_best).map(_._1.get_center_coord).toSeq + val shifted_values = mean_shift_obj.mean_shift(top_k) + mean_shift_obj.vec_mean(shifted_values) + } +} + +/** + * A trait used when '--eval-format' is not 'internal', i.e. the test documents + * don't come from the same corpus used to supply the training documents, + * but come from some separate text file. This is a general interface for + * iterating over files and returning the test documents in those files + * (possibly more than one per file). + */ +trait DocumentIteratingEvaluator[TEvalDoc, TEvalRes] extends + CorpusEvaluator[TEvalDoc, TEvalRes] { + /** + * Return an Iterable listing the documents retrievable from the given + * filename. + */ + def iter_documents(filehand: FileHandler, filename: String): + Iterable[TEvalDoc] + + class EvaluationFileProcessor extends FileProcessor[Unit] { + /* Process all documents in a given file. If return value is false, + processing was interrupted due to a limit being reached, and + no more files should be processed. */ + def process_file(filehand: FileHandler, filename: String): + (Boolean, Unit) = { + for (doc <- iter_documents(filehand, filename)) { + val (processed, keep_going) = process_document(doc) + if (!keep_going) + return (false, ()) + } + return (true, ()) + } + } + + def process_files(filehand: FileHandler, files: Iterable[String]) = { + val fileproc = new EvaluationFileProcessor + val (complete, _) = fileproc.process_files(filehand, files) + complete + } +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/GridLocate.scala b/src/main/scala/opennlp/fieldspring/gridlocate/GridLocate.scala new file mode 100644 index 0000000..9f45cb2 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/GridLocate.scala @@ -0,0 +1,1258 @@ +/////////////////////////////////////////////////////////////////////////////// +// GridLocate.scala +// +// Copyright (C) 2010, 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import util.matching.Regex +import util.Random +import math._ +import collection.mutable + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil.{FileHandler, LocalFileHandler} +import opennlp.fieldspring.util.osutil.output_resource_usage +import opennlp.fieldspring.util.printutil.errprint + +import opennlp.fieldspring.perceptron._ +import opennlp.fieldspring.worddist._ + +import WordDist.memoizer._ +import GridLocateDriver.Debug._ + +/* + +This file contains the main driver module and associated strategy classes +for GridLocate projects. "GridLocate" means applications that involve +searching for the best value (in some space) for a given test document by +dividing the space into a grid of some sort (not necessarily regular, and +not necessarily even with non-overlapping grid cells), aggregating all +the documents in a given cell, and finding the best value by searching for +the best grid cell and then returning some representative point (e.g. the +center) as the best value. The original application was for geolocation, +i.e. assigning a latitude/longitude coordinate to a document, and the grid +was a regular tiling of the Earth's surface based on "squares" of a given +amount of latitude and longitude on each side. But other applications are +possible, e.g. locating the date of a given biographical document, where +the space ranges over dates in time (one-dimensional) rather than over the +Earth's surface (two-dimensional). + +*/ + +///////////////////////////////////////////////////////////////////////////// +// Structures // +///////////////////////////////////////////////////////////////////////////// + +// def print_structure(struct: Any, indent: Int = 0) { +// val indstr = " "*indent +// if (struct == null) +// errprint("%snull", indstr) +// else if (struct.isInstanceOf[Tuple2[Any,Any]]) { +// val (x,y) = struct.asInstanceOf[Tuple2[Any,Any]] +// print_structure(List(x,y), indent) +// } else if (!(struct.isInstanceOf[Seq[Any]]) || +// struct.asInstanceOf[Seq[Any]].length == 0) +// errprint("%s%s", indstr, struct) +// else { +// if (struct(0).isInstanceOf[String]) { +// errprint("%s%s:", indstr, struct.asInstanceOf[String](0)) +// indstr += " " +// indent += 2 +// struct = struct.slice(1) +// } +// for (s <- struct) { +// if (isinstance(s, Seq)) +// print_structure(s, indent + 2) +// else if (isinstance(s, tuple)) { +// val (key, value) = s +// if (isinstance(value, Seq)) { +// errprint("%s%s:", indstr, key) +// print_structure(value, indent + 2) +// } +// else +// errprint("%s%s: %s", indstr, key, value) +// } +// else +// errprint("%s%s", indstr, s) +// } +// } +// } + +object GenericTypes { + type GenericDistDocument = DistDocument[_] + type GenericGeoCell = GeoCell[_, _ <: GenericDistDocument] + type GenericCellGrid = CellGrid[_, _ <: GenericDistDocument, + _ <: GenericGeoCell] + type GenericDistDocumentTable = + DistDocumentTable[_, _ <: GenericDistDocument, _ <: GenericCellGrid] + type CellGenericCellGrid[TCell <: GenericGeoCell] = CellGrid[_, _ <: GenericDistDocument, + TCell] +} +import GenericTypes._ + +///////////////////////////////////////////////////////////////////////////// +// Evaluation strategies // +///////////////////////////////////////////////////////////////////////////// + +object UnigramStrategy { + def check_unigram_dist(word_dist: WordDist) = { + word_dist match { + case x: UnigramWordDist => x + case _ => throw new IllegalArgumentException("You must use a unigram word distribution with this strategy") + } + } +} + +/** + * Abstract class for reading documents from a test file and doing + * document grid-location on them (as opposed, e.g., to trying to locate + * individual words). + */ +abstract class GridLocateDocumentStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] +]( + val cell_grid: TGrid +) { + /** + * For a given word distribution (describing a test document), return + * an Iterable of tuples, each listing a particular cell on the Earth + * and a score of some sort. The cells given in `include` must be + * included in the list. Higher scores are better. The results should + * be in sorted order, with better cells earlier. + */ + def return_ranked_cells(word_dist: WordDist, include: Iterable[TCell]): + Iterable[(TCell, Double)] +} + +/** + * Class that implements a very simple baseline strategy -- pick a random + * cell. + */ + +class RandomGridLocateDocumentStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] +]( + cell_grid: TGrid +) extends GridLocateDocumentStrategy[TCell, TGrid](cell_grid) { + def return_ranked_cells(word_dist: WordDist, include: Iterable[TCell]) = { + val cells = cell_grid.iter_nonempty_cells_including(include) + val shuffled = (new Random()).shuffle(cells) + (for (cell <- shuffled) yield (cell, 0.0)) + } +} + + /** + * Class that implements a simple baseline strategy -- pick the "most + * popular" cell (the one either with the largest number of documents, or + * the most number of links pointing to it, if `internal_link` is true). + */ + + class MostPopularCellGridLocateDocumentStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: TGrid, + internal_link: Boolean + ) extends GridLocateDocumentStrategy[TCell, TGrid](cell_grid) { + def return_ranked_cells(word_dist: WordDist, include: Iterable[TCell]) = { + (for (cell <- + cell_grid.iter_nonempty_cells_including(include)) + yield (cell, + (if (internal_link) + cell.combined_dist.incoming_links + else + cell.combined_dist.num_docs_for_links).toDouble)). + toArray sortWith (_._2 > _._2) + } + } + + /** + * Abstract class that implements a strategy for grid location that + * involves directly comparing the document distribution against each cell + * in turn and computing a score. + */ + abstract class PointwiseScoreStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: TGrid + ) extends GridLocateDocumentStrategy[TCell, TGrid](cell_grid) { + /** + * Function to return the score of a document distribution against a + * cell. + */ + def score_cell(word_dist: WordDist, cell: TCell): Double + + /** + * Compare a word distribution (for a document, typically) against all + * cells. Return a sequence of tuples (cell, score) where 'cell' + * indicates the cell and 'score' the score. + */ + def return_ranked_cells_serially(word_dist: WordDist, + include: Iterable[TCell]) = { + /* + The non-parallel way of doing things; Stephen resurrected it when + merging the Dirichlet stuff. Attempting to use the parallel method + caused an assertion failure after about 1200 of 1895 documents using + GeoText. + */ + val buffer = mutable.Buffer[(TCell, Double)]() + + for (cell <- cell_grid.iter_nonempty_cells_including( + include, nonempty_word_dist = true)) { + if (debug("lots")) { + errprint("Nonempty cell at indices %s = location %s, num_documents = %s", + cell.describe_indices(), cell.describe_location(), + cell.combined_dist.num_docs_for_word_dist) + } + + val score = score_cell(word_dist, cell) + buffer += ((cell, score)) + } + buffer + } + + /** + * Compare a word distribution (for a document, typically) against all + * cells. Return a sequence of tuples (cell, score) where 'cell' + * indicates the cell and 'score' the score. + */ + def return_ranked_cells_parallel(word_dist: WordDist, + include: Iterable[TCell]) = { + val cells = cell_grid.iter_nonempty_cells_including( + include, nonempty_word_dist = true) + cells.par.map(c => (c, score_cell(word_dist, c))).toBuffer + } + + def return_ranked_cells(word_dist: WordDist, include: Iterable[TCell]) = { + // FIXME, eliminate this global reference + val parallel = !GridLocateDriver.Params.no_parallel + val cell_buf = { + if (parallel) + return_ranked_cells_parallel(word_dist, include) + else + return_ranked_cells_serially(word_dist, include) + } + + /* SCALABUG: + If written simply as 'cell_buf sortWith (_._2 < _._2)', + return type is mutable.Buffer. However, if written as an + if/then as follows, return type is Iterable, even though both + forks have the same type of mutable.buffer! + */ + val retval = cell_buf sortWith (_._2 > _._2) + + /* If doing things parallel, this code applies for debugging + (serial has the debugging code embedded into it). */ + if (parallel && debug("lots")) { + for ((cell, score) <- retval) + errprint("Nonempty cell at indices %s = location %s, num_documents = %s, score = %s", + cell.describe_indices(), cell.describe_location(), + cell.combined_dist.num_docs_for_word_dist, score) + } + retval + } + } + + /** + * Class that implements a strategy for document geolocation by computing + * the KL-divergence between document and cell (approximately, how much + * the word distributions differ). Note that the KL-divergence as currently + * implemented uses the smoothed word distributions. + * + * @param partial If true (the default), only do "partial" KL-divergence. + * This only computes the divergence involving words in the document + * distribution, rather than considering all words in the vocabulary. + * @param symmetric If true, do a symmetric KL-divergence by computing + * the divergence in both directions and averaging the two values. + * (Not by default; the comparison is fundamentally asymmetric in + * any case since it's comparing documents against cells.) + */ + class KLDivergenceStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: TGrid, + partial: Boolean = true, + symmetric: Boolean = false + ) extends PointwiseScoreStrategy[TCell, TGrid](cell_grid) { + + var self_kl_cache: KLDivergenceCache = null + val slow = false + + def call_kl_divergence(self: WordDist, other: WordDist) = + self.kl_divergence(self_kl_cache, other, partial = partial) + + def score_cell(word_dist: WordDist, cell: TCell) = { + val cell_word_dist = cell.combined_dist.word_dist + var kldiv = call_kl_divergence(word_dist, cell_word_dist) + if (symmetric) { + val kldiv2 = cell_word_dist.kl_divergence(null, word_dist, + partial = partial) + kldiv = (kldiv + kldiv2) / 2.0 + } + // Negate so that higher scores are better + -kldiv + } + + override def return_ranked_cells(word_dist: WordDist, + include: Iterable[TCell]) = { + // This will be used by `score_cell` above. + self_kl_cache = word_dist.get_kl_divergence_cache() + + val cells = super.return_ranked_cells(word_dist, include) + + if (debug("kldiv") && word_dist.isInstanceOf[FastSlowKLDivergence]) { + val fast_slow_dist = word_dist.asInstanceOf[FastSlowKLDivergence] + // Print out the words that contribute most to the KL divergence, for + // the top-ranked cells + val num_contrib_cells = 5 + val num_contrib_words = 25 + errprint("") + errprint("KL-divergence debugging info:") + for (((cell, _), i) <- cells.take(num_contrib_cells) zipWithIndex) { + val (_, contribs) = + fast_slow_dist.slow_kl_divergence_debug( + cell.combined_dist.word_dist, partial = partial, + return_contributing_words = true) + errprint(" At rank #%s, cell %s:", i + 1, cell) + errprint(" %30s %s", "Word", "KL-div contribution") + errprint(" %s", "-" * 50) + // sort by absolute value of second element of tuple, in reverse order + val items = (contribs.toArray sortWith ((x, y) => abs(x._2) > abs(y._2))). + take(num_contrib_words) + for ((word, contribval) <- items) + errprint(" %30s %s", word, contribval) + errprint("") + } + } + + cells + } + } + + /** + * Class that implements a strategy for document geolocation by computing + * the cosine similarity between the distributions of document and cell. + * FIXME: We really should transform the distributions by TF/IDF before + * doing this. + * + * @param smoothed If true, use the smoothed word distributions. (By default, + * use unsmoothed distributions.) + * @param partial If true, only do "partial" cosine similarity. + * This only computes the similarity involving words in the document + * distribution, rather than considering all words in the vocabulary. + */ + class CosineSimilarityStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: TGrid, + smoothed: Boolean = false, + partial: Boolean = false + ) extends PointwiseScoreStrategy[TCell, TGrid](cell_grid) { + + def score_cell(word_dist: WordDist, cell: TCell) = { + var cossim = + word_dist.cosine_similarity(cell.combined_dist.word_dist, + partial = partial, smoothed = smoothed) + assert(cossim >= 0.0) + // Just in case of round-off problems + assert(cossim <= 1.002) + cossim = 1.002 - cossim + // Negate so that higher scores are better + -cossim + } + } + + /** Use a Naive Bayes strategy for comparing document and cell. */ + class NaiveBayesDocumentStrategy[ + TCell <: GenericGeoCell, + TGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: TGrid, + use_baseline: Boolean = true + ) extends PointwiseScoreStrategy[TCell, TGrid](cell_grid) { + + def score_cell(word_dist: WordDist, cell: TCell) = { + val params = cell_grid.table.driver.params + // Determine respective weightings + val (word_weight, baseline_weight) = ( + if (use_baseline) { + if (params.naive_bayes_weighting == "equal") (1.0, 1.0) + else { + val bw = params.naive_bayes_baseline_weight.toDouble + ((1.0 - bw) / word_dist.model.num_tokens, bw) + } + } else (1.0, 0.0)) + + val word_logprob = + cell.combined_dist.word_dist.get_nbayes_logprob(word_dist) + val baseline_logprob = + log(cell.combined_dist.num_docs_for_links.toDouble / + cell_grid.total_num_docs_for_links) + val logprob = (word_weight * word_logprob + + baseline_weight * baseline_logprob) + logprob + } + } + + abstract class AverageCellProbabilityStrategy[ + TCell <: GenericGeoCell, + XTGrid <: CellGenericCellGrid[TCell] + ]( + cell_grid: XTGrid + ) extends GridLocateDocumentStrategy[TCell, XTGrid](cell_grid) { + type TCellDistFactory <: + CellDistFactory[_, _ <: GenericDistDocument, TCell] { type TGrid = XTGrid } + def create_cell_dist_factory(lru_cache_size: Int): TCellDistFactory + + val cdist_factory = + create_cell_dist_factory(cell_grid.table.driver.params.lru_cache_size) + + def return_ranked_cells(word_dist: WordDist, include: Iterable[TCell]) = { + val celldist = + cdist_factory.get_cell_dist_for_word_dist(cell_grid, word_dist) + celldist.get_ranked_cells(include) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Segmentation // + ///////////////////////////////////////////////////////////////////////////// + + // General idea: Keep track of best possible segmentations up to a maximum + // number of segments. Either do it using a maximum number of segmentations + // (e.g. 100 or 1000) or all within a given factor of the best score (the + // "beam width", e.g. 10^-4). Then given the existing best segmentations, + // we search for new segmentations with more segments by looking at all + // possible ways of segmenting each of the existing best segments, and + // finding the best score for each of these. This is a slow process -- for + // each segmentation, we have to iterate over all segments, and for each + // segment we have to look at all possible ways of splitting it, and for + // each split we have to look at all assignments of cells to the two + // new segments. It also seems that we're likely to consider the same + // segmentation multiple times. + // + // In the case of per-word cell dists, we can maybe speed things up by + // computing the non-normalized distributions over each paragraph and then + // summing them up as necessary. + + ///////////////////////////////////////////////////////////////////////////// + // Stopwords // + ///////////////////////////////////////////////////////////////////////////// + + object Stopwords { + val stopwords_file_in_tg = "src/main/resources/data/%s/stopwords.txt" + + // Read in the list of stopwords from the given filename. + def read_stopwords(filehand: FileHandler, stopwords_filename: String, + language: String) = { + def compute_stopwords_filename(filename: String) = { + if (filename != null) filename + else { + val tgdir = FieldspringInfo.get_fieldspring_dir + // Concatenate directory and rest in most robust way + filehand.join_filename(tgdir, stopwords_file_in_tg format language) + } + } + val filename = compute_stopwords_filename(stopwords_filename) + errprint("Reading stopwords from %s...", filename) + filehand.openr(filename).toSet + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Whitelist // + ///////////////////////////////////////////////////////////////////////////// + + object Whitelist { + def read_whitelist(filehand: FileHandler, whitelist_filename: String): Set[String] = { + if(whitelist_filename == null || whitelist_filename.length == 0) + Nil.toSet + else + filehand.openr(whitelist_filename).toSet + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Main code // + ///////////////////////////////////////////////////////////////////////////// + + /** + * General class retrieving command-line arguments or storing programmatic + * configuration parameters for a Cell-grid-based application. + * + * @param parser If specified, should be a parser for retrieving the + * value of command-line arguments from the command line. Provided + * that the parser has been created and initialized by creating a + * previous instance of this same class with the same parser (a + * "shadow field" class), the variables below will be initialized with + * the values given by the user on the command line. Otherwise, they + * will be initialized with the default values for the parameters. + * Because they are vars, they can be freely set to other values. + * + */ + class GridLocateParameters(parser: ArgParser = null) extends + ArgParserParameters(parser) { + protected val ap = + if (parser == null) new ArgParser("unknown") else parser + + var language = + ap.option[String]("language", "lang", + default = "eng", + metavar = "LANG", + aliasedChoices = Seq( + Seq("eng", "en"), + Seq("por", "pt"), + Seq("deu", "de") + ), + help = """Name of language of corpus. Currently used only to + initialize the value of the stopwords file, if not explicitly set. + Two- and three-letter ISO-639 codes can be used. Currently recognized: + English (en, eng); German (de, deu); Portuguese (pt, por).""") + + //// Input files + var stopwords_file = + ap.option[String]("stopwords-file", "sf", + metavar = "FILE", + help = """File containing list of stopwords. If not specified, + a default list of English stopwords (stored in the Fieldspring distribution) + is used.""") + + var whitelist_file = + ap.option[String]("whitelist-file", "wf", + metavar = "FILE", + help = """File containing a whitelist of words. If specified, ONLY + words on the list will be read from any corpora; other words will be ignored. + If not specified, all words (except those on the stopword list) will be + read.""") + + var input_corpus = + ap.multiOption[String]("i", "input-corpus", + metavar = "DIR", + help = """Directory containing an input corpus. Documents in the + corpus can be Wikipedia articles, individual tweets in Twitter, the set of all + tweets for a given user, etc. The corpus generally contains one or more + "views" on the raw data comprising the corpus, with different views + corresponding to differing ways of representing the original text of the + documents -- as raw, word-split text; as unigram word counts; as n-gram word + counts; etc. Each such view has a schema file and one or more document files. + The latter contains all the data for describing each document, including + title, split (training, dev or test) and other metadata, as well as the text + or word counts that are used to create the textual distribution of the + document. The document files are laid out in a very simple database format, + consisting of one document per line, where each line is composed of a fixed + number of fields, separated by TAB characters. (E.g. one field would list + the title, another the split, another all the word counts, etc.) A separate + schema file lists the name of each expected field. Some of these names + (e.g. "title", "split", "text", "coord") have pre-defined meanings, but + arbitrary names are allowed, so that additional corpus-specific information + can be provided (e.g. retweet info for tweets that were retweeted from some + other tweet, redirect info when a Wikipedia article is a redirect to another + article, etc.). + + Multiple such files can be given by specifying the option multiple + times.""") + var eval_file = + ap.multiOption[String]("e", "eval-file", + metavar = "FILE", + help = """File or directory containing files to evaluate on. + Multiple such files/directories can be given by specifying the option multiple + times. If a directory is given, all files in the directory will be + considered (but if an error occurs upon parsing a file, it will be ignored). + Each file is read in and then disambiguation is performed. Not used during + document geolocation when --eval-format=internal (the default).""") + + var num_nearest_neighbors = + ap.option[Int]("num-nearest-neighbors", "knn", default = 4, + help = """Number of nearest neighbors (k in kNN); default is %default.""") + + var num_top_cells_to_output = + ap.option[Int]("num-top-cells-to-output", "num-top-cells", default = 5, + help = """Number of nearest neighbor cells to output; default is %default; + -1 means output all""") + + var output_training_cell_dists = + ap.flag("output-training-cell-dists", "output-training-cells", + help = """Output the training cell distributions after they've been trained.""") + + //// Options indicating which documents to train on or evaluate + var eval_set = + ap.option[String]("eval-set", "es", metavar = "SET", + default = "dev", + aliasedChoices = Seq(Seq("dev", "devel"), Seq("test")), + help = """Set to use for evaluation during document geolocation when + when --eval-format=internal ('dev' or 'devel' for the development set, + 'test' for the test set). Default '%default'.""") + var num_training_docs = + ap.option[Int]("num-training-docs", "ntrain", metavar = "NUM", + default = 0, + help = """Maximum number of training documents to use. + 0 means no limit. Default 0, i.e. no limit.""") + var num_test_docs = + ap.option[Int]("num-test-docs", "ntest", metavar = "NUM", + default = 0, + help = """Maximum number of test (evaluation) documents to process. + 0 means no limit. Default 0, i.e. no limit.""") + var skip_initial_test_docs = + ap.option[Int]("skip-initial-test-docs", "skip-initial", metavar = "NUM", + default = 0, + help = """Skip this many test docs at beginning. Default 0, i.e. + don't skip any documents.""") + var every_nth_test_doc = + ap.option[Int]("every-nth-test-doc", "every-nth", metavar = "NUM", + default = 1, + help = """Only process every Nth test doc. Default 1, i.e. + process all.""") + // def skip_every_n_test_docs = + // ap.option[Int]("skip-every-n-test-docs", "skip-n", default = 0, + // help = """Skip this many after each one processed. Default 0.""") + + //// Reranking options + var rerank = + ap.option[String]("rerank", + default = "none", + choices = Seq("none", "pointwise"), + help = """Type of reranking to do. Possibilities are + 'none', 'pointwise' (do pointwise reranking using a classifier). Default + is '%default'.""") + + var rerank_top_n = + ap.option[Int]("rerank-top-n", + default = 50, + help = """Number of top-ranked items to rerank. Default is %default.""") + + var rerank_classifier = + ap.option[String]("rerank-classifier", + default = "perceptron", + choices = Seq("perceptron", "avg-perceptron", "pa-perceptron", + "trivial"), + help = """Type of classifier to use for reranking. Possibilities are + 'perceptron' (perceptron using the basic algorithm); 'avg-perceptron' + (perceptron using the basic algorithm, where the weights from the various + rounds are averaged -- this usually improves results if the weights oscillate + around a certain error rate, rather than steadily improving); 'pa-perceptron' + (passive-aggressive perceptron, which usually leads to steady but gradually + dropping-off error rate improvements with increased number of rounds); + 'trivial' (a trivial pointwise reranker for testing purposes). + Default %default. + + For the perceptron classifiers, see also `--pa-variant`, + `--perceptron-error-threshold`, `--perceptron-aggressiveness` and + `--perceptron-rounds`.""") + + var pa_variant = + ap.option[Int]("pa-variant", + metavar = "INT", + choices = Seq(0, 1, 2), + help = """For passive-aggressive perceptron when reranking: variant + (0, 1, 2; default %default).""") + + var perceptron_error_threshold = + ap.option[Double]("perceptron-error-threshold", + metavar = "DOUBLE", + default = 1e-10, + help = """For perceptron when reranking: Total error threshold below + which training stops (default: %default).""") + + var perceptron_aggressiveness = + ap.option[Double]("perceptron-aggressiveness", + metavar = "DOUBLE", + default = 1.0, + help = """For perceptron: aggressiveness factor > 0.0 + (default: %default).""") + + var perceptron_rounds = + ap.option[Int]("perceptron-rounds", + metavar = "INT", + default = 10000, + help = """For perceptron: maximum number of training rounds + (default: %default).""") + + //// Options used when creating word distributions + var word_dist = + ap.option[String]("word-dist", "wd", + default = "pseudo-good-turing", + aliasedChoices = Seq( + Seq("pseudo-good-turing", "pgt"), + Seq("dirichlet"), + Seq("jelinek-mercer", "jelinek"), + Seq("unsmoothed-ngram")), + help = """Type of word distribution to use. Possibilities are + 'pseudo-good-turing' (a simplified version of Good-Turing over a unigram + distribution), 'dirichlet' (Dirichlet smoothing over a unigram distribution), + 'jelinek' or 'jelinek-mercer' (Jelinek-Mercer smoothing over a unigram + distribution), and 'unsmoothed-ngram' (an unsmoothed n-gram distribution). + Default '%default'. + + Note that all three involve some type of discounting, i.e. taking away a + certain amount of probability mass compared with the maximum-likelihood + distribution (which estimates 0 probability for words unobserved in a + particular document), so that unobserved words can be assigned positive + probability, based on their probability across all documents (i.e. their + global distribution). The difference is in how the discounting factor is + computed, as well as the default value for whether to do interpolation + (always mix the global distribution in) or back-off (use the global + distribution only for words not seen in the document). Jelinek-Mercer + and Dirichlet do interpolation by default, while pseudo-Good-Turing + does back-off; but this can be overridden using --interpolate. + Jelinek-Mercer uses a fixed discounting factor; Dirichlet uses a + discounting factor that gets smaller and smaller the larger the document, + while pseudo-Good-Turing uses a discounting factor that reserves total + mass for unobserved words that is equal to the total mass observed + for words seen once.""") + var interpolate = + ap.option[String]("interpolate", + default = "default", + aliasedChoices = Seq( + Seq("yes", "interpolate"), + Seq("no", "backoff"), + Seq("default")), + help = """Whether to do interpolation rather than back-off. + Possibilities are 'yes', 'no', and 'default' (which means 'yes' when doing + Dirichlet or Jelinek-Mercer smoothing, 'no' when doing pseudo-Good-Turing + smoothing).""") + var jelinek_factor = + ap.option[Double]("jelinek-factor", "jf", + default = 0.3, + help = """Smoothing factor when doing Jelinek-Mercer smoothing. + The higher the value, the more relative weight to give to the global + distribution vis-a-vis the document-specific distribution. This + should be a value between 0.0 (no smoothing at all) and 1.0 (total + smoothing, i.e. use only the global distribution and ignore + document-specific distributions entirely). Default %default.""") + var dirichlet_factor = + ap.option[Double]("dirichlet-factor", "df", + default = 500, + help = """Smoothing factor when doing Dirichlet smoothing. + The higher the value, the more relative weight to give to the global + distribution vis-a-vis the document-specific distribution. Default + %default.""") + var preserve_case_words = + ap.flag("preserve-case-words", "pcw", + help = """Don't fold the case of words used to compute and + match against document distributions. Note that in toponym resolution, + this applies only to words in documents (currently used only in Naive Bayes + matching), not to toponyms, which are always matched case-insensitively.""") + var include_stopwords_in_document_dists = + ap.flag("include-stopwords-in-document-dists", + help = """Include stopwords when computing word distributions.""") + var minimum_word_count = + ap.option[Int]("minimum-word-count", "mwc", metavar = "NUM", + default = 1, + help = """Minimum count of words to consider in word + distributions. Words whose count is less than this value are ignored.""") + var max_ngram_length = + ap.option[Int]("max-ngram-length", "mnl", metavar = "NUM", + default = 3, + help = """Maximum length of n-grams to generate when generating + n-grams from a raw document. Does not apply when reading in a corpus that + has already been parsed into n-grams (as is usually the case).""") + var tf_idf = + ap.flag("tf-idf", "tfidf", + help = """Adjust word counts according to TF-IDF weighting (i.e. + downweight words that occur in many documents).""") + + //// Options used when doing Naive Bayes geolocation + var naive_bayes_weighting = + ap.option[String]("naive-bayes-weighting", "nbw", metavar = "STRATEGY", + default = "equal", + choices = Seq("equal", "equal-words", "distance-weighted"), + help = """Strategy for weighting the different probabilities + that go into Naive Bayes. If 'equal', do pure Naive Bayes, weighting the + prior probability (baseline) and all word probabilities the same. If + 'equal-words', weight all the words the same but collectively weight all words + against the baseline, giving the baseline weight according to --baseline-weight + and assigning the remainder to the words. If 'distance-weighted', similar to + 'equal-words' but don't weight each word the same as each other word; instead, + weight the words according to distance from the toponym.""") + var naive_bayes_baseline_weight = + ap.option[Double]("naive-bayes-baseline-weight", "nbbw", + metavar = "WEIGHT", + default = 0.5, + help = """Relative weight to assign to the baseline (prior + probability) when doing weighted Naive Bayes. Default %default.""") + + //// Options used when doing ACP geolocation + var lru_cache_size = + ap.option[Int]("lru-cache-size", "lru", metavar = "SIZE", + default = 400, + help = """Number of entries in the LRU cache. Default %default. + Used only when --strategy=average-cell-probability.""") + + //// Miscellaneous options for controlling internal operation + var no_parallel = + ap.flag("no-parallel", + help = """If true, don't do ranking computations in parallel.""") + var test_kl = + ap.flag("test-kl", + help = """If true, run both fast and slow KL-divergence variations and + test to make sure results are the same.""") + + //// Debugging/output options + var max_time_per_stage = + ap.option[Double]("max-time-per-stage", "mts", metavar = "SECONDS", + default = 0.0, + help = """Maximum time per stage in seconds. If 0, no limit. + Used for testing purposes. Default 0, i.e. no limit.""") + var no_individual_results = + ap.flag("no-individual-results", "no-results", + help = """Don't show individual results for each test document.""") + var results_by_range = + ap.flag("results-by-range", + help = """Show results by range (of error distances and number of + documents in true cell). Not on by default as counters are used for this, + and setting so many counters breaks some Hadoop installations.""") + var oracle_results = + ap.flag("oracle-results", + help = """Only compute oracle results (much faster).""") + var debug = + ap.option[String]("d", "debug", metavar = "FLAGS", + help = """Output debug info of the given types. Multiple debug + parameters can be specified, indicating different types of info to output. + Separate parameters by spaces, colons or semicolons. Params can be boolean, + if given alone, or valueful, if given as PARAM=VALUE. Certain params are + list-valued; multiple values are specified by including the parameter + multiple times, or by separating values by a comma. + + The best way to figure out the possible parameters is by reading the + source code. (Look for references to debug("foo") for boolean params, + debugval("foo") for valueful params, or debuglist("foo") for list-valued + params.) Some known debug flags: + + gridrank: For the given test document number (starting at 1), output + a grid of the predicted rank for cells around the true cell. + Multiple documents can have the rank output, e.g. --debug 'gridrank=45,58' + (This will output info for documents 45 and 58.) This output can be + postprocessed to generate nice graphs; this is used e.g. in Wing's thesis. + + gridranksize: Size of the grid, in numbers of documents on a side. + This is a single number, and the grid will be a square centered on the + true cell. (Default currently 11.) + + kldiv: Print out words contributing most to KL divergence. + + wordcountdocs: Regenerate document file, filtering out documents not + seen in any counts file. + + some, lots, tons: General info of various sorts. (Document me.) + + cell: Print out info on each cell of the Earth as it's generated. Also + triggers some additional info during toponym resolution. (Document me.) + + commontop: Extra info for debugging + --baseline-strategy=link-most-common-toponym. + + pcl-travel: Extra info for debugging --eval-format=pcl-travel. + """) + + } + + class DebugSettings { + // Debug params. Different params indicate different info to output. + // Specified using --debug. Multiple params are separated by spaces, + // colons or semicolons. Params can be boolean, if given alone, or + // valueful, if given as PARAM=VALUE. Certain params are list-valued; + // multiple values are specified by including the parameter multiple + // times, or by separating values by a comma. + val debug = booleanmap[String]() + val debugval = stringmap[String]() + val debuglist = bufmap[String, String]() + + var list_debug_params = Set[String]() + + // Register a list-valued debug param. + def register_list_debug_param(param: String) { + list_debug_params += param + } + + def parse_debug_spec(debugspec: String) { + val params = """[:;\s]+""".r.split(debugspec) + // Allow params with values, and allow lists of values to be given + // by repeating the param + for (f <- params) { + if (f contains '=') { + val Array(param, value) = f.split("=", 2) + if (list_debug_params contains param) { + val values = "[,]".split(value) + debuglist(param) ++= values + } else + debugval(param) = value + } else + debug(f) = true + } + } + } + + /** + * Base class for programmatic access to document/etc. geolocation. + * Subclasses are for particular apps, e.g. GeolocateDocumentDriver for + * document-level geolocation. + * + * NOTE: Currently the code has some values stored in singleton objects, + * and no clear provided interface for resetting them. This basically + * means that there can be only one geolocation instance per JVM. + * By now, most of the singleton objects have been removed, and it should + * not be difficult to remove the final limitations so that multiple + * drivers per JVM (possibly not at the same time) can be done. + * + * Basic operation: + * + * 1. Create an instance of the appropriate subclass of GeolocateParameters + * (e.g. GeolocateDocumentParameters for document geolocation) and populate + * it with the appropriate parameters. Don't pass in any ArgParser instance, + * as is the default; that way, the parameters will get initialized to their + * default values, and you only have to change the ones you want to be + * non-default. + * 2. Call run(), passing in the instance you just created. + * + * NOTE: Currently, some of the fields of the GeolocateParameters-subclass + * are changed to more canonical values. If this is a problem, let me + * know and I'll fix it. + * + * Evaluation output is currently written to standard error, and info is + * also returned by the run() function. There are some scripts to parse the + * console output. See below. + */ + trait GridLocateDriver extends HadoopableArgParserExperimentDriver { + type TDoc <: DistDocument[_] + type TCell <: GeoCell[_, TDoc] + type TGrid <: CellGrid[_, TDoc, TCell] + type TDocTable <: DistDocumentTable[_, TDoc, TGrid] + override type TParam <: GridLocateParameters + + var stopwords: Set[String] = _ + var whitelist: Set[String] = _ + var cell_grid: TGrid = _ + var document_table: TDocTable = _ + var word_dist_factory: WordDistFactory = _ + var word_dist_constructor: WordDistConstructor = _ + var document_file_suffix: String = _ + var output_training_cell_dists: Boolean = _ + + /** + * Set the options to those as given. NOTE: Currently, some of the + * fields in this structure will be changed (canonicalized). See above. + * If options are illegal, an error will be signaled. + * + * @param options Object holding options to set + */ + def handle_parameters() { + /* FIXME: Eliminate this. */ + GridLocateDriver.Params = params + + if (params.debug != null) + parse_debug_spec(params.debug) + + need_seq(params.input_corpus, "input-corpus") + + if (params.jelinek_factor < 0.0 || params.jelinek_factor > 1.0) { + param_error("Value for --jelinek-factor must be between 0.0 and 1.0, but is %g" format params.jelinek_factor) + } + + // Need to have `document_file_suffix` set early on, but factory + // shouldn't be created till setup_for_run() because factory may + // depend on auxiliary parameters set during this stage (e.g. during + // GenerateKML). + document_file_suffix = initialize_word_dist_suffix() + } + + protected def initialize_document_table(word_dist_factory: WordDistFactory): + TDocTable + + protected def initialize_cell_grid(table: TDocTable): TGrid + + protected def word_dist_type = { + if (params.word_dist == "unsmoothed-ngram") "ngram" + else "unigram" + } + + protected def initialize_word_dist_suffix() = { + if (word_dist_type == "ngram") + textdbutil.ngram_counts_suffix + else + textdbutil.unigram_counts_suffix + } + + protected def get_stopwords() = { + if (params.include_stopwords_in_document_dists) Set[String]() + else stopwords + } + + protected def get_whitelist() = { + whitelist + } + + protected def initialize_word_dist_constructor(factory: WordDistFactory) = { + val the_stopwords = get_stopwords() + val the_whitelist = get_whitelist() + if (word_dist_type == "ngram") + new DefaultNgramWordDistConstructor( + factory, + ignore_case = !params.preserve_case_words, + stopwords = the_stopwords, + whitelist = the_whitelist, + minimum_word_count = params.minimum_word_count, + max_ngram_length = params.max_ngram_length) + else + new DefaultUnigramWordDistConstructor( + factory, + ignore_case = !params.preserve_case_words, + stopwords = the_stopwords, + whitelist = the_whitelist, + minimum_word_count = params.minimum_word_count) + } + + protected def initialize_word_dist_factory() = { + if (params.word_dist == "unsmoothed-ngram") + new UnsmoothedNgramWordDistFactory + else if (params.word_dist == "dirichlet") + new DirichletUnigramWordDistFactory(params.interpolate, + params.dirichlet_factor) + else if (params.word_dist == "jelinek-mercer") + new JelinekMercerUnigramWordDistFactory(params.interpolate, + params.jelinek_factor) + else + new PseudoGoodTuringUnigramWordDistFactory(params.interpolate) + } + + protected def read_stopwords() = { + Stopwords.read_stopwords(get_file_handler, params.stopwords_file, + params.language) + } + + protected def read_whitelist() = { + Whitelist.read_whitelist(get_file_handler, params.whitelist_file) + } + + protected def read_documents(table: TDocTable) { + for (fn <- params.input_corpus) + table.read_training_documents(get_file_handler, fn, + document_file_suffix, cell_grid) + table.finish_document_loading() + } + + def setup_for_run() { + stopwords = read_stopwords() + whitelist = read_whitelist() + word_dist_factory = initialize_word_dist_factory() + word_dist_constructor = initialize_word_dist_constructor(word_dist_factory) + word_dist_factory.set_word_dist_constructor(word_dist_constructor) + document_table = initialize_document_table(word_dist_factory) + cell_grid = initialize_cell_grid(document_table) + // This accesses the stopwords and whitelist through the pointer to this in + // document_table. + read_documents(document_table) + if (debug("stop-after-reading-dists")) { + errprint("Stopping abruptly because debug flag stop-after-reading-dists set") + output_resource_usage() + // We throw to top level before exiting because hprof tends to report + // too much garbage as if it were live. Unwinding the stack may fix + // some of that. If you don't want this unwinding, comment out the + // throw and uncomment the call to System.exit(). + throw new GridLocateAbruptExit + // System.exit(0) + } + cell_grid.finish() + if(params.output_training_cell_dists) { + for(cell <- cell_grid.iter_nonempty_cells(nonempty_word_dist = true)) { + print(cell.shortstr+"\t") + val word_dist = cell.combined_dist.word_dist + println(word_dist.toString) + } + } + } + + /** + * Given a list of strategies, process each in turn, evaluating all + * documents using the strategy. + * + * @param strategies List of (name, strategy) pairs, giving strategy + * names and objects. + * @param geneval Function to create an evaluator object to evaluate + * all documents, given a strategy. + * @tparam T Supertype of all the strategy objects. + */ + protected def process_strategies[T](strategies: Seq[(String, T)])( + geneval: (String, T) => CorpusEvaluator[_,_]) = { + for ((stratname, strategy) <- strategies) yield { + val evalobj = geneval(stratname, strategy) + // For --eval-format=internal, there is no eval file. To make the + // evaluation loop work properly, we pretend like there's a single + // eval file whose value is null. + val iterfiles = + if (params.eval_file.length > 0) params.eval_file + else params.input_corpus + evalobj.process_files(get_file_handler, iterfiles) + evalobj.finish() + (stratname, strategy, evalobj) + } + } +} + +trait GridLocateDocumentDriver extends GridLocateDriver { + var strategies: Seq[(String, GridLocateDocumentStrategy[TCell, TGrid])] = _ + var rankers: Map[GridLocateDocumentStrategy[TCell, TGrid], + Ranker[TDoc, TCell]] = _ + + override def handle_parameters() { + super.handle_parameters() + if (params.perceptron_aggressiveness <= 0) + param_error("Perceptron aggressiveness value should be strictly greater than zero") + } + + def create_strategy(stratname: String) = { + stratname match { + case "random" => + new RandomGridLocateDocumentStrategy[TCell, TGrid](cell_grid) + case "internal-link" => + new MostPopularCellGridLocateDocumentStrategy[TCell, TGrid]( + cell_grid, true) + case "num-documents" => + new MostPopularCellGridLocateDocumentStrategy[TCell, TGrid]( + cell_grid, false) + case "naive-bayes-no-baseline" => + new NaiveBayesDocumentStrategy[TCell, TGrid](cell_grid, false) + case "naive-bayes-with-baseline" => + new NaiveBayesDocumentStrategy[TCell, TGrid](cell_grid, true) + case "cosine-similarity" => + new CosineSimilarityStrategy[TCell, TGrid](cell_grid, smoothed = false, + partial = false) + case "partial-cosine-similarity" => + new CosineSimilarityStrategy[TCell, TGrid](cell_grid, smoothed = false, + partial = true) + case "smoothed-cosine-similarity" => + new CosineSimilarityStrategy[TCell, TGrid](cell_grid, smoothed = true, + partial = false) + case "smoothed-partial-cosine-similarity" => + new CosineSimilarityStrategy[TCell, TGrid](cell_grid, smoothed = true, + partial = true) + case "full-kl-divergence" => + new KLDivergenceStrategy[TCell, TGrid](cell_grid, symmetric = false, + partial = false) + case "partial-kl-divergence" => + new KLDivergenceStrategy[TCell, TGrid](cell_grid, symmetric = false, + partial = true) + case "symmetric-full-kl-divergence" => + new KLDivergenceStrategy[TCell, TGrid](cell_grid, symmetric = true, + partial = false) + case "symmetric-partial-kl-divergence" => + new KLDivergenceStrategy[TCell, TGrid](cell_grid, symmetric = true, + partial = true) + case "none" => + null + case other => { + assert(false, "Internal error: Unhandled strategy %s" format other) + null + } + } + } + + def create_strategies(): Seq[(String, GridLocateDocumentStrategy[TCell, TGrid])] + + protected def create_pointwise_classifier_trainer() = { + params.rerank_classifier match { + case "pa-perceptron" => + new PassiveAggressiveBinaryPerceptronTrainer( + params.pa_variant, params.perceptron_aggressiveness, + error_threshold = params.perceptron_error_threshold, + max_iterations = params.perceptron_rounds) + case "perceptron" | "avg-perceptron" => + new BasicBinaryPerceptronTrainer( + params.perceptron_aggressiveness, + error_threshold = params.perceptron_error_threshold, + max_iterations = params.perceptron_rounds, + averaged = params.rerank_classifier == "avg-perceptron") + } + } + + protected def create_basic_ranker( + strategy: GridLocateDocumentStrategy[TCell, TGrid] + ) = { + new CellGridRanker[TDoc, TCell, TGrid](strategy) + } + + def create_ranker(strategy: GridLocateDocumentStrategy[TCell, TGrid]) = { + val basic_ranker = create_basic_ranker(strategy) + if (params.rerank == "none") basic_ranker + else params.rerank_classifier match { + case "trivial" => + new TrivialDistDocumentReranker[TDoc, TCell, TGrid]( + basic_ranker, params.rerank_top_n) + // FIXME!!! + case _ => basic_ranker + } + } + + /** + * Set everything up for document grid-location. Create and save a + * sequence of strategy objects. + */ + override def setup_for_run() { + super.setup_for_run() + strategies = create_strategies() + } +} + +object GridLocateDriver { + var Params: GridLocateParameters = _ + val Debug: DebugSettings = new DebugSettings + + // Debug flags (from SphereCellGridEvaluator) -- need to set them + // here before we parse the command-line debug settings. (FIXME, should + // be a better way that introduces fewer long-range dependencies like + // this) + // + // gridrank: For the given test document number (starting at 1), output + // a grid of the predicted rank for cells around the true + // cell. Multiple documents can have the rank output, e.g. + // + // --debug 'gridrank=45,58' + // + // (This will output info for documents 45 and 58.) + // + // gridranksize: Size of the grid, in numbers of documents on a side. + // This is a single number, and the grid will be a square + // centered on the true cell. + register_list_debug_param("gridrank") + debugval("gridranksize") = "11" +} + +class GridLocateAbruptExit extends Throwable { } + +abstract class GridLocateApp(appname: String) extends + ExperimentDriverApp(appname) { + type TDriver <: GridLocateDriver + + override def run_program() = { + try { + super.run_program() + } catch { + case e:GridLocateAbruptExit => { + errprint("Caught abrupt exit throw, exiting") + 0 + } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/Reranker.scala b/src/main/scala/opennlp/fieldspring/gridlocate/Reranker.scala new file mode 100644 index 0000000..147cdfb --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/Reranker.scala @@ -0,0 +1,224 @@ +package opennlp.fieldspring.gridlocate + +import opennlp.fieldspring.util.printutil._ +import opennlp.fieldspring.perceptron._ +/** + * A basic ranker. Given a test item, return a list of ranked answers from + * best to worst, with a score for each. The score must not increase from + * any answer to the next one. + */ +trait Ranker[TestItem, Answer] { + /** + * Evaluate a test item, returning a list of ranked answers from best to + * worst, with a score for each. The score must not increase from any + * answer to the next one. Any answers mentioned in `include` must be + * included in the returned list. + */ + def evaluate(item: TestItem, include: Iterable[Answer]): + Iterable[(Answer, Double)] +} + +/** + * A scoring binary classifier. Given a test item, return a score, with + * higher numbers indicating greater likelihood of being "positive". + */ +trait ScoringBinaryClassifier[TestItem] { + /** + * The value of `minimum_positive` indicates the dividing line between values + * that should be considered positive and those that should be considered + * negative; typically this will be 0 or 0.5. */ + def minimum_positive: Double + def score_item(item: TestItem): Double +} + +/** + * A reranker. This is a particular type of ranker that involves two + * steps: One to compute an initial ranking, and a second to do a more + * accurate reranking of some subset of the highest-ranked answers. The + * idea is that a more accurate value can be computed when there are fewer + * possible answers to distinguish -- assuming that the initial ranker + * is able to include the correct answer near the top significantly more + * often than actually at the top. + */ +trait Reranker[TestItem, Answer] extends Ranker[TestItem, Answer] { +} + +/** + * A pointwise reranker that uses a scoring classifier to assign a score + * to each possible answer to be reranked. The idea is that + */ +trait PointwiseClassifyingReranker[TestItem, RerankInstance, Answer] + extends Reranker[TestItem, Answer] { + /** Ranker for generating initial ranking. */ + protected def initial_ranker: Ranker[TestItem, Answer] + /** Scoring classifier for use in reranking. */ + protected def rerank_classifier: ScoringBinaryClassifier[RerankInstance] + + /** + * Create a reranking training instance to feed to the classifier, given + * a test item, a potential answer from the ranker, and whether the answer + * is correct. These training instances will be used to train the + * classifier. + */ + protected def create_rerank_instance(item: TestItem, possible_answer: Answer, + score: Double): RerankInstance + + /** + * Number of top-ranked items to submit to reranking. + */ + protected val top_n: Int + + /** + * Generate rerank training instances for a given ranker + * training instance. + */ + protected def get_rerank_training_instances(item: TestItem, + true_answer: Answer) = { + val answers = initial_ranker.evaluate(item, Iterable(true_answer)).take(top_n) + for {(possible_answer, score) <- answers + is_correct = possible_answer == true_answer + } + yield ( + create_rerank_instance(item, possible_answer, score), is_correct) + } + + protected def rerank_answers(item: TestItem, + answers: Iterable[(Answer, Double)]) = { + val new_scores = + for {(answer, score) <- answers + instance = create_rerank_instance(item, answer, score) + new_score = rerank_classifier.score_item(instance) + } + yield (answer, new_score) + new_scores.toSeq sortWith (_._2 > _._2) + } + + def evaluate(item: TestItem, include: Iterable[Answer]) = { + val initial_answers = initial_ranker.evaluate(item, include) + val (to_rerank, others) = initial_answers.splitAt(top_n) + rerank_answers(item, to_rerank) ++ others + } +} + +/** + * A pointwise reranker that uses a scoring classifier to assign a score + * to each possible answer to be reranked. The idea is that + */ +trait PointwiseClassifyingRerankerWithTrainingData[ + TestItem, RerankInstance, Answer] extends + PointwiseClassifyingReranker[TestItem, RerankInstance, Answer] { + /** + * Training data used to create the reranker. + */ + protected val training_data: Iterable[(TestItem, Answer)] + + /** + * Create the classifier used for reranking, given a set of training data + * (in the form of pairs of reranking instances and whether they represent + * correct answers). + */ + protected def create_rerank_classifier( + data: Iterable[(RerankInstance, Boolean)] + ): ScoringBinaryClassifier[RerankInstance] + + protected val rerank_classifier = { + val rerank_training_data = training_data.flatMap { + case (item, true_answer) => + get_rerank_training_instances(item, true_answer) + } + create_rerank_classifier(rerank_training_data) + } +} + +/** + * @tparam TDoc Type of the training and test documents + * @tparam TCell Type of a cell in a cell grid + * @tparam TGrid Type of a cell grid + * + * @param strategy Object encapsulating the strategy used for performing + * evaluation. + */ +class CellGridRanker[ + TDoc <: DistDocument[_], + TCell <: GeoCell[_, TDoc], + TGrid <: CellGrid[_, TDoc, TCell] +]( + strategy: GridLocateDocumentStrategy[TCell, TGrid] +) extends Ranker[TDoc, TCell] { + def evaluate(item: TDoc, include: Iterable[TCell]) = + strategy.return_ranked_cells(item.dist, include) +} + +class DistDocumentRerankInstance[ + TDoc <: DistDocument[_], + TCell <: GeoCell[_, TDoc], + TGrid <: CellGrid[_, TDoc, TCell] +]( + doc: TDoc, cell: TCell, score: Double +) extends SparseFeatureVector(Map("score" -> score)) { +} + +abstract class DistDocumentReranker[ + TDoc <: DistDocument[_], + TCell <: GeoCell[_, TDoc], + TGrid <: CellGrid[_, TDoc, TCell] +]( + _initial_ranker: Ranker[TDoc, TCell], + _top_n: Int +) extends PointwiseClassifyingReranker[ + TDoc, + DistDocumentRerankInstance[TDoc, TCell, TGrid], + TCell] { + protected val top_n = _top_n + protected val initial_ranker = _initial_ranker + + protected def create_rerank_instance(item: TDoc, possible_answer: TCell, + score: Double) = + new DistDocumentRerankInstance[TDoc, TCell, TGrid]( + item, possible_answer, score) +} + +/** + * A trivial scoring binary classifier that simply returns the already + * existing score from a ranker. + */ +class TrivialScoringBinaryClassifier[TestItem <: SparseFeatureVector]( + val minimum_positive: Double +) extends ScoringBinaryClassifier[TestItem] { + def score_item(item: TestItem) = { + val retval = item("score") + errprint("Trivial scoring item %s = %s", item, retval) + retval + } +} + +class TrivialDistDocumentReranker[ + TDoc <: DistDocument[_], + TCell <: GeoCell[_, TDoc], + TGrid <: CellGrid[_, TDoc, TCell] +]( + _initial_ranker: Ranker[TDoc, TCell], + _top_n: Int +) extends DistDocumentReranker[TDoc, TCell, TGrid]( + _initial_ranker, _top_n +) { + val rerank_classifier = + new TrivialScoringBinaryClassifier[ + DistDocumentRerankInstance[TDoc, TCell, TGrid]]( + 0 // FIXME: This is incorrect but doesn't matter + ) +} + +abstract class PerceptronDistDocumentReranker[ + TDoc <: DistDocument[_], + TCell <: GeoCell[_, TDoc], + TGrid <: CellGrid[_, TDoc, TCell] +]( + _initial_ranker: Ranker[TDoc, TCell], + _top_n: Int +) extends DistDocumentReranker[TDoc, TCell, TGrid]( + _initial_ranker, _top_n +) { + // FIXME! + val rerank_classifier = null +} diff --git a/src/main/scala/opennlp/fieldspring/gridlocate/TextGrounderInfo.scala b/src/main/scala/opennlp/fieldspring/gridlocate/TextGrounderInfo.scala new file mode 100644 index 0000000..604c5cb --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/gridlocate/TextGrounderInfo.scala @@ -0,0 +1,44 @@ +/////////////////////////////////////////////////////////////////////////////// +// FieldspringInfo.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.gridlocate + +import opennlp.fieldspring.util.printutil.errprint + +/** + Fieldspring-specific information (e.g. env vars). + */ + +object FieldspringInfo { + var fieldspring_dir: String = null + + def set_fieldspring_dir(dir: String) { + fieldspring_dir = dir + } + + def get_fieldspring_dir() = { + if (fieldspring_dir == null) + fieldspring_dir = System.getenv("FIELDSPRING_DIR") + if (fieldspring_dir == null) { + errprint("""FIELDSPRING_DIR must be set to the top-level directory where +Fieldspring is installed.""") + require(fieldspring_dir != null) + } + fieldspring_dir + } +} diff --git a/src/main/scala/opennlp/fieldspring/perceptron/Memoizer.scala b/src/main/scala/opennlp/fieldspring/perceptron/Memoizer.scala new file mode 100644 index 0000000..9e1b3cf --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/perceptron/Memoizer.scala @@ -0,0 +1,126 @@ +/////////////////////////////////////////////////////////////////////////////// +// Memoizer.scala +// +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.perceptron + +import collection.mutable + +/** + * A class for "memoizing" words, i.e. mapping them to some other type + * (e.g. Int) that should be faster to compare and potentially require + * less space. + */ +abstract class Memoizer { + /** + * The type of a memoized word. + */ + type Word + /** + * Map a word as a string to its memoized form. + */ + def memoize_string(word: String): Word + /** + * Map a word from its memoized form back to a string. + */ + def unmemoize_string(word: Word): String + + /** + * The type of a mutable map from memoized words to Ints. + */ + type WordIntMap + /** + * Create a mutable map from memoized words to Ints. + */ + def create_word_int_map(): WordIntMap + /** + * The type of a mutable map from memoized words to Doubles. + */ + type WordDoubleMap + /** + * Create a mutable map from memoized words to Doubles. + */ + def create_word_double_map(): WordDoubleMap + + lazy val blank_memoized_string = memoize_string("") + + def lowercase_memoized_word(word: Word) = + memoize_string(unmemoize_string(word).toLowerCase) +} + +/** + * The memoizer we actually use. Maps word strings to Ints. + * + * @param minimum_index Minimum index used, usually either 0 or 1. + */ +class IntStringMemoizer(val minimum_index: Int = 0) extends Memoizer { + type Word = Int + + protected var next_word_count: Word = minimum_index + + def number_of_entries = next_word_count - minimum_index + + // For replacing strings with ints. This should save space on 64-bit + // machines (string pointers are 8 bytes, ints are 4 bytes) and might + // also speed lookup. + protected val word_id_map = mutable.Map[String,Word]() + //protected val word_id_map = trovescala.ObjectIntMap[String]() + + // Map in the opposite direction. + protected val id_word_map = mutable.Map[Word,String]() + //protected val id_word_map = trovescala.IntObjectMap[String]() + + def memoize_string(word: String) = { + val index = word_id_map.getOrElse(word, -1) + // println("Saw word=%s, index=%s" format (word, index)) + if (index != -1) index + else { + val newind = next_word_count + next_word_count += 1 + word_id_map(word) = newind + id_word_map(newind) = word + newind + } + } + + def unmemoize_string(word: Word) = id_word_map(word) + + //def create_word_int_map() = trovescala.IntIntMap() + //type WordIntMap = trovescala.IntIntMap + //def create_word_double_map() = trovescala.IntDoubleMap() + //type WordDoubleMap = trovescala.IntDoubleMap + def create_word_int_map() = mutable.Map[Word,Int]() + type WordIntMap = mutable.Map[Word,Int] + def create_word_double_map() = mutable.Map[Word,Double]() + type WordDoubleMap = mutable.Map[Word,Double] +} + +// /** +// * Version that uses Trove for extremely fast and memory-efficient hash +// * tables, making use of the Trove-Scala interface for easy access to the +// * Trove hash tables. +// */ +// class TroveIntStringMemoizer( +// minimum_index: Int = 0 +// ) extends IntStringMemoizer(minimum_index) { +// override protected val word_id_map = trovescala.ObjectIntMap[String]() +// override protected val id_word_map = trovescala.IntObjectMap[String]() +// override def create_word_int_map() = trovescala.IntIntMap() +// override type WordIntMap = trovescala.IntIntMap +// override def create_word_double_map() = trovescala.IntDoubleMap() +// override type WordDoubleMap = trovescala.IntDoubleMap +// } diff --git a/src/main/scala/opennlp/fieldspring/perceptron/Perceptron.scala b/src/main/scala/opennlp/fieldspring/perceptron/Perceptron.scala new file mode 100644 index 0000000..fa295a2 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/perceptron/Perceptron.scala @@ -0,0 +1,934 @@ + /////////////////////////////////////////////////////////////////////////////// +// Perceptron.scala +// +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.perceptron + +/** + * A perceptron for binary classification. + * + * @author Ben Wing + */ + +import util.control.Breaks._ +import collection.mutable +import io.Source + +/** + * A vector of real-valued features. In general, features are indexed + * both by a non-negative integer and by a class label (i.e. a label for + * the class that is associated with a particular instance by a classifier). + * Commonly, the class label is ignored when looking up a feature's value. + * Some implementations might want to evaluate the features on-the-fly + * rather than store an actual vector of values. + */ +trait FeatureVector { + /** Return the length of the feature vector. This is the number of weights + * that need to be created -- not necessarily the actual number of items + * stored in the vector (which will be different especially in the case + * of sparse vectors). */ + def length: Int + + /** Return the value at index `i`, for class `label`. */ + def apply(i: Int, label: Int): Double + + /** Return the squared magnitude of the feature vector for class `label`, + * i.e. dot product of feature vector with itself */ + def squared_magnitude(label: Int): Double + + /** Return the squared magnitude of the difference between the values of + * this feature vector for the two labels `label1` and `label2`. */ + def diff_squared_magnitude(label1: Int, label2: Int): Double + + /** Return the dot product of the given weight vector with the feature + * vector for class `label`. */ + def dot_product(weights: WeightVector, label: Int): Double + + /** Update a weight vector by adding a scaled version of the feature vector, + * with class `label`. */ + def update_weights(weights: WeightVector, scale: Double, label: Int) +} + +/** + * A feature vector that ignores the class label. + */ +trait SimpleFeatureVector extends FeatureVector { + /** Return the value at index `i`. */ + def apply(i: Int): Double + + def apply(i: Int, label: Int) = apply(i) +} + +/** + * A feature vector in which the features are stored densely, i.e. as + * an array of values. + */ +trait DenseFeatureVector extends FeatureVector { + def dot_product(weights: WeightVector, label: Int) = + (for (i <- 0 until length) yield apply(i, label)*weights(i)).sum + + def squared_magnitude(label: Int) = + (for (i <- 0 until length; va = apply(i, label)) yield va*va).sum + + def diff_squared_magnitude(label1: Int, label2: Int) = + (for (i <- 0 until length; va = apply(i, label1) - apply(i, label2)) + yield va*va).sum + + def update_weights(weights: WeightVector, scale: Double, label: Int) { + (0 until length).foreach(i => { weights(i) += scale*apply(i, label) }) + } +} + +/** + * A vector of real-valued features, stored explicitly. The values passed in + * are used exactly as the values of the feature; no additional term is + * inserted to handle a "bias" or "intercept" weight. + */ +class RawArrayFeatureVector( + values: WeightVector +) extends DenseFeatureVector with SimpleFeatureVector { + /** Add two feature vectors. */ + def +(other: SimpleFeatureVector) = { + val len = length + val res = new WeightVector(len) + for (i <- 0 until len) + res(i) = this(i) + other(i) + new RawArrayFeatureVector(res) + } + + /** Subtract two feature vectors. */ + def -(other: SimpleFeatureVector) = { + val len = length + val res = new WeightVector(len) + for (i <- 0 until len) + res(i) = this(i) - other(i) + new RawArrayFeatureVector(res) + } + + /** Scale a feature vector. */ + def *(scalar: Double) = { + val len = length + val res = new WeightVector(len) + for (i <- 0 until len) + res(i) = this(i)*scalar + new RawArrayFeatureVector(res) + } + + /** Return the length of the feature vector. */ + def length = values.length + + /** Return the value at index `i`. */ + def apply(i: Int) = values(i) + + def update(i: Int, value: Double) { values(i) = value } +} + +/** + * A vector of real-valued features, stored explicitly. An additional value + * set to a constant 1 is automatically stored at the end of the vector. + */ +class ArrayFeatureVector( + values: WeightVector +) extends DenseFeatureVector with SimpleFeatureVector { + /** Return the length of the feature vector; + 1 including the extra bias + * term. */ + def length = values.length + 1 + + /** Return the value at index `i`, but return 1.0 at the last index. */ + def apply(i: Int) = { + if (i == values.length) 1.0 + else values(i) + } + + def update(i: Int, value: Double) { + if (i == values.length) { + if (value != 1.0) { + throw new IllegalArgumentException( + "Element at the last index (index %s) unmodifiable, fixed at 1.0" + format i) + } + } else { values(i) = value } + } +} + +/** + * A feature vector in which the features are stored sparsely, i.e. only + * the features with non-zero values are stored, using a hash table or + * similar. + */ +class SparseFeatureVector( + feature_values: Map[String, Double] +) extends SimpleFeatureVector { + protected val memoized_features = Map(0 -> 0.0) ++ // the intercept term + feature_values.map { + case (name, value) => + (SparseFeatureVector.feature_mapper.memoize_string(name), value) + } + + def length = { + // +1 because of the intercept term + SparseFeatureVector.feature_mapper.number_of_entries + 1 + } + + def apply(index: Int) = memoized_features.getOrElse(index, 0.0) + def apply(feature: String): Double = + apply(SparseFeatureVector.feature_mapper.memoize_string(feature)) + + def squared_magnitude(label: Int) = + memoized_features.map { + case (index, value) => value * value + }.sum + + def diff_squared_magnitude(label1: Int, label2: Int) = 0.0 + + def dot_product(weights: WeightVector, label: Int) = + memoized_features.map { + case (index, value) => value * weights(index) + }.sum + + def update_weights(weights: WeightVector, scale: Double, label: Int) { + memoized_features.map { + case (index, value) => weights(index) += scale * value + } + } + + override def toString = { + "SparseFeatureVector(%s)" format + memoized_features.filter { case (index, value) => value > 0}. + toSeq.sorted.map { + case (index, value) => + "%s(%s)=%.2f" format ( + SparseFeatureVector.feature_mapper.unmemoize_string(index), + index, value + ) + }.mkString(",") + } +} + +object SparseFeatureVector { + // Set the minimum index to 1 so we can use 0 for the intercept term + val feature_mapper = new IntStringMemoizer(minimum_index = 1) +} + +/** + * A sparse feature vector built up out of nominal strings. A global + * mapping table is maintained to convert between strings and array + * indices into a logical vector. + */ +class SparseNominalFeatureVector( + nominal_features: Iterable[String] +) extends SparseFeatureVector( + nominal_features.map((_, 1.0)).toMap +) { + override def toString = { + "SparseNominalFeatureVector(%s)" format + memoized_features.filter { case (index, value) => value > 0}. + toSeq.sorted.map { + case (index, value) => + "%s(%s)" format ( + SparseFeatureVector.feature_mapper.unmemoize_string(index), + index + ) + }.mkString(",") + } +} + +/** + * A data instance (a statistical unit), consisting of a feature vector + * specifying the characteristics of the instance and a label, to be + * predicted. + * + * @tparam LabelType type of the label (e.g. Int for classification, + * Double for regression, etc.). + */ +abstract class Instance[LabelType] { + /** Return the label. */ + def getLabel: LabelType + /** Return the feature vector. */ + def getFeatures: FeatureVector +} + +/** + * A factory object for creating sparse nominal instances for classification, + * consisting of a nominal label and a set of nominal features. "Nominal" + * in this case means data described using an arbitrary string. Nominal + * features are either present or absent, and nominal labels have no ordering + * or other numerical significance. + */ +class SparseNominalInstanceFactory { + val label_mapper = new IntStringMemoizer(minimum_index = 0) + def label_to_index(label: String) = label_mapper.memoize_string(label) + def index_to_label(index: Int) = label_mapper.unmemoize_string(index) + def number_of_labels = label_mapper.number_of_entries + + def make_labeled_instance(features: Iterable[String], label: String) = { + val featvec = new SparseNominalFeatureVector(features) + val labelind = label_to_index(label) + (featvec, labelind) + } + + def get_csv_labeled_instances(source: Source) = { + val lines = source.getLines + for (line <- lines) yield { + val atts = line.split(",") + val label = atts.last + val features = atts.dropRight(1) + make_labeled_instance(features, label) + } + } +} + +trait LinearClassifier { + /** Return number of labels. */ + def number_of_labels: Int + + assert(number_of_labels >= 2) + + /** Classify a given instance, returning the class (a label from 0 to + * `number_of_labels`-1). */ + def classify(instance: FeatureVector): Int + + /** Score a given instance. Return a sequence of predicted scores, of + * the same length as the number of labels present. There is one score + * per label, and the maximum score corresponds to the single predicted + * label if such a prediction is desired. */ + def score(instance: FeatureVector): IndexedSeq[Double] +} + +/** + * A binary linear classifier, created from an array of weights. Normally + * created automatically by one of the trainer classes. + */ +class BinaryLinearClassifier ( + val weights: WeightVector +) extends LinearClassifier { + val number_of_labels = 2 + + /** Classify a given instance, returning the class, either 0 or 1. */ + def classify(instance: FeatureVector) = { + val sc = binary_score(instance) + if (sc > 0) 1 else 0 + } + + /** Score a given instance, returning a single real number. If the score + * is > 0, 1 is predicted, else 0. */ + def binary_score(instance: FeatureVector) = instance.dot_product(weights, 1) + + def score(instance: FeatureVector) = + IndexedSeq(0, binary_score(instance)) +} + +/** + * Class for training a linear classifier given a set of training instances and + * associated labels. + */ +trait LinearClassifierTrainer { + /** Create and initialize a vector of weights of length `len`. + * By default, initialized to all 0's, but could be changed. */ + def new_weights(len: Int) = new WeightVector(len) + + /** Create and initialize a vector of weights of length `len` to all 0's. */ + def new_zero_weights(len: Int) = new WeightVector(len) + + /** Check that all instances have the same length. */ + def check_sequence_lengths(data: Iterable[(FeatureVector, Int)]) { + val len = data.head._1.length + for ((inst, label) <- data) + assert(inst.length == len) + } + + /** Train a perceptron given a set of labeled instances. */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int): + LinearClassifier +} + +/** + * Class for training a binary perceptron given a set of training instances + * and associated labels. Use function application to train a new + * perceptron, e.g. `new BinaryPerceptronTrainer()(data)`. + * + * The basic perceptron training algorithm, in all its variants, works as + * follows: + * + * 1. We do multiple iterations, and in each iteration we loop through the + * training instances. + * 2. We process the training instances one-by-one, and potentially update + * the weight vector each time we process a training instance. (Hence, + * the algorithm is "online" or sequential, as opposed to an "off-line" + * or batch algorithm that attempts to satisfy some globally optimal + * function, e.g. maximize the joint probability of seeing the entire + * training set. An off-line iterative algorithm updates the weight + * function once per iteration in a way that attempts to improve the + * overall performance of the algorithm on the entire training set.) + * 3. Each time we see a training instance, we run the prediction algorithm + * to see how we would do on that training instance. In general, if + * we produce the right answer, we make no changes to the weights. + * However, if we produce the wrong answer, we change the weights in + * such a way that we will subsequently do better on the given training + * instance, generally by adding to the weight vector a simple scalar + * multiple (possibly negative) of the feature vector associated with the + * training instance in question. + * 4. We repeat until no further change (or at least, the total change is + * less than some small value), or until we've done a maximum number of + * iterations. + * @param error_threshold Threshold that the sum of all scale factors for + * all instances must be below in order for training to stop. In + * practice, in order to succeed with a threshold such as 1e-10, the + * actual sum of scale factors must be 0. + * @param max_iterations Maximum number of iterations. Training stops either + * when the threshold constraint succeeds of the maximum number of + * iterations is reached. + */ +abstract class BinaryPerceptronTrainer( + averaged: Boolean = false, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends LinearClassifierTrainer { + assert(error_threshold >= 0) + assert(max_iterations > 0) + + /** Check that the arguments passed in are kosher, and return an array of + * the weights to be learned. */ + def initialize(data: Iterable[(FeatureVector, Int)]) = { + check_sequence_lengths(data) + for ((inst, label) <- data) + assert(label == 0 || label == 1) + new_weights(data.head._1.length) + } + + /** Return the scale factor used for updating the weight vector to a + * new weight vector. + * + * @param inst Instance we are currently processing. + * @param label True label of that instance. + * @param score Predicted score on that instance. + */ + def get_scale_factor(inst: FeatureVector, label: Int, score: Double): + Double + + /** Train a binary perceptron given a set of labeled instances. */ + def apply(data: Iterable[(FeatureVector, Int)]) = { + val debug = false + val weights = initialize(data) + val avg_weights = new_zero_weights(weights.length) + def print_weights() { + Console.err.println("Weights: length=%s,max=%s,min=%s" format + (weights.length, weights.max, weights.min)) + // Console.err.println("Weights: [%s]" format weights.mkString(",")) + } + val len = weights.length + var iter = 0 + if (debug) + print_weights() + breakable { + while (true) { + iter += 1 + if (debug) + Console.err.println("Iteration %s" format iter) + var total_error = 0.0 + for ((inst, label) <- data) { + if (debug) + Console.err.println("Instance %s, label %s" format (inst, label)) + val score = inst.dot_product(weights, 1) + if (debug) + Console.err.println("Score %s" format score) + val scale = get_scale_factor(inst, label, score) + if (debug) + Console.err.println("Scale %s" format scale) + inst.update_weights(weights, scale, 1) + if (debug) + print_weights() + total_error += math.abs(scale) + } + if (averaged) + (0 until len).foreach(i => avg_weights(i) += weights(i)) + Console.err.println("Iteration %s, total_error %s" format (iter, total_error)) + if (total_error < error_threshold || iter >= max_iterations) + break + } + } + if (averaged) { + (0 until len).foreach(i => avg_weights(i) /= iter) + new BinaryLinearClassifier(avg_weights) + } else new BinaryLinearClassifier(weights) + } + + /** Train a perceptron given a set of labeled instances. */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + assert(num_classes == 2) + apply(data) + } +} + +/** Train a binary perceptron using the basic algorithm. See the above + * description of the general perceptron training algorithm. In this case, + * when we process an instance, if our prediction is wrong, we either + * push the weight up (if the correct prediction is positive) or down (if the + * correct prediction is negative), according to `alpha` times the feature + * vector of the instance we just evaluated on. + */ +class BasicBinaryPerceptronTrainer( + alpha: Double, + averaged: Boolean = false, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends BinaryPerceptronTrainer(averaged, error_threshold, max_iterations) { + def get_scale_factor(inst: FeatureVector, label: Int, score: Double) = { + val pred = if (score > 0) 1 else -1 + // Map from 0/1 to -1/1 + val symmetric_label = label*2 - 1 + alpha*(symmetric_label - pred) + } +} + +trait PassiveAggressivePerceptronTrainer { + val _variant: Int + val _aggressiveness_param: Double + + def compute_update_factor(loss: Double, sqmag: Double) = { + assert(_variant >= 0 && _variant <= 2) + assert(_aggressiveness_param > 0) + if (_variant == 0) + loss / sqmag + else if (_variant == 1) + _aggressiveness_param min (loss / sqmag) + else + loss / (sqmag + 1.0/(2.0*_aggressiveness_param)) + } + + /** Return set of "yes" labels associated with an instance. Currently only + * one yes label per instance, but this could be changed by redoing this + * function. */ + def yes_labels(label: Int, num_classes: Int) = + (0 until 0) ++ (label to label) + + /** Return set of "no" labels associated with an instance -- complement of + * the set of "yes" labels. */ + def no_labels(label: Int, num_classes: Int) = + (0 until label) ++ (label until num_classes) + +} + +/** Train a binary perceptron using the basic algorithm. See the above + * description of the general perceptron training algorithm. When processing + * a training instance, the algorithm is "passive" in the sense that it makes + * no changes if the prediction is correct (as in all perceptron training + * algorithms), and "aggressive" when a prediction is wrong in the sense that + * it changes the weight as much as necessary (but no more) to satisfy a + * given constraint. In this case, the idea is to change the weight as + * little as possible while ensuring that the prediction on the instance is + * not only correct but has a score that exceeds the minimally required score + * for correctness by at least as much as a given "margin". Hence, we + * essentially * try to progess as much as possible in each step (the + * constraint satisfaction) while also trying to preserve as much information + * as possible that was learned previously (the minimal constraint + * satisfaction). + * + * @param variant Variant 0 directly implements the algorithm just + * described. The other variants are designed for training sets that may + * not be linearly separable, and as a result are less aggressive. + * Variant 1 simply limits the total change to be no more than a given + * factor, while variant 2 scales the total change down relatively. In + * both cases, an "aggressiveness factor" needs to be given. + * @param aggressiveness_param As just described above. Higher values + * cause more aggressive changes to the weight vector during training. + */ +class PassiveAggressiveBinaryPerceptronTrainer( + variant: Int, + aggressiveness_param: Double = 20.0, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends BinaryPerceptronTrainer(false, error_threshold, max_iterations) + with PassiveAggressivePerceptronTrainer { + val _variant = variant; val _aggressiveness_param = aggressiveness_param + def get_scale_factor(inst: FeatureVector, label: Int, score: Double) = { + // Map from 0/1 to -1/1 + val symmetric_label = label*2 - 1 + val loss = 0.0 max (1.0 - symmetric_label*score) + val sqmag = inst.squared_magnitude(1) + compute_update_factor(loss, sqmag)*symmetric_label + } +} + +object Maxutil { + /** Return the argument producing the maximum when the function is applied + * to it. */ + def argmax[T](args: Iterable[T], fun: T => Double) = { + (args zip args.map(fun)).maxBy(_._2)._1 + } + + /** Return both the argument producing the maximum and the maximum value + * itself, when the function is applied to the arguments. */ + def argandmax[T](args: Iterable[T], fun: T => Double) = { + (args zip args.map(fun)).maxBy(_._2) + } + + /** Return the argument producing the minimum when the function is applied + * to it. */ + def argmin[T](args: Iterable[T], fun: T => Double) = { + (args zip args.map(fun)).minBy(_._2)._1 + } + + /** Return both the argument producing the minimum and the minimum value + * itself, when the function is applied to the arguments. */ + def argandmin[T](args: Iterable[T], fun: T => Double) = { + (args zip args.map(fun)).minBy(_._2) + } +} + +/** + * A multi-class perceptron with only a single set of weights for all classes. + * Note that the feature vector is passed the class in when a value is + * requested; it is assumed that class-specific features are handled + * automatically through this mechanism. + */ +class SingleWeightMultiClassLinearClassifier ( + val weights: WeightVector, + val number_of_labels: Int +) extends LinearClassifier { + + /** Classify a given instance, returning the class. */ + def classify(instance: FeatureVector) = + Maxutil.argmax[Int](0 until number_of_labels, score_class(instance, _)) + + /** Score a given instance for a single class. */ + def score_class(instance: FeatureVector, clazz: Int) = + instance.dot_product(weights, clazz) + + /** Score a given instance, returning an array of scores, one per class. */ + def score(instance: FeatureVector) = + (0 until number_of_labels).map(score_class(instance, _)).toArray +} + +/** + * A multi-class perceptron with a different set of weights for each class. + * Note that the feature vector is also passed the class in when a value is + * requested. + */ +class MultiClassLinearClassifier ( + val weights: IndexedSeq[WeightVector] +) extends LinearClassifier { + val number_of_labels = weights.length + + /** Classify a given instance, returning the class. */ + def classify(instance: FeatureVector) = + Maxutil.argmax[Int](0 until number_of_labels, score_class(instance, _)) + + /** Score a given instance for a single class. */ + def score_class(instance: FeatureVector, clazz: Int) = + instance.dot_product(weights(clazz), clazz) + + /** Score a given instance, returning an array of scores, one per class. */ + def score(instance: FeatureVector) = + (0 until number_of_labels).map(score_class(instance, _)).toArray +} + +/** + * Class for training a multi-class perceptron with only a single set of + * weights for all classes. + */ +abstract class SingleWeightMultiClassPerceptronTrainer( + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends LinearClassifierTrainer { + assert(error_threshold >= 0) + assert(max_iterations > 0) + + /** Check that the arguments passed in are kosher, and return an array of + * the weights to be learned. */ + def initialize(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + assert(num_classes >= 2) + for ((inst, label) <- data) + assert(label >= 0 && label < num_classes) + new_weights(data.head._1.length) + } + + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int): + SingleWeightMultiClassLinearClassifier +} + +/** + * Class for training a passive-aggressive multi-class perceptron with only a + * single set of weights for all classes. + */ +class PassiveAggressiveSingleWeightMultiClassPerceptronTrainer( + variant: Int, + aggressiveness_param: Double = 20.0, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends SingleWeightMultiClassPerceptronTrainer( + error_threshold, max_iterations +) with PassiveAggressivePerceptronTrainer { + val _variant = variant; val _aggressiveness_param = aggressiveness_param + + /** + * Actually train a passive-aggressive single-weight multi-class + * perceptron. Note that, although we're passed in a single correct label + * per instance, the code below is written so that it can handle a set of + * correct labels; you'd just have to change `yes_labels` and `no_labels` + * and pass the appropriate set of correct labels in. + */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + val weights = initialize(data, num_classes) + val len = weights.length + var iter = 0 + breakable { + while (iter < max_iterations) { + var total_error = 0.0 + for ((inst, label) <- data) { + def dotprod(x: Int) = inst.dot_product(weights, x) + val yeslabs = yes_labels(label, num_classes) + val nolabs = no_labels(label, num_classes) + val (r,rscore) = Maxutil.argandmin[Int](yeslabs, dotprod(_)) + val (s,sscore) = Maxutil.argandmax[Int](nolabs, dotprod(_)) + val margin = rscore - sscore + val loss = 0.0 max (1.0 - margin) + val sqmagdiff = inst.diff_squared_magnitude(r, s) + val scale = compute_update_factor(loss, sqmagdiff) + inst.update_weights(weights, scale, r) + inst.update_weights(weights, -scale, s) + total_error += math.abs(scale) + } + if (total_error < error_threshold) + break + iter += 1 + } + } + new SingleWeightMultiClassLinearClassifier(weights, num_classes) + } +} + +/** + * Class for training a multi-class perceptron with separate weights for each + * class. + */ +abstract class MultiClassPerceptronTrainer( + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends LinearClassifierTrainer { + assert(error_threshold >= 0) + assert(max_iterations > 0) + + /** Check that the arguments passed in are kosher, and return an array of + * the weights to be learned. */ + def initialize(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + assert(num_classes >= 2) + for ((inst, label) <- data) + assert(label >= 0 && label < num_classes) + val len = data.head._1.length + IndexedSeq[WeightVector]( + (for (i <- 0 until num_classes) yield new_weights(len)) :_*) + } + + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int): + MultiClassLinearClassifier +} + +/** + * Class for training a passive-aggressive multi-class perceptron with only a + * single set of weights for all classes. + */ +class PassiveAggressiveMultiClassPerceptronTrainer( + variant: Int, + aggressiveness_param: Double = 20.0, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends MultiClassPerceptronTrainer( + error_threshold, max_iterations +) with PassiveAggressivePerceptronTrainer { + val _variant = variant; val _aggressiveness_param = aggressiveness_param + + /** + * Actually train a passive-aggressive multi-weight multi-class + * perceptron. Note that, although we're passed in a single correct label + * per instance, the code below is written so that it can handle a set of + * correct labels; you'd just have to change `yes_labels` and `no_labels` + * and pass the appropriate set of correct labels in. + */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + val weights = initialize(data, num_classes) + val len = weights(0).length + var iter = 0 + breakable { + while (iter < max_iterations) { + var total_error = 0.0 + for ((inst, label) <- data) { + def dotprod(x: Int) = inst.dot_product(weights(x), x) + val yeslabs = yes_labels(label, num_classes) + val nolabs = no_labels(label, num_classes) + val (r,rscore) = Maxutil.argandmin[Int](yeslabs, dotprod(_)) + val (s,sscore) = Maxutil.argandmax[Int](nolabs, dotprod(_)) + val margin = rscore - sscore + val loss = 0.0 max (1.0 - margin) + val rmag = inst.squared_magnitude(r) + val smag = inst.squared_magnitude(s) + val sqmagdiff = rmag + smag + val scale = compute_update_factor(loss, sqmagdiff) + inst.update_weights(weights(r), scale, r) + inst.update_weights(weights(s), -scale, s) + total_error += math.abs(scale) + } + if (total_error < error_threshold) + break + iter += 1 + } + } + new MultiClassLinearClassifier(weights) + } +} + +/** + * Class for training a cost-sensitive multi-class perceptron with only a + * single set of weights for all classes. + */ +abstract class CostSensitiveSingleWeightMultiClassPerceptronTrainer( + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends SingleWeightMultiClassPerceptronTrainer { + assert(error_threshold >= 0) + assert(max_iterations > 0) + + def cost(correct: Int, predicted: Int): Double +} + +/** + * Class for training a passive-aggressive cost-sensitive multi-class + * perceptron with only a single set of weights for all classes. + */ +abstract class PassiveAggressiveCostSensitiveSingleWeightMultiClassPerceptronTrainer( + prediction_based: Boolean, + variant: Int, + aggressiveness_param: Double = 20.0, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends CostSensitiveSingleWeightMultiClassPerceptronTrainer( + error_threshold, max_iterations +) with PassiveAggressivePerceptronTrainer { + val _variant = variant; val _aggressiveness_param = aggressiveness_param + + /** + * Actually train a passive-aggressive single-weight multi-class + * perceptron. Note that, although we're passed in a single correct label + * per instance, the code below is written so that it can handle a set of + * correct labels; you'd just have to change `yes_labels` and `no_labels` + * and pass the appropriate set of correct labels in. + */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + val weights = initialize(data, num_classes) + val len = weights.length + var iter = 0 + val all_labs = 0 until num_classes + breakable { + while (iter < max_iterations) { + var total_error = 0.0 + for ((inst, label) <- data) { + def dotprod(x: Int) = inst.dot_product(weights, x) + val goldscore = dotprod(label) + val predlab = + if (prediction_based) + Maxutil.argmax[Int](all_labs, dotprod(_)) + else + Maxutil.argmax[Int](all_labs, + x=>(dotprod(x) - goldscore + math.sqrt(cost(label, x)))) + val loss = dotprod(predlab) - goldscore + + math.sqrt(cost(label, predlab)) + val sqmagdiff = inst.diff_squared_magnitude(label, predlab) + val scale = compute_update_factor(loss, sqmagdiff) + inst.update_weights(weights, scale, label) + inst.update_weights(weights, -scale, predlab) + total_error += math.abs(scale) + } + if (total_error < error_threshold) + break + iter += 1 + } + } + new SingleWeightMultiClassLinearClassifier(weights, num_classes) + } +} + +/** + * Class for training a cost-sensitive multi-class perceptron with a separate + * set of weights per class. + */ +abstract class CostSensitiveMultiClassPerceptronTrainer( + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends MultiClassPerceptronTrainer { + assert(error_threshold >= 0) + assert(max_iterations > 0) + + def cost(correct: Int, predicted: Int): Double +} + +/** + * Class for training a passive-aggressive cost-sensitive multi-class + * perceptron with a separate set of weights per class. + */ +abstract class PassiveAggressiveCostSensitiveMultiClassPerceptronTrainer( + prediction_based: Boolean, + variant: Int, + aggressiveness_param: Double = 20.0, + error_threshold: Double = 1e-10, + max_iterations: Int = 1000 +) extends CostSensitiveMultiClassPerceptronTrainer( + error_threshold, max_iterations +) with PassiveAggressivePerceptronTrainer { + val _variant = variant; val _aggressiveness_param = aggressiveness_param + + /** + * Actually train a passive-aggressive single-weight multi-class + * perceptron. Note that, although we're passed in a single correct label + * per instance, the code below is written so that it can handle a set of + * correct labels; you'd just have to change `yes_labels` and `no_labels` + * and pass the appropriate set of correct labels in. + */ + def apply(data: Iterable[(FeatureVector, Int)], num_classes: Int) = { + val weights = initialize(data, num_classes) + val len = weights(0).length + var iter = 0 + val all_labs = 0 until num_classes + breakable { + while (iter < max_iterations) { + var total_error = 0.0 + for ((inst, label) <- data) { + def dotprod(x: Int) = inst.dot_product(weights(x), x) + val goldscore = dotprod(label) + val predlab = + if (prediction_based) + Maxutil.argmax[Int](all_labs, dotprod(_)) + else + Maxutil.argmax[Int](all_labs, + x=>(dotprod(x) - goldscore + math.sqrt(cost(label, x)))) + val loss = dotprod(predlab) - goldscore + + math.sqrt(cost(label, predlab)) + val rmag = inst.squared_magnitude(label) + val smag = inst.squared_magnitude(predlab) + val sqmagdiff = rmag + smag + val scale = compute_update_factor(loss, sqmagdiff) + inst.update_weights(weights(label), scale, label) + inst.update_weights(weights(predlab), -scale, predlab) + total_error += math.abs(scale) + } + if (total_error < error_threshold) + break + iter += 1 + } + } + new MultiClassLinearClassifier(weights) + } +} diff --git a/src/main/scala/opennlp/fieldspring/perceptron/package.scala b/src/main/scala/opennlp/fieldspring/perceptron/package.scala new file mode 100644 index 0000000..ca07bfc --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/perceptron/package.scala @@ -0,0 +1,24 @@ +/////////////////////////////////////////////////////////////////////////////// +// package.scala +// +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring + +package object perceptron { + type WeightVector = Array[Double] +} + diff --git a/src/main/scala/opennlp/fieldspring/poligrounder/Poligrounder.scala b/src/main/scala/opennlp/fieldspring/poligrounder/Poligrounder.scala new file mode 100644 index 0000000..e2a9f18 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/poligrounder/Poligrounder.scala @@ -0,0 +1,246 @@ +/////////////////////////////////////////////////////////////////////////////// +// Poligrounder.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +/* + +Basic idea: + +1. We specify a corpus and times to compare, e.g. + +poligrounder -i twitter-spritzer --from 201203051627/-1h --to 201203051801/3h + +will operate on the twitter-spritzer corpus and compare the hour +directly preceding March 5, 2012, 4:27pm with the three hours directly +following March 5, 2012, 6:01pm. + +Time can be specified either as simple absolute times (e.g. 201203051627) +or as a combination of a time and an offset, e.g. 201203051800-10h3m5s means +10 hours 3 minutes 5 seconds prior to 201203051800 (March 5, 2012, 6:00pm). +Absolute times are specified as YYYYMMDD[hh[mm[ss]]], i.e. a specific day +must be given, with optional hours, minutes or seconds, defaulting to the +earliest possible time when a portion is left out. Offsets and lengths +are specified using one or more combinations of number (possibly floating +point) and designator: + +s = second +m or mi = minute +h = hour +d = day +mo = month +y = year + +2. There may be different comparison methods, triggered by different command +line arguments. + +3. Currently we have code in `gridlocate` that reads documents in from a +corpus and amalgamates them using a grid of some sort. We can reuse this +to amalgate documents by time. E.g. if we want to compare two specific time +periods, we will have two corresponding cells, one for each period, and +throw away the remaining documents. In other cases where we might want to +look at distributions over a period of time, we will have more cells, at +(possibly more or less) regular intervals. + +*/ +package opennlp.fieldspring.poligrounder + +import util.matching.Regex +import util.Random +import math._ +import collection.mutable + +import opennlp.fieldspring.{util=>tgutil} +import tgutil.argparser._ +import tgutil.collectionutil._ +import tgutil.textdbutil._ +import tgutil.distances._ +import tgutil.experiment._ +import tgutil.ioutil.{FileHandler, LocalFileHandler} +import tgutil.osutil.output_resource_usage +import tgutil.printutil.errprint +import tgutil.timeutil._ + +import opennlp.fieldspring.gridlocate._ +import GridLocateDriver.Debug._ + +import opennlp.fieldspring.worddist.{WordDist,WordDistFactory} +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +/* + +This module is the main driver module for the Poligrounder subproject. +See GridLocate.scala. + +*/ + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +/** + * Class retrieving command-line arguments or storing programmatic + * configuration parameters. + * + * @param parser If specified, should be a parser for retrieving the + * value of command-line arguments from the command line. Provided + * that the parser has been created and initialized by creating a + * previous instance of this same class with the same parser (a + * "shadow field" class), the variables below will be initialized with + * the values given by the user on the command line. Otherwise, they + * will be initialized with the default values for the parameters. + * Because they are vars, they can be freely set to other values. + * + */ +class PoligrounderParameters(parser: ArgParser = null) extends + GridLocateParameters(parser) { + var from = ap.option[String]("f", "from", + help = """Chunk of start time to compare.""") + + var to = ap.option[String]("t", "to", + help = """Chunk of end time to compare.""") + + var min_prob = ap.option[Double]("min-prob", "mp", default = 0.0, + help = """Mininum probability when comparing distributions. + Default is 0.0, which means no restrictions.""") + + var max_items = ap.option[Int]("max-items", "mi", default = 200, + help = """Maximum number of items (words or n-grams) to output when + comparing distributions. Default is %default. This applies separately + to those items that have increased and decreased, meaning the total + number counting both kinds may be as much as twice the maximum.""") + + var ideological_user_corpus = ap.option[String]( + "ideological-user-corpus", "iuc", + help="""File containing corpus output from FindPolitical, listing + users and associated ideologies.""") + + var ideological_users: Map[String, Double] = _ + var ideological_users_liberal: Map[String, Double] = _ + var ideological_users_conservative: Map[String, Double] = _ + var ideological_categories: Seq[String] = _ + + // Unused, determined by --ideological-user-corpus. +// var mode = ap.option[String]("m", "mode", +// default = "combined", +// choices = Seq("combined", "ideo-users"), +// help = """How to compare distributions. Possible values are +// +// 'combined': For a given time period, combine all users into a single +// distribution. +// +// 'ideo-users': Retrieve the ideology of the users and use that to +// separate the users into liberal and conservative, and compare those +// separately.""") +} + +/** + * A simple field-text file processor that just records the users and ideology. + * + * @param suffix Suffix used to select document metadata files in a directory + */ +class IdeoUserFileProcessor extends + TextDBProcessor[(String, Double)]("ideo-users") { + def handle_row(fieldvals: Seq[String]) = { + val user = schema.get_field(fieldvals, "user") + val ideology = + schema.get_field(fieldvals, "ideology").toDouble + Some((user, ideology)) + } +} + +class PoligrounderDriver extends + GridLocateDriver with StandaloneExperimentDriverStats { + type TParam = PoligrounderParameters + type TRunRes = Unit + type TDoc = TimeDocument + type TCell = TimeCell + type TGrid = TimeCellGrid + type TDocTable = TimeDocumentTable + + var degrees_per_cell = 0.0 + var from_chunk: (Long, Long) = _ + var to_chunk: (Long, Long) = _ + + override def handle_parameters() { + def parse_interval(param: String) = { + parse_date_interval(param) match { + case (Some((start, end)), "") => (start, end) + case (None, errmess) => param_error(errmess) + } + } + from_chunk = parse_interval(params.from) + to_chunk = parse_interval(params.to) + + if (params.ideological_user_corpus != null) { + val processor = new IdeoUserFileProcessor + val users = + processor.read_textdb(new LocalFileHandler, + params.ideological_user_corpus).flatten.toMap + params.ideological_users = users + params.ideological_users_liberal = + users filter { case (u, ideo) => ideo < 0.33 } + params.ideological_users_conservative = + users filter { case (u, ideo) => ideo > 0.66 } + params.ideological_categories = Seq("liberal", "conservative") + } else + params.ideological_categories = Seq("all") + + super.handle_parameters() + } + + override protected def initialize_word_dist_suffix() = { + super.initialize_word_dist_suffix() + "-tweets" + } + + protected def initialize_document_table(word_dist_factory: WordDistFactory) = { + new TimeDocumentTable(this, word_dist_factory) + } + + protected def initialize_cell_grid(table: TimeDocumentTable) = { + if (params.ideological_user_corpus == null) + new TimeCellGrid(from_chunk, to_chunk, Seq("all"), x => "all", table) + else + new TimeCellGrid(from_chunk, to_chunk, Seq("liberal", "conservative"), + x => { + if (params.ideological_users_liberal contains x.user) + "liberal" + else if (params.ideological_users_conservative contains x.user) + "conservative" + else + null + }, table) + } + + def run_after_setup() { + if (params.ideological_user_corpus == null) + DistributionComparer.compare_cells_2way( + cell_grid.asInstanceOf[TimeCellGrid], "all", + params.min_prob, params.max_items) + else + DistributionComparer.compare_cells_4way( + cell_grid.asInstanceOf[TimeCellGrid], "liberal", "conservative", + params.min_prob, params.max_items) + } +} + +object PoligrounderApp extends GridLocateApp("poligrounder") { + type TDriver = PoligrounderDriver + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() +} + diff --git a/src/main/scala/opennlp/fieldspring/poligrounder/TimeCell.scala b/src/main/scala/opennlp/fieldspring/poligrounder/TimeCell.scala new file mode 100644 index 0000000..d5ca0fb --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/poligrounder/TimeCell.scala @@ -0,0 +1,358 @@ +/////////////////////////////////////////////////////////////////////////////// +// TimeCell.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.poligrounder + +import math._ + +import collection.mutable + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textutil.format_float +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.printutil.{errout, errprint} +import opennlp.fieldspring.util.experiment._ + +import opennlp.fieldspring.gridlocate._ +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.worddist._ +import opennlp.fieldspring.worddist.WordDist.memoizer._ + +///////////////////////////////////////////////////////////////////////////// +// Cells in a grid // +///////////////////////////////////////////////////////////////////////////// + +class TimeCell( + from: Long, + to: Long, + cell_grid: TimeCellGrid +) extends GeoCell[TimeCoord, TimeDocument](cell_grid) { + /** + * Return the boundary of the cell as a pair of coordinates, specifying the + * beginning and end. + */ + def get_boundary() = (from, to) + + def contains(time: TimeCoord) = from <= time.millis && time.millis < to + + def get_center_coord = TimeCoord((to + from)/2) + + def describe_location = + "%s - %s" format (format_time(from), format_time(to)) + + def describe_indices = + "%s/%s" format (format_time(from), format_interval(to - from)) +} + +/** + * Pair of intervals for comparing before and after for a particular class of + * tweets (e.g. those from liberal or conservative users). + */ +class TimeCellPair( + val category: String, + val before_chunk: (Long, Long), + val after_chunk: (Long, Long), + val grid: TimeCellGrid +) { + val before_cell = new TimeCell(before_chunk._1, before_chunk._2, grid) + val after_cell = new TimeCell(after_chunk._1, after_chunk._2, grid) + + def find_best_cell_for_coord(coord: TimeCoord) = { + if (before_cell contains coord) { + if (debug("cell")) + errprint("Putting document with coord %s in category %s, before-chunk %s", + category, coord, before_chunk) + before_cell + } else if (after_cell contains coord) { + if (debug("cell")) + errprint("Putting document with coord %s in category %s, after-chunk %s", + category, coord, after_chunk) + after_cell + } else { + if (debug("cell")) + errprint("Skipping document with category %s, coord %s because not in either before-chunk %s or after-chunk %s", + category, coord, before_chunk, after_chunk) + null + } + } +} + +/** + * Class for a "grid" of time intervals. + */ +class TimeCellGrid( + before_chunk: (Long, Long), + after_chunk: (Long, Long), + categories: Seq[String], + category_of_doc: TimeDocument => String, + override val table: TimeDocumentTable +) extends CellGrid[TimeCoord, TimeDocument, TimeCell](table) { + + val pairs = categories.map { + x => (x, new TimeCellPair(x, before_chunk, after_chunk, this)) + }.toMap + var total_num_cells = 2 * categories.length + + def find_best_cell_for_document(doc: TimeDocument, + create_non_recorded: Boolean) = { + assert(!create_non_recorded) + val category = category_of_doc(doc) + if (category != null) + pairs(category).find_best_cell_for_coord(doc.coord) + else { + if (debug("cell")) + errprint("Skipping document %s because not in any category", doc) + null + } + } + + def add_document_to_cell(doc: TimeDocument) { + val cell = find_best_cell_for_document(doc, false) + if (cell != null) + cell.add_document(doc) + } + + def iter_nonempty_cells(nonempty_word_dist: Boolean = false) = { + for { + category <- categories + pair = pairs(category) + v <- List(pair.before_cell, pair.after_cell) + val empty = ( + if (nonempty_word_dist) v.combined_dist.is_empty_for_word_dist() + else v.combined_dist.is_empty()) + if (!empty) + } yield v + } + + def initialize_cells() { + for (category <- categories) { + val pair = pairs(category) + pair.before_cell.finish() + pair.after_cell.finish() + /* FIXME!!! + 1. Should this be is_empty or is_empty_for_word_dist? Do we even need + this distinction? + 2. Computation of num_non_empty_cells should happen automatically! + */ + if (!pair.before_cell.combined_dist.is_empty_for_word_dist) + num_non_empty_cells += 1 + if (!pair.after_cell.combined_dist.is_empty_for_word_dist) + num_non_empty_cells += 1 + for ((cell, name) <- + Seq((pair.before_cell, "before"), (pair.after_cell, "after"))) { + val comdist = cell.combined_dist + errprint("Number of documents in %s-chunk: %s", name, + comdist.num_docs_for_word_dist) + errprint("Number of types in %s-chunk: %s", name, + comdist.word_dist.model.num_types) + errprint("Number of tokens in %s-chunk: %s", name, + comdist.word_dist.model.num_tokens) + } + } + } +} + +abstract class DistributionComparer(min_prob: Double, max_items: Int) { + type Item + type Dist <: WordDist + + def get_pair(grid: TimeCellGrid, category: String) = + grid.pairs(category) + def get_dist(cell: TimeCell): Dist = + cell.combined_dist.word_dist.asInstanceOf[Dist] + def get_keys(dist: Dist) = dist.model.iter_keys.toSet + def lookup_item(dist: Dist, item: Item): Double + def format_item(item: Item): String + + def compare_cells_2way(grid: TimeCellGrid, category: String) { + val before_dist = get_dist(get_pair(grid, category).before_cell) + val after_dist = get_dist(get_pair(grid, category).after_cell) + + val itemdiff = + for { + rawitem <- get_keys(before_dist) ++ get_keys(after_dist) + item = rawitem.asInstanceOf[Item] + p = lookup_item(before_dist, item) + q = lookup_item(after_dist, item) + if p >= min_prob || q >= min_prob + } yield (item, before_dist.dunning_log_likelihood_2x1(item.asInstanceOf[before_dist.Item], after_dist), q - p) + + println("Items by 2-way log-likelihood for category '%s':" format category) + for ((item, dunning, prob) <- + itemdiff.toSeq.sortWith(_._2 > _._2).take(max_items)) { + println("%7s: %-20s (%8s, %8s = %8s - %8s)" format + (format_float(dunning), + format_item(item), + if (prob > 0) "increase" else "decrease", format_float(prob), + format_float(lookup_item(before_dist, item)), + format_float(lookup_item(after_dist, item)) + )) + } + println("") + + val diff_up = itemdiff filter (_._3 > 0) + val diff_down = itemdiff filter (_._3 < 0) map (x => (x._1, x._2, x._3.abs)) + def print_diffs(diffs: Iterable[(Item, Double, Double)], + incdec: String, updown: String) { + println("") + println("Items that %s in probability:" format incdec) + println("------------------------------------") + for ((item, dunning, prob) <- + diffs.toSeq.sortWith(_._3 > _._3).take(max_items)) { + println("%s: %s - %s = %s%s (LL %s)" format + (format_item(item), + format_float(lookup_item(before_dist, item)), + format_float(lookup_item(after_dist, item)), + updown, format_float(prob), + format_float(dunning))) + } + println("") + } + print_diffs(diff_up, "increased", "+") + print_diffs(diff_down, "decreased", "-") + + } + + def compare_cells_4way(grid: TimeCellGrid, category1: String, + category2: String) { + val before_dist_1 = get_dist(get_pair(grid, category1).before_cell) + val after_dist_1 = get_dist(get_pair(grid, category1).after_cell) + val before_dist_2 = get_dist(get_pair(grid, category2).before_cell) + val after_dist_2 = get_dist(get_pair(grid, category2).after_cell) + + val cat13 = category1.slice(0,3) + val cat23 = category2.slice(0,3) + val cat18 = category1.slice(0,8) + val cat28 = category2.slice(0,8) + + val itemdiff = + for { + rawitem <- get_keys(before_dist_1) ++ get_keys(after_dist_1) ++ + get_keys(before_dist_2) ++ get_keys(after_dist_2) + item = rawitem.asInstanceOf[Item] + p1 = lookup_item(before_dist_1, item) + q1 = lookup_item(after_dist_1, item) + p2 = lookup_item(before_dist_2, item) + q2 = lookup_item(after_dist_2, item) + if p1 >= min_prob || q1 >= min_prob || p2 >= min_prob || q2 >= min_prob + abs1 = q1 - p1 + abs2 = q2 - p2 + pct1 = (q1 - p1)/p1*100 + pct2 = (q2 - p2)/p2*100 + change = { + if (pct1 > 0 && pct2 <= 0) "+"+cat13 + else if (pct1 <= 0 && pct2 > 0) "+"+cat23 + else if (pct1 < 0 && pct2 < 0) "-both" + else "+both" + } + } yield (item, before_dist_1.dunning_log_likelihood_2x2( + item.asInstanceOf[before_dist_1.Item], + after_dist_1, before_dist_2, after_dist_2), + p1, q1, p2, q2, abs1, abs2, pct1, pct2, change + ) + + println("%24s change %7s%% (+-%7.7s) / %7s%% (+-%7.7s)" format ( + "Items by 4-way log-lhood:", cat13, cat18, cat23, cat28)) + def fmt(x: Double) = format_float(x, include_plus = true) + for ((item, dunning, p1, q1, p2, q2, abs1, abs2, pct1, pct2, change) <- + itemdiff.toSeq.sortWith(_._2 > _._2).take(max_items)) { + println("%7s: %-15.15s %6s: %7s%% (%9s) / %7s%% (%9s)" format + (format_float(dunning), + format_item(item), + change, + fmt(pct1), fmt(abs1), + fmt(pct2), fmt(abs2) + )) + } + println("") + + type ItemDunProb = + (Item, Double, Double, Double, Double, Double, Double, + Double, Double, Double, String) + val diff_cat1 = itemdiff filter (_._11 == "+"+cat13) + val diff_cat2 = itemdiff filter (_._11 == "+"+cat23) + def print_diffs(diffs: Iterable[ItemDunProb], + category: String, updown: String) { + def print_diffs_1(msg: String, + comparefun: (ItemDunProb, ItemDunProb) => Boolean) { + println("") + println("%s leaning towards %8s:" format (msg, category)) + println("----------------------------------------------------------") + for ((item, dunning, p1, q1, p2, q2, abs1, abs2, pct1, pct2, change) <- + diffs.toSeq.sortWith(comparefun).take(max_items)) { + println("%-15.15s = LL %7s (%%chg-diff %7s%% = %7s%% - %7s%%)" format + (format_item(item), format_float(dunning), + fmt(pct1 - pct2), fmt(pct1), fmt(pct2))) + } + } + print_diffs_1("Items by 4-way log-lhood with difference", _._2 > _._2) + // print_diffs_1("Items with greatest difference", _._3 > _._3) + } + print_diffs(diff_cat1, category1, "+") + print_diffs(diff_cat2, category2, "-") + println("") + compare_cells_2way(grid, category1) + println("") + compare_cells_2way(grid, category2) + } +} + +class UnigramComparer(min_prob: Double, max_items: Int) extends + DistributionComparer(min_prob, max_items) { + type Item = Word + type Dist = UnigramWordDist + + def lookup_item(dist: Dist, item: Item) = dist.lookup_word(item) + def format_item(item: Item) = unmemoize_string(item) +} + +class NgramComparer(min_prob: Double, max_items: Int) extends + DistributionComparer(min_prob, max_items) { + import NgramStorage.Ngram + type Item = Ngram + type Dist = NgramWordDist + + def lookup_item(dist: Dist, item: Item) = dist.lookup_ngram(item) + def format_item(item: Item) = item mkString " " +} + +object DistributionComparer { + def get_comparer(grid: TimeCellGrid, category: String, min_prob: Double, + max_items: Int) = + grid.pairs(category).before_cell.combined_dist.word_dist match { + case _: UnigramWordDist => + new UnigramComparer(min_prob, max_items) + case _: NgramWordDist => + new NgramComparer(min_prob, max_items) + case _ => throw new IllegalArgumentException("Don't know how to compare this type of word distribution") + } + + def compare_cells_2way(grid: TimeCellGrid, category: String, min_prob: Double, + max_items: Int) { + val comparer = get_comparer(grid, category, min_prob, max_items) + comparer.compare_cells_2way(grid, category) + } + + def compare_cells_4way(grid: TimeCellGrid, category1: String, + category2: String, min_prob: Double, max_items: Int) { + val comparer = get_comparer(grid, category1, min_prob, max_items) + comparer.compare_cells_4way(grid, category1, category2) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/poligrounder/TimeDocument.scala b/src/main/scala/opennlp/fieldspring/poligrounder/TimeDocument.scala new file mode 100644 index 0000000..4142a2c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/poligrounder/TimeDocument.scala @@ -0,0 +1,82 @@ +/////////////////////////////////////////////////////////////////////////////// +// TimeDocument.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.poligrounder + +import collection.mutable + +import opennlp.fieldspring.util.distances._ +import opennlp.fieldspring.util.textdbutil.Schema +import opennlp.fieldspring.util.printutil._ + +import opennlp.fieldspring.gridlocate.{DistDocument,DistDocumentTable,CellGrid} +import opennlp.fieldspring.gridlocate.DistDocumentConverters._ + +import opennlp.fieldspring.worddist.WordDistFactory + +class TimeDocument( + schema: Schema, + table: TimeDocumentTable +) extends DistDocument[TimeCoord](schema, table) { + var coord: TimeCoord = _ + var user: String = _ + def has_coord = coord != null + def title = if (coord != null) coord.toString else "unknown time" + + def struct = + + { + if (has_coord) + { coord } + } + + + override def set_field(name: String, value: String) { + name match { + case "min-timestamp" => coord = get_x_or_null[TimeCoord](value) + case "user" => user = value + case _ => super.set_field(name, value) + } + } + + def coord_as_double(coor: TimeCoord) = coor match { + case null => Double.NaN + case TimeCoord(x) => x.toDouble / 1000 + } + + def distance_to_coord(coord2: TimeCoord) = { + (coord_as_double(coord2) - coord_as_double(coord)).abs + } + def output_distance(dist: Double) = "%s seconds" format dist +} + +/** + * A DistDocumentTable specifically for documents with coordinates described + * by a TimeCoord. + * We delegate the actual document creation to a subtable specific to the + * type of corpus (e.g. Wikipedia or Twitter). + */ +class TimeDocumentTable( + override val driver: PoligrounderDriver, + word_dist_factory: WordDistFactory +) extends DistDocumentTable[TimeCoord, TimeDocument, TimeCellGrid]( + driver, word_dist_factory +) { + def create_document(schema: Schema) = new TimeDocument(schema, this) +} + diff --git a/src/main/scala/opennlp/fieldspring/postprocess/DocumentPinKMLGenerator.scala b/src/main/scala/opennlp/fieldspring/postprocess/DocumentPinKMLGenerator.scala new file mode 100644 index 0000000..ecffbe8 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/DocumentPinKMLGenerator.scala @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////////////// +// DocumentPinKMLGenerator.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +import java.io._ +import javax.xml.datatype._ +import javax.xml.stream._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util.KMLUtil +import opennlp.fieldspring.tr.util.LogUtil +import scala.collection.JavaConversions._ +import org.clapper.argot._ + +object DocumentPinKMLGenerator { + + val factory = XMLOutputFactory.newInstance + val rand = new scala.util.Random + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.DocumentPinKMLGenerator", preUsage = Some("Fieldspring")) + val inFile = parser.option[String](List("i", "input"), "input", "input file") + val kmlOutFile = parser.option[String](List("k", "kml"), "kml", "kml output file") + val tokenIndexOffset = parser.option[Int](List("o", "offset"), "offset", "token index offset") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(inFile.value == None) { + println("You must specify an input file via -i.") + sys.exit(0) + } + if(kmlOutFile.value == None) { + println("You must specify a KML output file via -k.") + sys.exit(0) + } + val offset = if(tokenIndexOffset.value != None) tokenIndexOffset.value.get else 0 + + val outFile = new File(kmlOutFile.value.get) + val stream = new BufferedOutputStream(new FileOutputStream(outFile)) + val out = factory.createXMLStreamWriter(stream, "UTF-8") + + KMLUtil.writeHeader(out, inFile.value.get) + + for(line <- scala.io.Source.fromFile(inFile.value.get).getLines) { + val tokens = line.split("\t") + if(tokens.length >= 3+offset) { + val docName = tokens(1+offset) + val coordTextPair = tokens(2+offset).split(",") + val coord = Coordinate.fromDegrees(coordTextPair(0).toDouble, coordTextPair(1).toDouble) + KMLUtil.writePinPlacemark(out, docName, coord) + } + } + + KMLUtil.writeFooter(out) + + out.close + } +} diff --git a/src/main/scala/opennlp/fieldspring/postprocess/DocumentRankerByError.scala b/src/main/scala/opennlp/fieldspring/postprocess/DocumentRankerByError.scala new file mode 100644 index 0000000..9464fff --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/DocumentRankerByError.scala @@ -0,0 +1,58 @@ +/////////////////////////////////////////////////////////////////////////////// +// DocumentRankerByError.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +// This program takes a log file and outputs the document names to standard out, ranked by prediction error. + +import org.clapper.argot._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util.LogUtil + +object DocumentRankerByError { + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.DocumentRankerByError", preUsage = Some("Fieldspring")) + val logFile = parser.option[String](List("l", "log"), "log", "log input file") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(logFile.value == None) { + println("You must specify a log input file via -l.") + sys.exit(0) + } + + val docsAndErrors:List[(String, Double, Coordinate, Coordinate)] = + (for(pe <- LogUtil.parseLogFile(logFile.value.get)) yield { + val dist = pe.trueCoord.distanceInKm(pe.predCoord) + + (pe.docName, dist, pe.trueCoord, pe.predCoord) + }).sortWith((x, y) => x._2 < y._2) + + for((docName, dist, trueCoord, predCoord) <- docsAndErrors) { + println(docName+"\t"+dist+"\t"+trueCoord+"\t"+predCoord) + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/postprocess/ErrorKMLGenerator.scala b/src/main/scala/opennlp/fieldspring/postprocess/ErrorKMLGenerator.scala new file mode 100644 index 0000000..4713794 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/ErrorKMLGenerator.scala @@ -0,0 +1,86 @@ +/////////////////////////////////////////////////////////////////////////////// +// ErrorKMLGenerator.scala +// +// Copyright (C) 2011, 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +import java.io._ +import javax.xml.datatype._ +import javax.xml.stream._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util.KMLUtil +import opennlp.fieldspring.tr.util.LogUtil +import scala.collection.JavaConversions._ +import org.clapper.argot._ + +object ErrorKMLGenerator { + + val factory = XMLOutputFactory.newInstance + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.ErrorKMLGenerator", preUsage = Some("Fieldspring")) + val logFile = parser.option[String](List("l", "log"), "log", "log input file") + val kmlOutFile = parser.option[String](List("k", "kml"), "kml", "kml output file") + val usePred = parser.option[String](List("p", "pred"), "pred", "show predicted rather than gold locations") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(logFile.value == None) { + println("You must specify a log input file via -l.") + sys.exit(0) + } + if(kmlOutFile.value == None) { + println("You must specify a KML output file via -k.") + sys.exit(0) + } + + val rand = new scala.util.Random + + val outFile = new File(kmlOutFile.value.get) + val stream = new BufferedOutputStream(new FileOutputStream(outFile)) + val out = factory.createXMLStreamWriter(stream, "UTF-8") + + KMLUtil.writeHeader(out, "errors-at-"+(if(usePred.value == None) "true" else "pred")) + + for(pe <- LogUtil.parseLogFile(logFile.value.get)) { + val predCoord = Coordinate.fromDegrees(pe.predCoord.getLatDegrees() + (rand.nextDouble() - 0.5) * .1, + pe.predCoord.getLngDegrees() + (rand.nextDouble() - 0.5) * .1); + + //val dist = trueCoord.distanceInKm(predCoord) + + val coord1 = if(usePred.value == None) pe.trueCoord else predCoord + val coord2 = if(usePred.value == None) predCoord else pe.trueCoord + + KMLUtil.writeArcLinePlacemark(out, coord1, coord2); + KMLUtil.writePinPlacemark(out, pe.docName, coord1, "yellow"); + //KMLUtil.writePlacemark(out, pe.docName, coord1, KMLUtil.RADIUS); + KMLUtil.writePinPlacemark(out, pe.docName, coord2, "blue"); + //KMLUtil.writePolygon(out, pe.docName, coord, KMLUtil.SIDES, KMLUtil.RADIUS, math.log(dist) * KMLUtil.BARSCALE/2) + } + + KMLUtil.writeFooter(out) + + out.close + } +} diff --git a/src/main/scala/opennlp/fieldspring/postprocess/KNNKMLGenerator.scala b/src/main/scala/opennlp/fieldspring/postprocess/KNNKMLGenerator.scala new file mode 100644 index 0000000..867c036 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/KNNKMLGenerator.scala @@ -0,0 +1,96 @@ +/////////////////////////////////////////////////////////////////////////////// +// KNNKMLGenerator.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +import java.io._ +import javax.xml.datatype._ +import javax.xml.stream._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util.KMLUtil +import opennlp.fieldspring.tr.util.LogUtil +import scala.collection.JavaConversions._ +import org.clapper.argot._ + +object KNNKMLGenerator { + + val factory = XMLOutputFactory.newInstance + val rand = new scala.util.Random + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.KNNKMLGenerator", preUsage = Some("Fieldspring")) + val logFile = parser.option[String](List("l", "log"), "log", "log input file") + val kmlOutFile = parser.option[String](List("k", "kml"), "kml", "kml output file") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(logFile.value == None) { + println("You must specify a log input file via -l.") + sys.exit(0) + } + if(kmlOutFile.value == None) { + println("You must specify a KML output file via -k.") + sys.exit(0) + } + + val outFile = new File(kmlOutFile.value.get) + val stream = new BufferedOutputStream(new FileOutputStream(outFile)) + val out = factory.createXMLStreamWriter(stream, "UTF-8") + + KMLUtil.writeHeader(out, "knn") + + for(pe <- LogUtil.parseLogFile(logFile.value.get)) { + + val jPredCoord = jitter(pe.predCoord) + + KMLUtil.writePinPlacemark(out, pe.docName, pe.trueCoord) + KMLUtil.writePinPlacemark(out, pe.docName, jPredCoord, "blue") + KMLUtil.writePlacemark(out, "#1", jPredCoord, KMLUtil.RADIUS*10) + KMLUtil.writeLinePlacemark(out, pe.trueCoord, jPredCoord, "redLine") + + for((neighbor, rank) <- pe.neighbors) { + val jNeighbor = jitter(neighbor) + /*if(rank == 1) { + KMLUtil.writePlacemark(out, "#1", neighbor, KMLUtil.RADIUS*10) + }*/ + if(rank != 1) { + KMLUtil.writePlacemark(out, "#"+rank, jNeighbor, KMLUtil.RADIUS*10) + KMLUtil.writePinPlacemark(out, pe.docName, jNeighbor, "green") + /*if(!neighbor.equals(pe.predCoord))*/ KMLUtil.writeLinePlacemark(out, pe.trueCoord, jNeighbor) + } + } + + } + + KMLUtil.writeFooter(out) + + out.close + } + + def jitter(coord:Coordinate): Coordinate = { + Coordinate.fromDegrees(coord.getLatDegrees() + (rand.nextDouble() - 0.5) * .1, + coord.getLngDegrees() + (rand.nextDouble() - 0.5) * .1); + } +} diff --git a/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgError.scala b/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgError.scala new file mode 100644 index 0000000..0d50e3c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgError.scala @@ -0,0 +1,110 @@ +/////////////////////////////////////////////////////////////////////////////// +// WordRankerByAvgError.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +import opennlp.fieldspring.util.Twokenize +import org.clapper.argot._ +import java.io._ + +object WordRankerByAvgError { + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.WordRankerByAvgError", preUsage = Some("Fieldspring")) + val corpusFile = parser.option[String](List("i", "input"), "list", "corpus input file") + val listFile = parser.option[String](List("l", "list"), "list", "list input file") + val docThresholdOption = parser.option[Int](List("t", "threshold"), "threshold", "document frequency threshold") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(corpusFile.value == None) { + println("You must specify a corpus input file via -i.") + sys.exit(0) + } + if(listFile.value == None) { + println("You must specify a list input file via -l.") + sys.exit(0) + } + + val docThreshold = if(docThresholdOption.value == None) 10 else docThresholdOption.value.get + + val docNamesAndErrors:Map[String, Double] = scala.io.Source.fromFile(listFile.value.get).getLines. + map(_.split("\t")).map(p => (p(0), p(1).toDouble)).toMap + + val in = new BufferedReader( + new InputStreamReader(new FileInputStream(corpusFile.value.get), "UTF8")) + + val wordsToErrors = new scala.collection.mutable.HashMap[String, Double] + val wordsToDocNames = new scala.collection.mutable.HashMap[String, scala.collection.immutable.HashSet[String]] + val wordDocToCount = new scala.collection.mutable.HashMap[(String, String), Int] + val docSizes = new scala.collection.mutable.HashMap[String, Int] + //val docNames = new scala.collection.mutable.HashSet[String] + + var line:String = in.readLine + + while(line != null) { + val tokens = line.split("\t") + if(tokens.length >= 6) { + val docName = tokens(0) + if(docNamesAndErrors.contains(docName)) { + //docNames += docName + val error = docNamesAndErrors(docName) + val tweet = tokens(5) + + val words = Twokenize(tweet).map(_.toLowerCase)//.toSet + + for(word <- words) { + if(!word.startsWith("@user_")) { + val prevDocSize = docSizes.getOrElse(docName, 0) + docSizes.put(docName, prevDocSize + 1) + val prevCount = wordDocToCount.getOrElse((word, docName), 0) + wordDocToCount.put((word, docName), prevCount + 1) + val prevSet = wordsToDocNames.getOrElse(word, new scala.collection.immutable.HashSet()) + wordsToDocNames.put(word, prevSet + docName) + //val prevError = wordsToErrors.getOrElse(word, 0.0) + //wordsToErrors.put(word, prevError + error) + } + } + } + } + line = in.readLine + } + in.close + + for((word, docNames) <- wordsToDocNames) { + for(docName <- docNames) { + val count = wordDocToCount((word, docName)) + val prevError = wordsToErrors.getOrElse(word, 0.0) + wordsToErrors.put(word, prevError + count.toDouble / docSizes(docName) * docNamesAndErrors(docName)) + } + } + + wordsToErrors.foreach(p => if(wordsToDocNames(p._1).size < docThreshold) wordsToErrors.remove(p._1)) + + wordsToErrors.foreach(p => wordsToErrors.put(p._1, p._2 / wordsToDocNames(p._1).size)) + + wordsToErrors.toList.sortWith((x, y) => x._2 < y._2).foreach(p => println(p._1+"\t"+p._2)) + } +} diff --git a/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgErrorUT.scala b/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgErrorUT.scala new file mode 100644 index 0000000..201a3ed --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/postprocess/WordRankerByAvgErrorUT.scala @@ -0,0 +1,95 @@ +/////////////////////////////////////////////////////////////////////////////// +// WordRankerByAvgErrorUT.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.postprocess + +import opennlp.fieldspring.util.Twokenize +import org.clapper.argot._ +import java.io._ + +object WordRankerByAvgErrorUT { + + import ArgotConverters._ + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.postprocess.WordRankerByAvgError", preUsage = Some("Fieldspring")) + val corpusFile = parser.option[String](List("i", "input"), "list", "corpus input file") + val listFile = parser.option[String](List("l", "list"), "list", "list input file") + val docThresholdOption = parser.option[Int](List("t", "threshold"), "threshold", "document frequency threshold") + + def main(args: Array[String]) { + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + if(corpusFile.value == None) { + println("You must specify a corpus input file via -i.") + sys.exit(0) + } + if(listFile.value == None) { + println("You must specify a list input file via -l.") + sys.exit(0) + } + + val docThreshold = if(docThresholdOption.value == None) 10 else docThresholdOption.value.get + + val docNamesAndErrors:Map[String, Double] = scala.io.Source.fromFile(listFile.value.get).getLines. + map(_.split("\t")).map(p => (p(0), p(1).toDouble)).toMap + + val in = new BufferedReader( + new InputStreamReader(new FileInputStream(corpusFile.value.get), "UTF8")) + + val wordsToErrors = new scala.collection.mutable.HashMap[String, Double] + val wordsToDocNames = new scala.collection.mutable.HashMap[String, scala.collection.immutable.HashSet[String]] + + var line:String = in.readLine + + while(line != null) { + val tokens = line.split("\t") + if(tokens.length >= 3) { + val docName = tokens(0) + if(docNamesAndErrors.contains(docName)) { + val error = docNamesAndErrors(docName) + val text = tokens(2) + + val wordsAndCounts:Map[String, Int] = text.split(" ").map(p => (p.split(":")(0), p.split(":")(1).toInt)).toMap + val docSize = text.split(" ").map(_.split(":")(1).toInt).sum + + for((word, count) <- wordsAndCounts) { + //if(!word.startsWith("@user_")) { + val prevError = wordsToErrors.getOrElse(word, 0.0) + wordsToErrors.put(word, prevError + error * count.toDouble/docSize) + val prevSet = wordsToDocNames.getOrElse(word, new scala.collection.immutable.HashSet()) + wordsToDocNames.put(word, prevSet + docName) + //} + } + } + } + line = in.readLine + } + in.close + + wordsToErrors.foreach(p => if(wordsToDocNames(p._1).size < docThreshold) wordsToErrors.remove(p._1)) + + wordsToErrors.foreach(p => wordsToErrors.put(p._1, p._2 / wordsToDocNames(p._1).size)) + + wordsToErrors.toList.sortWith((x, y) => x._2 < y._2).foreach(p => println(p._1+"\t"+p._2)) + } +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ConvertTwitterInfochimps.scala b/src/main/scala/opennlp/fieldspring/preprocess/ConvertTwitterInfochimps.scala new file mode 100644 index 0000000..7cb25be --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ConvertTwitterInfochimps.scala @@ -0,0 +1,570 @@ +/////////////////////////////////////////////////////////////////////////////// +// ConvertTwitterInfochimps.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import collection.mutable + +import java.io.PrintStream + +import org.apache.commons.lang3.StringEscapeUtils._ + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil.{FileHandler, LocalFileHandler, LineProcessor} +import opennlp.fieldspring.util.MeteredTask +import opennlp.fieldspring.util.osutil.output_resource_usage +import opennlp.fieldspring.util.printutil._ +import opennlp.fieldspring.util.textutil.with_commas +import opennlp.fieldspring.util.Twokenize + +/* + +Steps for converting Infochimps to our format: + +1) Input is a series of files, e.g. part-00000.gz, each about 180 MB. +2) Each line looks like this: + + +100000018081132545 20110807002716 25430513 GTheHardWay Niggas Lost in the Sauce ..smh better slow yo roll and tell them hoes to get a job nigga #MRIloveRatsIcanchange&amp;saveherassNIGGA <a href="http://twitter.com/download/android" rel="nofollow">Twitter for Android</a> en 42.330165 -83.045913 +The fields are: + +1) Tweet ID +2) Time +3) User ID +4) User name +5) Empty? +6) User name being replied to (FIXME: which JSON field is this?) +7) User ID for replied-to user name (but sometimes different ID's for same user name) +8) Empty? +9) Tweet text -- double HTML-encoded (e.g. & becomes &amp;) +10) HTML anchor text indicating a link of some sort, HTML-encoded (FIXME: which JSON field is this?) +11) Language, as a two-letter code +12) Latitude +13) Longitude +14) Empty? +15) Empty? +16) Empty? +17) Empty? +18) Empty? + + +3) We want to convert each to two files: one containing the article-data, + one containing the text. We can later convert the text to unigram counts, + bigram counts, etc. + +*/ + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +class ConvertTwitterInfochimpsParameters(ap: ArgParser) extends + ProcessFilesParameters(ap) { + var output_stats = + ap.flag("output-stats", + help = """If true, output time-based and user-based statistics on the +tweets in the input, rather than converting the files. Other flags may be +given to request additional statistics.""") + var output_min_stats = + ap.flag("output-min-stats", + help = """If true, output time-based statistics on the +tweets in the input, rather than converting the files. Other flags may be +given to request additional statistics.""") + var output_all_stats = + ap.flag("output-all-stats", + help = """If true, output all statistics on the tweets in the input, +rather than converting the files.""") + var user_stats = + ap.flag("user-stats", + help = """If true, extra statistics involving number of tweets per user +are computed.""") + var user_to_userid_stats = + ap.flag("user-to-userid-stats", + help = """If true, extra statistics involving user-to-userid matches +are computed.""") + var reply_user_stats = + ap.flag("reply-user-stats", + help = """If true, extra statistics involving who replies to whom +are computed.""") + val files = + ap.multiPositional[String]("files", + help = """File(s) to process for input.""") +} + +abstract class TwitterInfochimpsFileProcessor extends LineProcessor[Unit] { + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + val task = new MeteredTask("tweet", "parsing") + var lineno = 0 + for (line <- lines) { + lineno += 1 + line.split("\t", -1).toList match { + case id :: time :: userid :: username :: _ :: + reply_username :: reply_userid :: _ :: text :: anchor :: lang :: + lat :: long :: _ :: _ :: _ :: _ :: _ :: Nil => { + //val raw_anchor = unescapeXml(anchor) + // Go ahead and leave it encoded, just to make sure no TAB chars; + // we don't really use it much anyway. + val raw_anchor = anchor + assert(!(raw_anchor contains '\t')) + val metadata = + Seq("corpus-name"->"twitter-infochimps", + "id"->id, "title"->id, "split"->"training", + "coord"->("%s,%s" format (lat, long)),"time"->time, + "username"->username, "userid"->userid, + "reply_username"->reply_username, "reply_userid"->reply_userid, + "anchor"->raw_anchor, "lang"->lang) + process_line(metadata, text) + } + case _ => { + errprint("Bad line #%d: %s" format (lineno, line)) + errprint("Line length: %d" format line.split("\t", -1).length) + } + } + task.item_processed() + } + task.finish() + print_msg_heading("Memory/time usage:", blank_lines_before = 3) + output_resource_usage(dojava = false) + (true, ()) + } + + // To be implemented + + def process_line(metadata: Seq[(String, String)], text: String) +} + +class ConvertTwitterInfochimpsFileProcessor( + params: ConvertTwitterInfochimpsParameters, + suffix: String +) extends TwitterInfochimpsFileProcessor { + var outstream: PrintStream = _ + val compression_type = "bzip2" + var schema: Seq[String] = null + var current_filehand: FileHandler = _ + + override def begin_process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) { + val (_, outname) = filehand.split_filename(realname) + val out_text_name = "%s/twitter-infochimps-%s-%s.txt" format ( + params.output_dir, outname, suffix) + errprint("Text document file is %s..." format out_text_name) + current_filehand = filehand + outstream = filehand.openw(out_text_name, compression = compression_type) + super.begin_process_lines(lines, filehand, file, compression, realname) + } + + def process_line(metadata: Seq[(String, String)], text: String) { + val rawtext = unescapeHtml4(unescapeXml(text)) + val splittext = Twokenize(rawtext) + val outdata = metadata ++ Seq("text"->(splittext mkString " ")) + val (schema, fieldvals) = outdata.unzip + if (this.schema == null) { + // Output the schema file, first time we see a line + val schema_file_name = + "%s/twitter-infochimps-%s-schema.txt" format (params.output_dir, suffix) + val schema_stream = current_filehand.openw(schema_file_name) + errprint("Schema file is %s..." format schema_file_name) + schema_stream.println(schema mkString "\t") + schema_stream.close() + this.schema = schema + } else + assert (this.schema == schema) + outstream.println(fieldvals mkString "\t") + } + + override def end_process_file(filehand: FileHandler, file: String) { + outstream.close() + super.end_process_file(filehand, file) + } +} + +class TwitterStatistics(params: ConvertTwitterInfochimpsParameters) { + // Over-all statistics + var num_tweets = 0 + + val max_string_val = "\uFFFF" + val min_string_val = "" + // Time-based statistics + val tweets_by_day = intmap[String]() + val tweets_by_hour = intmap[String]() + val tweets_by_minute = intmap[String]() + // For percentages between 0 and 100, what's the earliest time such that + // this many percent of tweets are before it? + val quantile_times = stringmap[Int]() + // In general, when finding the minimum, we want the default to be greater + // than any possible value, and when finding the maximum, we want the + // default to be less than any possible value. For string comparison, + // the empty string is less than any other string, and a string beginning + // with Unicode 0xFFFF is greater than almost any other string. (0xFFFF + // is not a valid Unicode character.) + var earliest_tweet_time = max_string_val + var latest_tweet_time = min_string_val + var earliest_tweet_id = max_string_val + var latest_tweet_id = min_string_val + + // If tweet ID's are sorted by time, the following two should be + // identical to the earliest_tweet_id/latest_tweet_id. Tweets ID's + // may have exceeeded Long range; if not, they probably will soon. + var num_lowest_tweet_id = BigInt("9"*50) + var num_lowest_tweet_time = "" + var num_highest_tweet_id = BigInt(-1) + var num_highest_tweet_time = "" + + // Similar but we sort tweet ID's lexicographically. + var lex_lowest_tweet_id = max_string_val + var lex_lowest_tweet_time = "" + var lex_highest_tweet_id = "" + var lex_highest_tweet_time = "" + + // User-based statistics + val tweets_by_user = intmap[String]() + val num_users_by_num_tweets = intmap[Int]() + val userid_by_user = intmapmap[String, String]() + val userid_by_user_min_time = stringmapmap[String, String](max_string_val) + val userid_by_user_max_time = stringmapmap[String, String](min_string_val) + val user_by_userid = intmapmap[String, String]() + val user_by_userid_min_time = stringmapmap[String, String](max_string_val) + val user_by_userid_max_time = stringmapmap[String, String](min_string_val) + val tweets_by_reply_user = intmap[String]() + val reply_userid_by_reply_user = intmapmap[String, String]() + val reply_userid_by_reply_user_min_time = stringmapmap[String, String](max_string_val) + val reply_userid_by_reply_user_max_time = stringmapmap[String, String](min_string_val) + val reply_user_by_reply_userid = intmapmap[String, String]() + val reply_user_by_reply_userid_min_time = stringmapmap[String, String](max_string_val) + val reply_user_by_reply_userid_max_time = stringmapmap[String, String](min_string_val) + val reply_user_by_user = intmapmap[String, String]() + val user_by_reply_user = intmapmap[String, String]() + + def record_tweet(metadata: Seq[(String, String)], text: String) { + val tparams = metadata.toMap + val time = tparams("time") + val id = tparams("id") + val username = tparams("username") + val userid = tparams("userid") + val reply_username = tparams("reply_username") + val reply_userid = tparams("reply_userid") + + num_tweets += 1 + + // Time format is YYYYMMDDHHmmSS, hence take(8) goes up through the day, + // take(10) up through the hour, etc. + tweets_by_day(time.take(8)) += 1 + tweets_by_hour(time.take(10)) += 1 + // tweets_by_minute(time.take(12)) += 1 + + if (time < earliest_tweet_time) { + earliest_tweet_time = time + earliest_tweet_id = id + } + if (time > latest_tweet_time) { + latest_tweet_time = time + latest_tweet_id = id + } + if (id < lex_lowest_tweet_id) { + lex_lowest_tweet_id = id + lex_lowest_tweet_time = time + } + if (id > lex_highest_tweet_id) { + lex_highest_tweet_id = id + lex_highest_tweet_time = time + } + val num_id = BigInt(id) + if (num_id < num_lowest_tweet_id) { + num_lowest_tweet_id = num_id + num_lowest_tweet_time = time + } + if (num_id > num_highest_tweet_id) { + num_highest_tweet_id = num_id + num_highest_tweet_time = time + } + + + def set_max_with_cur[T,U <% Ordered[U]](table: mutable.Map[T,U], + key: T, newval: U) { + if (table(key) < newval) + table(key) = newval + } + def set_min_with_cur[T,U <% Ordered[U]](table: mutable.Map[T,U], + key: T, newval: U) { + if (table(key) > newval) + table(key) = newval + } + if (params.user_stats) { + tweets_by_user(username) += 1 + tweets_by_reply_user(reply_username) += 1 + } + if (params.user_to_userid_stats) { + userid_by_user(username)(userid) += 1 + set_max_with_cur(userid_by_user_max_time(username), userid, time) + set_min_with_cur(userid_by_user_min_time(username), userid, time) + user_by_userid(userid)(username) += 1 + set_max_with_cur(user_by_userid_max_time(userid), username, time) + set_min_with_cur(user_by_userid_min_time(userid), username, time) + } + if (params.reply_user_stats && reply_username != "") { + if (params.user_to_userid_stats) { + reply_userid_by_reply_user(reply_username)(reply_userid) += 1 + set_max_with_cur(reply_userid_by_reply_user_max_time(reply_username), reply_userid, time) + set_min_with_cur(reply_userid_by_reply_user_min_time(reply_username), reply_userid, time) + reply_user_by_reply_userid(reply_userid)(reply_username) += 1 + set_max_with_cur(reply_user_by_reply_userid_max_time(reply_userid), reply_username, time) + set_min_with_cur(reply_user_by_reply_userid_min_time(reply_userid), reply_username, time) + } + reply_user_by_user(username)(reply_username) += 1 + user_by_reply_user(reply_username)(username) += 1 + } + } + + def finish_statistics() { + num_users_by_num_tweets.clear() + for ((user, count) <- tweets_by_user) + num_users_by_num_tweets(count) += 1 + quantile_times.clear() + var tweets_so_far = 0 + var next_quantile_to_set = 0 + for ((time, count) <- tweets_by_hour.toSeq.sorted) { + tweets_so_far += count + val percent_seen = 100*(tweets_so_far.toDouble / num_tweets) + while (next_quantile_to_set <= percent_seen) { + quantile_times(next_quantile_to_set) = time + next_quantile_to_set += 1 + } + } + } + + def print_statistics() { + finish_statistics() + + val how_many_summary = 10000 + val how_many_summary_str = with_commas(how_many_summary) + val how_many_detail = 100 + val how_many_detail_str = with_commas(how_many_detail) + + errprint("") + errprint("Earliest tweet: %s at %s" format + (earliest_tweet_id, earliest_tweet_time)) + errprint("Numerically lowest tweet ID: %s at %s" format + (num_lowest_tweet_id, num_lowest_tweet_time)) + errprint("Lexicographically lowest tweet ID: %s at %s" format + (lex_lowest_tweet_id, lex_lowest_tweet_time)) + errprint("") + errprint("Latest tweet: %s at %s" format + (latest_tweet_id, latest_tweet_time)) + errprint("Numerically highest tweet ID: %s at %s" format + (num_highest_tweet_id, num_highest_tweet_time)) + errprint("Lexicographically highest tweet ID: %s at %s" format + (lex_highest_tweet_id, lex_highest_tweet_time)) + errprint("") + errprint("Number of tweets: %s" format num_tweets) + if (tweets_by_user.size > 0) { + errprint("Number of users: %s" format tweets_by_user.size) + print_msg_heading( + "Top %s users by number of tweets:" format how_many_summary_str) + output_reverse_sorted_table(tweets_by_user, maxrows = how_many_summary) + print_msg_heading( + "Frequency of frequencies (number of users with given number of tweets):") + output_key_sorted_table(num_users_by_num_tweets) + } + + print_msg_heading("Tweets by day:") + output_key_sorted_table(tweets_by_day) + print_msg_heading("Tweets by hour:") + output_key_sorted_table(tweets_by_hour) + if (tweets_by_minute.size > 0) { + print_msg_heading("Tweets by minute:") + output_key_sorted_table(tweets_by_minute) + } + print_msg_heading("Tweet quantiles by time (minimum time for given percent of tweets):") + output_key_sorted_table(quantile_times) + + def reply_to_details(sending: String, _from: String, _to: String, + tweets_by_user_map: mutable.Map[String, Int], + tweets_by_reply_user_map: mutable.Map[String, Int]) { + if (tweets_by_user_map.size == 0) + return + print_msg_heading("Reply-to, for top %s %s users:" format + (how_many_detail_str, sending)) + for (((user, count), index0) <- + tweets_by_user_map.toSeq.sortWith(_._2 > _._2). + slice(0, how_many_detail). + zipWithIndex) { + val index = index0 + 1 + errprint("#%d: User %s (%d tweets %s, %d tweets %s):", + index, user, count, _from, tweets_by_reply_user_map(user), _to) + def output_table_for_user(header: String, + table: mutable.Map[String, mutable.Map[String, Int]]) { + if (table.size > 0) { + errprint("#%d: %s:" format (index, header)) + output_reverse_sorted_table(table(user), indent = " ") + } + } + output_table_for_user("Corresponding user ID's by tweet count:", + userid_by_user) + output_table_for_user("Users that this user replied to:", + reply_user_by_user) + output_table_for_user("Users that relied to this user:", + user_by_reply_user) + } + } + + reply_to_details("sending", "from", "to", tweets_by_user, + tweets_by_reply_user) + reply_to_details("receiving", "to", "from", tweets_by_reply_user, + tweets_by_user) + + def output_x_with_multi_y(header: String, xdesc: String, ydesc: String, + x_to_y: mutable.Map[String, mutable.Map[String, Int]], + x_to_y_min_time: mutable.Map[String, mutable.Map[String, String]], + x_to_y_max_time: mutable.Map[String, mutable.Map[String, String]] + ) { + if (x_to_y.size > 0) { + print_msg_heading(header) + val x_with_multi_y = + (for ((x, ys) <- x_to_y; if ys.size > 1) + yield (x, ys.size)) + for (((x, count), index) <- + x_with_multi_y.toSeq.sortWith(_._2 > _._2).zipWithIndex) { + errprint("#%d: %s %s (%d different %s's): (listed by num tweets)", + index + 1, xdesc, x, count, ydesc) + for ((y, count) <- x_to_y(x).toSeq.sortWith(_._2 > _._2)) { + errprint("%s%s = %s (from %s to %s)" format + (" ", y, count, x_to_y_min_time(x)(y), x_to_y_max_time(x)(y))) + } + } + } + } + + output_x_with_multi_y("Users with multiple user ID's:", + "user", "ID", userid_by_user, + userid_by_user_min_time, userid_by_user_max_time) + + output_x_with_multi_y("User ID's with multiple users:", + "ID", "user", user_by_userid, + user_by_userid_min_time, user_by_userid_max_time) + + output_x_with_multi_y("Reply users with multiple reply user ID's:", + "reply user", "reply ID", reply_userid_by_reply_user, + reply_userid_by_reply_user_min_time, reply_userid_by_reply_user_max_time) + + output_x_with_multi_y("Reply ID's with multiple reply users:", + "reply ID", "reply user", reply_user_by_reply_userid, + reply_user_by_reply_userid_min_time, reply_user_by_reply_userid_max_time) + + print_msg_heading("Memory/time usage:", blank_lines_before = 3) + output_resource_usage(dojava = false) + } +} + +class TwitterInfochimpsStatsFileProcessor( + params: ConvertTwitterInfochimpsParameters +) extends + TwitterInfochimpsFileProcessor { + var curfile: String = _ + var curfile_stats: TwitterStatistics = _ + val global_stats = new TwitterStatistics(params) + + override def begin_process_file(filehand: FileHandler, file: String) { + val (_, outname) = filehand.split_filename(file) + curfile = outname + curfile_stats = new TwitterStatistics(params) + super.begin_process_file(filehand, file) + } + + def process_line(metadata: Seq[(String, String)], text: String) { + curfile_stats.record_tweet(metadata, text) + global_stats.record_tweet(metadata, text) + } + + def print_curfile_stats() { + print_msg_heading("Statistics for file %s:" format curfile, + blank_lines_before = 6) + curfile_stats.print_statistics() + } + + def print_global_stats(is_final: Boolean = false) { + print_msg_heading( + "Statistics for all files%s:" format (if (is_final) "" else " so far"), + blank_lines_before = 6) + global_stats.print_statistics() + } + + override def end_process_file(filehand: FileHandler, file: String) { + print_curfile_stats() + print_global_stats() + super.end_process_file(filehand, file) + } + + override def end_processing(filehand: FileHandler, files: Iterable[String]) { + print_global_stats(is_final = true) + super.end_processing(filehand, files) + } +} + +class ConvertTwitterInfochimpsDriver extends + ProcessFilesDriver with StandaloneExperimentDriverStats { + type TParam = ConvertTwitterInfochimpsParameters + + override def handle_parameters() { + if (params.output_all_stats) { + params.output_stats = true + params.user_stats = true + params.user_to_userid_stats = true + params.reply_user_stats = true + } + if (params.reply_user_stats) { + params.user_stats = true + } + if (params.output_stats) { + params.user_stats = true + params.output_min_stats = true + } + if (!params.output_min_stats) + super.handle_parameters() + } + + override def run_after_setup() { + val filehand = get_file_handler + if (params.output_min_stats) + new TwitterInfochimpsStatsFileProcessor(params). + process_files(filehand, params.files) + else { + super.run_after_setup() + new ConvertTwitterInfochimpsFileProcessor(params, "text"). + process_files(filehand, params.files) + } + } +} + +object ConvertTwitterInfochimps extends + ExperimentDriverApp("ConvertTwitterInfochimps") { + type TDriver = ConvertTwitterInfochimpsDriver + + override def description = +"""Convert input files in the Infochimps Twitter corpus into files in the +format expected by Fieldspring. If --output-stats or a related argument +is given, output statistics to stderr rather than converting text. +""" + + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ExtractGeotaggedListFromWikiDump.scala b/src/main/scala/opennlp/fieldspring/preprocess/ExtractGeotaggedListFromWikiDump.scala new file mode 100644 index 0000000..62840b4 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ExtractGeotaggedListFromWikiDump.scala @@ -0,0 +1,152 @@ +/////////////////////////////////////////////////////////////////////////////// +// PreprocWikiDump.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import opennlp.fieldspring.tr.topo._ + +import java.io._ +import org.apache.commons.compress.compressors.bzip2._ + +object ExtractGeotaggedListFromWikiDump { + + var MAX_COUNT = 0 + //val NEW_PAGE = " " + //val coordRE = """\|\s*latd|\|\s*lat_deg|\|\s*latG|\|\s*latitude|\{\{\s*Coord?|\|\s*Breitengrad|\{\{\s*Coordinate\s""".r + + val titleRE = """^\s{4}<title>(.*)\s*$""".r + val redirectRE = """^\s{4}\s*$""".r + val idRE = """^\s{4}(.*)\s*$""".r + + val coordRE = """^.*\{\{\s*(?:[Cc]oord|[Cc]oordinate)\s*\|.*$""".r + val coord_decimal = """^.*\{\{\s*(?:[Cc]oord|[Cc]oordinate)\s*\|\s*(-?\d+\.?\d*+)\s*\|\s*(-?\d+\.?\d*+)\s*\|.*$""".r + val coord_dms = """^.*\{\{\s*(?:[Cc]oord|[Cc]oordinate)?\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*([Nn]|[Ss])\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*([Ee]|[Ww])\s*.*$""".r + val coord_dm = """^.*\{\{\s*(?:[Cc]oord|[Cc]oordinate)?\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*([Nn]|[Ss])\s*\|\s*(\d+)\s*\|\s*(\d+)\s*\|\s*([Ee]|[Ww])\s*.*$""".r + val coord_d = """^.*\{\{\s*(?:[Cc]oord|[Cc]oordinate)?\s*\|\s*(-?\d+\.?\d*+)\s*\|\s*([Nn]|[Ss])\s*\|\s*(-?\d+\.?\d*+)\s*\|\s*([Ee]|[Ww])\s*.*$""".r + val latRE = """^.*[Ll]atd\s*=\s*(\d+)\s*\|.*$""".r + val latdms = """^.*[Ll]atd\s*=\s*(\d+)\s*\|\s*[Ll]atm\s*=(\d+)\s*\|\s*[Ll]ats\s*=(\d+)\s*\|\s*[Ll]at[Nn][Ss]\s*=\s*([Nn]|[Ss])\s*\|\s*[Ll]ongd\s*=\s*(\d+)\s*\|\s*[Ll]ongm\s*=\s*(\d+)\s*\|\s*[Ll]ongs\s*=\s*(\d+)\s*\|\s*[Ll]ong[Ee][Ww]\s*=\s*([Ee]|[Ww])\s*.*$""".r + val latdm = """^.*[Ll]atd\s*=\s*(\d+)\s*\|\s*[Ll]atm\s*=(\d+)\s*\|\s*[Ll]at[Nn][Ss]\s*=\s*([Nn]|[Ss])\s*\|\s*[Ll]ongd\s*=\s*(\d+)\s*\|\s*[Ll]ongm\s*=\s*(\d+)\s*\|\s*[Ll]ong[Ee][Ww]\s*=\s*([Ee]|[Ww])\s*.*$""".r + val latd = """^.*[Ll]atd\s*=\s*(-?\d+\.?\d*+)\s*\|\s*[Ll]at[Nn][Ss]\s*=\s*([Nn]|[Ss])\s*\|\s*[Ll]ongd\s*=\s*(-?\d+\.?\d*+)\s*\|\s*[Ll]ong[Ee][Ww]\s*=\s*([Ee]|[Ww])\s*.*$""".r + + def main(args: Array[String]) { + val fileInputStream = new FileInputStream(new File(args(0))) + if(args.length >= 2) + MAX_COUNT = args(1).toInt + //fileInputStream.read(); // used to be null pointer without this + //fileInputStream.read(); + val cbzip2InputStream = new BZip2CompressorInputStream(fileInputStream) + val in = new BufferedReader(new InputStreamReader(cbzip2InputStream)) + + val redirectsOut = new BufferedWriter(new FileWriter("redirects.txt")) + + var totalPageCount = 0 + var geotaggedPageCount = 0 + var lookingForCoord = false + var lineCount = 0 + var title = "" + //var redirectTitle = "" + var id = "" + var line = in.readLine + while(line != null && (MAX_COUNT <= 0 || lineCount < MAX_COUNT)) { + + if(titleRE.findFirstIn(line) != None) { + val titleRE(t) = line + title = t + totalPageCount += 1 + /*if(totalPageCount % 10000 == 0) + println(line+" "+geotaggedPageCount+"/"+totalPageCount)*/ + lookingForCoord = true + //redirectTitle = null + } + + if(redirectRE.findFirstIn(line) != None) { + val redirectRE(r) = line + redirectsOut.write(title+"\t"+r+"\n") + } + + if(idRE.findFirstIn(line) != None) { + val idRE(i) = line + id = i + } + + //if(lookingForCoord) {// && coordRE.findFirstIn(line) != None) { + //println(title) + //println(line) + //geotaggedPageCount += 1 + //lookingForCoord = false + + var coord:Coordinate = null + + if(lookingForCoord) { + if(coordRE.findFirstIn(line) != None) { + coord = line match { + case coord_decimal(lat,lon) => { Coordinate.fromDegrees(lat.toDouble, lon.toDouble) } + case coord_dms(latd, latm, lats, ns, longd, longm, longs, ew) => { + val lat = (if(ns.equalsIgnoreCase("S")) -1 else 1) * latd.toDouble + latm.toDouble/60 + lats.toDouble/3600 + val lon = (if(ew.equalsIgnoreCase("W")) -1 else 1) * longd.toDouble + longm.toDouble/60 + longs.toDouble/3600 + Coordinate.fromDegrees(lat, lon) + } + case coord_dm(latd, latm, ns, longd, longm, ew) => { + val lat = (if(ns.equalsIgnoreCase("S")) -1 else 1) * latd.toDouble + latm.toDouble/60 + val lon = (if(ew.equalsIgnoreCase("W")) -1 else 1) * longd.toDouble + longm.toDouble/60 + Coordinate.fromDegrees(lat, lon) + } + case coord_d(lat,ns, lon, ew) => { Coordinate.fromDegrees((if(ns.equalsIgnoreCase("S")) -1 else 1)*lat.toDouble, (if(ew.equalsIgnoreCase("W")) -1 else 1)*lon.toDouble) } + case _ => null + } + } + + else if(latRE.findFirstIn(line) != None) { + coord = line match { + case latdms(latd, latm, lats, ns, longd, longm, longs, ew) => { + val lat = (if(ns.equalsIgnoreCase("S")) -1 else 1) * latd.toDouble + latm.toDouble/60 + lats.toDouble/3600 + val lon = (if(ew.equalsIgnoreCase("W")) -1 else 1) * longd.toDouble + longm.toDouble/60 + longs.toDouble/3600 + Coordinate.fromDegrees(lat, lon) + } + case latdm(latd, latm, ns, longd, longm, ew) => { + val lat = (if(ns.equalsIgnoreCase("S")) -1 else 1) * latd.toDouble + latm.toDouble/60 + val lon = (if(ew.equalsIgnoreCase("W")) -1 else 1) * longd.toDouble + longm.toDouble/60 + Coordinate.fromDegrees(lat, lon) + } + case latd(latd, ns, longd, ew) => { + val lat = (if(ns.equalsIgnoreCase("S")) -1 else 1) * latd.toDouble + val lon = (if(ew.equalsIgnoreCase("W")) -1 else 1) * longd.toDouble + Coordinate.fromDegrees(lat, lon) + } + case _ => null + } + } + + if(coord != null) { + println(id+"\t"+title+"\t"+coord) + geotaggedPageCount += 1 + lookingForCoord = false + } + } + + line = in.readLine + lineCount += 1 + } + + //println("Geotagged page count: "+geotaggedPageCount) + //println("Total page count: "+totalPageCount) + + redirectsOut.close + in.close + } +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ExtractLinksFromWikiDump.scala b/src/main/scala/opennlp/fieldspring/preprocess/ExtractLinksFromWikiDump.scala new file mode 100644 index 0000000..33ad355 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ExtractLinksFromWikiDump.scala @@ -0,0 +1,367 @@ +/////////////////////////////////////////////////////////////////////////////// +// PreprocWikiDump.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.app._ +import opennlp.fieldspring.tr.topo.gaz._ + +import java.io._ +//import java.util._ +import java.util.ArrayList +import java.util.zip._ +import org.apache.commons.compress.compressors.bzip2._ + +import scala.collection.JavaConversions._ + +import org.clapper.argot._ +import ArgotConverters._ + +object ExtractLinksFromWikiDump { + + //var MAX_COUNT = 0//100000 + val windowSize = 20 + val THRESHOLD = 10.0 / 6372.8 // 10km in radians + + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.preprocess.ExtractLinksFromWikiDump", preUsage = Some("Fieldspring")) + + val articleNamesIDsCoordsFile = parser.option[String](List("a", "art"), "art", "wiki article IDs, titles, and coordinates (as output by ExtractGeotaggedListFromWikiDump)") + val rawWikiInputFile = parser.option[String](List("w", "wiki"), "wiki", "raw wiki input file (bz2)") + val trInputFile = parser.option[String](List("i", "tr-input"), "tr-input", "toponym resolution corpus input path") + val gazInputFile = parser.option[String](List("g", "gaz"), "gaz", "serialized gazetteer input file") + val stoplistInputFile = parser.option[String](List("s", "stoplist"), "stoplist", "stopwords input file") + val redirectsInputFile = parser.option[String](List("r", "redirects"), "redirects", "redirects input file") + val linksOutputFile = parser.option[String](List("l", "links"), "links", "geotagged->geotagged link count output file") + val trainingInstanceOutputDir = parser.option[String](List("d", "training-dir"), "training-dir", "training instance output directory") + val maxCountOption = parser.option[Int](List("n", "max-count"), "max-count", "maximum number of lines to read (if unspecified, all will be read)") + + //val coordRE = """\|\s*latd|\|\s*lat_deg|\|\s*latG|\|\s*latitude|\{\{\s*Coord?|\|\s*Breitengrad|\{\{\s*Coordinate\s""".r + + val titleRE = """^\s{4}(.*)\s*$""".r + val idRE = """^\s{4}(.*)\s*$""".r + + val listRE = """^(\d+)\t([^\t]+)\t(-?\d+\.?\d*),(-?\d+\.?\d*)$""".r + val redirectRE = """^(.+)\t(.+)$""".r + + //val linkAndContextRE = """((?:\S+\s*){0,20})?(\[\[[^\|\]]+)(\|?[^\|\]]+)?\]\]((?:\s*\S+){0,20})?""".r + val tokenRE = """(?:\[\[(?:[^\|\]]+)?\|?(?:[^\|\]]+)\]\])|(?:\w[^ :&=;{}\|<>]*\w)""".r + val tokenOnlyRE = """\w[^ :&=;{}\|<>]*\w""".r + //val tokenRE = """(?:\[\[(?:[^\|\]]+)?\|?(?:[^\|\]]+)\]\])|\w+""".r + //val linkRE = """^\[\[([^\|\]]+)?\|?([^\|\]]+)\]\]$""".r + val linkRE = """^\[\[([^\|]+)(?:\|(.+))?\]\]$""".r + + val markupTokens = "nbsp,lt,gt,ref,br,thinsp,amp,url,deadurl,http,quot,cite".split(",").toSet + + def main(args: Array[String]) { + + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + //println(rawWikiInputFile.value.get) + + println("Reading output from ExtractGeotaggedListFromWikiDump from " + articleNamesIDsCoordsFile.value.get + " ...") + val articleNamesToIDsAndCoords = + (for(line <- scala.io.Source.fromFile(articleNamesIDsCoordsFile.value.get).getLines) yield { + line match { + case listRE(id, name, lat, lon) => Some((name, (id.toInt, Coordinate.fromDegrees(lat.toDouble, lon.toDouble)))) + case _ => None + } + }).flatten.toMap + + println("Reading redirects from ExtractGeotaggedListFromWikiDump from " + redirectsInputFile.value.get + " ...") + val redirects = + (for(line <- scala.io.Source.fromFile(redirectsInputFile.value.get).getLines) yield { + line match { + case redirectRE(title1, title2) => Some((title1, title2)) + case _ => None + } + }).flatten.toMap + + println("Reading toponyms from TR-CoNLL at " + trInputFile.value.get + " ...") + val toponyms:Set[String] = CorpusInfo.getCorpusInfo(trInputFile.value.get).map(_._1).toSet + + val links = new scala.collection.mutable.HashMap[(Int, Int), Int] // (location1.id, location2.id) => count + val toponymsToTrainingSets = new scala.collection.mutable.HashMap[String, ArrayList[String]] + + println("Reading serialized gazetteer from " + gazInputFile.value.get + " ...") + val gis = new GZIPInputStream(new FileInputStream(gazInputFile.value.get)) + val ois = new ObjectInputStream(gis) + val gnGaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + gis.close + + val stoplist:Set[String] = + if(stoplistInputFile.value != None) { + println("Reading stopwords file from " + stoplistInputFile.value.get + " ...") + scala.io.Source.fromFile(stoplistInputFile.value.get).getLines.toSet + } + else { + println("No stopwords file specified. Using an empty stopword list.") + Set() + } + + val fileInputStream = new FileInputStream(new File(rawWikiInputFile.value.get)) + val cbzip2InputStream = new BZip2CompressorInputStream(fileInputStream) + val in = new BufferedReader(new InputStreamReader(cbzip2InputStream)) + + //var totalPageCount = 0 + //var geotaggedPageCount = 0 + //var lookingForCoord = false + var lineCount = 0 + var pageTitle = "" + var id = "" + var line = in.readLine + val maxCount = if(maxCountOption.value != None) maxCountOption.value.get else 0 + while(line != null && (maxCount <= 0 || lineCount < maxCount)) { + + if(titleRE.findFirstIn(line) != None) { + val titleRE(t) = line + pageTitle = t + //totalPageCount += 1 + //if(totalPageCount % 10000 == 0) + // println(line+" "+geotaggedPageCount+"/"+totalPageCount) + //lookingForCoord = true + } + + if(idRE.findFirstIn(line) != None) { + val idRE(i) = line + id = i + } + + val tokArray:Array[String] = (for(token <- tokenRE.findAllIn(line)) yield { + if(markupTokens.contains(token)) + None + else + Some(token) + }).flatten.toArray + + for(tokIndex <- 0 until tokArray.size) { + val token = tokArray(tokIndex) + token match { + case linkRE(titleRaw,a) => { + + val title = redirects.getOrElse(titleRaw, titleRaw) + val titleLower = title.toLowerCase + + val idAndCoord = articleNamesToIDsAndCoords.getOrElse(title, null) + + if(idAndCoord != null) { + + // Count the link if the current page is also geotagged: + val thisIDAndCoord = articleNamesToIDsAndCoords.getOrElse(pageTitle, null) + if(thisIDAndCoord != null) { + val pair = (thisIDAndCoord._1, idAndCoord._1) + val prevCount = links.getOrElse(pair, 0) + links.put(pair, prevCount+1) + } + + // Extract and write context: + val looseLookupResult = looseLookup(gnGaz, titleLower) + val matchingToponym = looseLookupResult._1 + if(matchingToponym != null && toponyms(matchingToponym)) { + + val closestGazIndex = getClosestGazIndex(gnGaz, titleLower, idAndCoord._2, idAndCoord._1, looseLookupResult._2) + //val matchingToponym = closestGazIndexResult._1 + //if(toponyms(matchingToponym)) { + //val closestGazIndex = closestGazIndexResult._2 + if(closestGazIndex != -1) { + val context = getContextFeatures(tokArray, tokIndex, windowSize, stoplist) + if(context.size > 0) { + val strippedContext = looseLookupResult._3 + //print(matchingToponym+": ") + //context.foreach(f => print(f+",")) + //tokenRE.findAllIn(strippedContext).foreach(f => print(f+",")) + //println(closestGazIndex) + + val contextAndLabelArray = Array.concat(context, tokenRE.findAllIn(strippedContext).toArray) + //val contextAndLabel:List[String] = (context.toList ::: tokenRE.findAllIn(strippedContext).toList) ::: (closestGazIndex.toString :: Nil) + val contextAndLabelString = contextAndLabelArray.mkString(",")+","+closestGazIndex.toString + print(matchingToponym+": ") + println(contextAndLabelString) + + val prevAL = toponymsToTrainingSets.getOrElse(matchingToponym, new ArrayList[String]) + prevAL.add(contextAndLabelString) + toponymsToTrainingSets.put(matchingToponym, prevAL) + } + } + //} + + } + + } + } + case _ => //println(token) + } + } + + if(lineCount % 100001 == 100000) { + println("*------------------") + println("*Line number - "+lineCount) + println("*Current article - "+pageTitle+" ("+id+")") + println("*line.size - "+line.size) + println("*tokArray.size - "+tokArray.size) + println("*toponymsToTrainingSets.size - "+toponymsToTrainingSets.size) + println("*toponymsToTrainingSets biggest training set - "+toponymsToTrainingSets.map(p => p._2.size).max) + println("*links.size - "+links.size) + println("*storedDistances.size - "+storedDistances.size) + println("*------------------") + } + + + line = in.readLine + lineCount += 1 + } + + val dir = + if(trainingInstanceOutputDir.value.get != None) { + println("Outputting training instances to directory " + trainingInstanceOutputDir.value.get + " ...") + val dirFile:File = new File(trainingInstanceOutputDir.value.get) + if(!dirFile.exists) + dirFile.mkdir + if(trainingInstanceOutputDir.value.get.endsWith("/")) + trainingInstanceOutputDir.value.get + else + trainingInstanceOutputDir.value.get+"/" + } + else { + println("Outputting training instances to current working directory ...") + "" + } + for((toponym, trainingSet) <- toponymsToTrainingSets) { + val outFile = new File(dir + toponym.replaceAll(" ", "_")+".txt") + val out = new BufferedWriter(new FileWriter(outFile)) + for(line <- trainingSet) { + //for(feature <- context) out.write(feature+",") + out.write(line+"\n") + } + out.close + } + + println("Writing links and counts to "+(if(linksOutputFile.value != None) linksOutputFile.value.get else "links.dat")+" ...") + val out = new DataOutputStream(new FileOutputStream(if(linksOutputFile.value != None) linksOutputFile.value.get else "links.dat")) + for(((id1, id2), count) <- links) { + out.writeInt(id1); out.writeInt(id2); out.writeInt(count) + } + out.close + + in.close + + println("All done.") + } + + /*def findMatchingToponym(gnGaz:GeoNamesGazetteer, titleLower:String): String = { + + }*/ + + // (loc.id, article.id) -> distance + val storedDistances = new scala.collection.mutable.HashMap[(Int, Int), Double] + + // Returns index of closest entry in gazetteer + def getClosestGazIndex(gnGaz:GeoNamesGazetteer, name:String, coord:Coordinate, articleID:Int, candidates:java.util.List[Location]): Int = { + //val looseLookupResult = looseLookup(gnGaz, name) + //val candidates = looseLookupResult._2 + //val candidates = gnGaz.lookup( + if(candidates != null) { + var minDist = Double.PositiveInfinity + var bestIndex = -1 + + for(index <- 0 until candidates.size) { + val loc = candidates(index) + val key = (loc.getId, articleID) + val dist = + if(loc.getRegion.getRepresentatives.size > 1) { // Only use distance table for multipoint locations + if(storedDistances.contains(key)) storedDistances(key) + else { + val distComputed = loc.getRegion.distance(coord) + storedDistances.put(key, distComputed) + distComputed + } + } + else { + loc.getRegion.distance(coord) + } + if(dist < THRESHOLD && dist < minDist) { + minDist = dist + bestIndex = index + } + } + + bestIndex + } + else + -1 + } + + // Returns string that matched in gazetteer and list of candidates, and any context that was stripped off + def looseLookup(gnGaz:GeoNamesGazetteer, name:String): (String, java.util.List[Location], String) = { + + var nameToReturn:String = null + var listToReturn:java.util.List[Location] = null + var strippedContext = "" + + val firstAttempt = gnGaz.lookup(name) + if(firstAttempt != null) { + nameToReturn = name + listToReturn = firstAttempt + } + + val commaIndex = name.indexOf(",") + + if(commaIndex != -1) { + val nameBeforeComma = name.slice(0, commaIndex).trim + val secondAttempt = gnGaz.lookup(nameBeforeComma) + if(secondAttempt != null) { + nameToReturn = nameBeforeComma + listToReturn = secondAttempt + strippedContext = name.drop(commaIndex+1).trim + } + else { + val parenIndex = name.indexOf("(") + val nameBeforeParen = name.slice(0, parenIndex).trim + val thirdAttempt = gnGaz.lookup(nameBeforeParen) + nameToReturn = nameBeforeParen + listToReturn = thirdAttempt + strippedContext = name.drop(parenIndex+1).trim + } + } + + (nameToReturn, listToReturn, strippedContext) + } + + def getContextFeatures(tokArray:Array[String], tokIndex:Int, windowSize:Int, stoplist:Set[String]): Array[String] = { + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(tokArray.length, tokIndex + windowSize + 1) + + val linksIncluded = Array.concat(tokArray.slice(startIndex, tokIndex), tokArray.slice(tokIndex + 1, endIndex))//.filterNot(stoplist(_)) + + // Remove link notation: + linksIncluded.map(t => t match { + case linkRE(title, a) => { + val anchor = if(a == null || a.trim.size == 0) title else a.trim + tokenOnlyRE.findAllIn(anchor).toArray//.split(" ") + } + case _ => tokenOnlyRE.findAllIn(t).toArray//.split(" ") + }).flatten.map(_.toLowerCase).filterNot(stoplist(_)) + } + +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/FindPolitical.scala b/src/main/scala/opennlp/fieldspring/preprocess/FindPolitical.scala new file mode 100644 index 0000000..0350d46 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/FindPolitical.scala @@ -0,0 +1,631 @@ +// FindPolitical.scala +// +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import collection.mutable + +import java.io._ + +import org.apache.commons.logging + +import com.nicta.scoobi.Scoobi._ + +import opennlp.fieldspring.{util => tgutil} +import tgutil.argparser._ +import tgutil.textdbutil._ +import tgutil.hadoop._ +import tgutil.ioutil._ +import tgutil.collectionutil._ +import tgutil.osutil._ +import tgutil.printutil._ + +class FindPoliticalParams(ap: ArgParser) extends + ScoobiProcessFilesParams(ap) { + var political_twitter_accounts = ap.option[String]( + "political-twitter-accounts", "pta", + help="""File containing list of politicians and associated twitter + accounts, for identifying liberal and conservative tweeters.""") + var political_twitter_accounts_format = ap.option[String]( + "political-twitter-accounts-format", "ptaf", + default = "officeholders", + choices = Seq("officeholders", "ideo-users"), + help="""Format for file specified in --political-twitter-accounts. + Possibilities: 'officeholders' (a file containing data gleaned from + the Internet, specifying holders of political offices and their parties), + 'ideo-users' (output from a previous run of FindPolitical, in + Fieldspring corpus format, with ideology identified by a numeric + score).""") + var min_accounts = ap.option[Int]("min-accounts", default = 2, + help="""Minimum number of political accounts referenced by Twitter users + in order for users to be considered. Default %default.""") + var min_conservative = ap.option[Double]("min-conservative", "mc", + default = 0.75, + help="""Minimum ideology score to consider a user as an "ideological + conservative". On the ideology scale, greater values indicate more + conservative. Currently, the scale runs from 0 to 1; hence, this value + should be greater than 0.5. Default %default.""") + var max_liberal = ap.option[Double]("max-liberal", "ml", + help="""Maximum ideology score to consider a user as an "ideological + liberal". On the ideology scale, greater values indicate more + conservative. Currently, the scale runs from 0 to 1; hence, this value + should be less than 0.5. If unspecified, computed as the mirror image of + the value of '--min-conservative' (e.g. 0.25 if + --min-conservative=0.75).""") + var iterations = ap.option[Int]("iterations", "i", + default = 1, + help="""Number of iterations when generating ideological users.""") + var corpus_name = ap.option[String]("corpus-name", + help="""Name of output corpus; for identification purposes. + Default to name taken from input directory.""") + var include_text = ap.flag("include-text", + help="""Include text of users sending tweets referencing a feature.""") + var ideological_ref_type = ap.option[String]("ideological-ref-type", "ilt", + default = "retweets", choices = Seq("retweets", "mentions", "followers"), + help="""Type of references to other accounts to use when determining the + ideology of a user. Possibilities are 'retweets' (accounts that tweets + are retweeted from); 'mentions' (any @-mention of an account, including + retweets); 'following' (accounts that a user is following). Default + %default.""") + var political_feature_type = ap.multiOption[String]("political-feature-type", + "pft", + choices = Seq("retweets", "followers", "hashtags", "urls", "images", + "unigrams", "bigrams", "trigrams", "ngrams"), + aliasedChoices = Seq(Seq("user-mentions", "mentions")), + help="""Type of political features to track when searching for data that may + be associated with particular ideologies. Possibilities are 'retweets' + (accounts that tweets are retweeted from); 'mentions' (any @-mention of an + account, including retweets); 'following' (accounts that a user is + following); 'hashtags'; 'unigrams'; 'bigrams'; 'trigrams'; 'ngrams'. + DOCUMENT THESE; NOT YET IMPLEMENTED. Multiple features can be tracked + simultaneously by specifying this option multiple times.""") + // FIXME: Should be able to specify multiple features separated by commas. + // This requires that we fix argparser.scala to allow this. Probably + // should add an extra field to provide a way of splitting -- maybe a regexp, + // maybe a function. + // Schema for the input file, after file read + var schema: Schema = _ + + override def check_usage() { + if (political_twitter_accounts == null) + ap.error("--political-twitter-accounts must be specified") + if (!ap.specified("max-liberal")) + max_liberal = 1 - min_conservative + if (iterations <= 0) + ap.error("--iterations must be > 0") + } +} + +/** + * A simple field-text file processor that just records the users and ideology. + * + * @param suffix Suffix used to select document metadata files in a directory + */ +class IdeoUserFileProcessor extends + TextDBProcessor[(String, Double)]("ideo-users") { + def handle_row(fieldvals: Seq[String]) = { + val user = schema.get_field(fieldvals, "user") + val ideology = + schema.get_field(fieldvals, "ideology").toDouble + Some((user.toLowerCase, ideology)) + } +} + +object FindPolitical extends + ScoobiProcessFilesApp[FindPoliticalParams] { + abstract class FindPoliticalAction(opts: FindPoliticalParams) + extends ScoobiProcessFilesAction { + val progname = "FindPolitical" + } + + /** + * Count of total number of references given a sequence of + * (data, weight, times) pairs of references to a particular data point. + */ + def count_refs[T](seq: Seq[(T, Double, Int)]) = seq.map(_._3).sum + /** + * Count of total weight given a sequence of (data, weight, times) pairs + * of references to a particular data point. + */ + def count_weight[T](seq: Seq[(T, Double, Int)]) = + seq.map{ case (_, weight, times) => weight*times }.sum + + /** + * Count of total number of accounts given a sequence of (data, times) pairs + * of references to a particular data point. + */ + def count_accounts[T](seq: Seq[(T, Double, Int)]) = seq.length + + + /** + * Description of a "politico" -- a politician along their party and + * known twitter accounts. + */ + case class Politico(last: String, first: String, title: String, + party: String, where: String, accounts: Seq[String]) { + def full_name = first + " " + last + } + implicit val politico_wire = + mkCaseWireFormat(Politico.apply _, Politico.unapply _) + + def encode_ideo_refs_map(seq: Seq[(String, Double, Int)]) = + (for ((account, ideology, count) <- seq sortWith (_._3 > _._3)) yield + ("%s:%.2f:%s" format ( + encode_string_for_count_map_field(account), ideology, count)) + ) mkString " " + + def empty_ideo_refs_map = Seq[(String, Double, Int)]() + + /** + * Description of a user and the accounts referenced, both political and + * nonpolitical, along with ideology. + * + * @param user Twitter account of the user + * @param ideology Computed ideology of the user (higher values indicate + * more conservative) + * @param ideo_refs Set of references to other accounts used in computing + * the ideology (either mentions, retweets or following, based on + * --ideological-ref-type); this is a sequence of tuples of + * (account, ideology, times), i.e. an account, its ideology and the number + * of times it was seen + * @param lib_ideo_refs Subset of `ideo_refs` that refer to liberal users + * @param cons_ideo_refs Subset of `ideo_refs` that refer to conservative users + * @param fields Field values of user's tweets (concatenated) + */ + case class IdeologicalUser(user: String, ideology: Double, + ideo_refs: Seq[(String, Double, Int)], + lib_ideo_refs: Seq[(String, Double, Int)], + cons_ideo_refs: Seq[(String, Double, Int)], + fields: Seq[String]) { + def get_feature_values(factory: IdeologicalUserAction, ty: String) = { + ty match { + case field@("retweets" | "user-mentions" | "hashtags") => + decode_count_map( + factory.user_subschema.get_field(fields, field)) + // case "followers" => FIXME + // case "unigrams" => FIXME + // case "bigrams" => FIXME + // case "trigrams" => FIXME + // case "ngrams" => FIXME + } + } + + def to_row(opts: FindPoliticalParams) = + Seq(user, "%.3f" format ideology, + count_accounts(ideo_refs), + count_refs(ideo_refs), + encode_ideo_refs_map(ideo_refs), + count_accounts(lib_ideo_refs), + count_refs(lib_ideo_refs), + encode_ideo_refs_map(lib_ideo_refs), + count_accounts(cons_ideo_refs), + count_refs(cons_ideo_refs), + encode_ideo_refs_map(cons_ideo_refs), + fields mkString "\t" + ) mkString "\t" + } + implicit val ideological_user_wire = + mkCaseWireFormat(IdeologicalUser.apply _, IdeologicalUser.unapply _) + + class IdeologicalUserAction(opts: FindPoliticalParams) extends + FindPoliticalAction(opts) { + val operation_category = "IdeologicalUser" + + val user_subschema_fieldnames = + opts.schema.fieldnames filterNot (_ == "user") + val user_subschema = new SubSchema(user_subschema_fieldnames, + opts.schema.fixed_values, opts.schema) + + def row_fields = + Seq("user", "ideology", + "num-ideo-accounts", "num-ideo-refs", "ideo-refs", + "num-lib-ideo-accounts", "num-lib-ideo-refs", "lib-ideo-refs", + "num-cons-ideo-accounts", "num-cons-ideo-refs", "cons-ideo-refs") ++ + user_subschema_fieldnames + + /** + * For a given user, determine if the user is an "ideological user" + * and if so, return an object describing the user. + * + * @param line Line of data describing a user, from `ParseTweets --grouping=user` + * @param accounts Mapping of ideological accounts and their ideology + * @param include_extra_fields True if we should include extra fields + * in the object specifying the references to ideological users that + * were found; only if we're writing the objects out for human inspection, + * not when we're iterating further + */ + def get_ideological_user(line: String, accounts: Map[String, Double], + include_extra_fields: Boolean) = { + error_wrap(line, None: Option[IdeologicalUser]) { line => { + val fields = line.split("\t", -1) + + def subsetted_fields = + if (include_extra_fields) + user_subschema.map_original_fieldvals(fields) + else Seq[String]() + + // get list of (refs, times) pairs + val ideo_ref_field = + if (opts.ideological_ref_type == "mentions") "user-mentions" + else opts.ideological_ref_type + val ideo_refs = + decode_count_map(opts.schema.get_field(fields, ideo_ref_field)) + val text = opts.schema.get_field(fields, "text") + val user = opts.schema.get_field(fields, "user") + //errprint("For user %s, ideo_refs: %s", user, ideo_refs.toList) + // find references to a politician + val libcons_ideo_refs = + for {(ideo_ref, times) <- ideo_refs + lower_ideo_ref = ideo_ref.toLowerCase + if accounts contains lower_ideo_ref + ideology = accounts(lower_ideo_ref)} + yield (lower_ideo_ref, ideology, times) + //errprint("libcons_ideo_refs: %s", libcons_ideo_refs.toList) + val num_libcons_ideo_refs = count_refs(libcons_ideo_refs) + if (num_libcons_ideo_refs > 0) { + val ideology = count_weight(libcons_ideo_refs)/num_libcons_ideo_refs + if (include_extra_fields) { + val lib_ideo_refs = libcons_ideo_refs.filter { + case (lower_ideo_ref, ideology, times) => + ideology <= opts.max_liberal + } + val num_lib_ideo_refs = count_refs(lib_ideo_refs) + val cons_ideo_refs = libcons_ideo_refs.filter { + case (lower_ideo_ref, ideology, times) => + ideology >= opts.min_conservative + } + val num_cons_ideo_refs = count_refs(cons_ideo_refs) + val ideo_user = + IdeologicalUser(user, ideology, libcons_ideo_refs, lib_ideo_refs, + cons_ideo_refs, subsetted_fields) + Some(ideo_user) + } else { + val ideo_user = + IdeologicalUser(user, ideology, empty_ideo_refs_map, + empty_ideo_refs_map, empty_ideo_refs_map, Seq[String]()) + Some(ideo_user) + } + } else if (accounts contains user.toLowerCase) { + val ideology = accounts(user.toLowerCase) + val ideo_user = + IdeologicalUser(user, ideology, empty_ideo_refs_map, + empty_ideo_refs_map, empty_ideo_refs_map, subsetted_fields) + Some(ideo_user) + } else + None + }} + } + } + + /** + * A political data point -- a piece of data (e.g. user mention, retweet, + * hash tag, URL, n-gram, etc.) in a tweet by an ideological user. + * + * @param data Data of the data point + * @param ty Type of data point + * @param spellings Map of actual (non-lowercased) spellings of data point + * by usage + * @param num_accounts Total number of accounts referencing data point + * @param num_refs Total number of references to data point + * @param num_lib_accounts Number of accounts with a noticeably + * "liberal" ideology referencing data point + * @param num_lib_refs Number of references to data point from accounts + * with a noticeably "liberal" ideology + * @param num_cons_accounts Number of accounts with a noticeably + * "conservative" ideology referencing data point + * @param num_cons_refs Number of references to data point from accounts + * with a noticeably "conservative" ideology + * @param num_refs_ideo_weighted Sum of references weighted by ideology of + * person doing the referenceing, so that we can compute a weighted + * average to determine their ideology. + * @param num_mentions Total number of mentions + * @param num_lib_mentions Number of times mentioned by people with + * a noticeably "liberal" ideology + * @param num_conserv_mentions Number of times mentioned by people with + * a noticeably "conservative" ideology + * @param num_ideo_mentions Sum of mentions weighted by ideology of + * person doing the mentioning, so that we can compute a weighted + * average to determine their ideology. + * @param all_text Text of all users referencing the politico. + */ + case class PoliticalFeature(value: String, spellings: Map[String, Int], + num_accounts: Int, num_refs: Int, + num_lib_accounts: Int, num_lib_refs: Int, + num_cons_accounts: Int, num_cons_refs: Int, + num_refs_ideo_weighted: Double, all_text: Seq[String]) { + def to_row(opts: FindPoliticalParams) = + Seq(value, encode_count_map(spellings.toSeq), + num_accounts, num_refs, + num_lib_accounts, num_lib_refs, + num_cons_accounts, num_cons_refs, + num_refs_ideo_weighted/num_refs, + if (opts.include_text) all_text mkString " !! " else "(omitted)" + ) mkString "\t" + } + + object PoliticalFeature { + + def row_fields = + Seq("value", "spellings", "num-accounts", "num-refs", + "num-lib-accounts", "num-lib-refs", + "num-cons-accounts", "num-cons-refs", + "ideology", "all-text") + /** + * For a given ideological user, generate the "potential politicos": other + * people referenced, along with their ideological scores. + */ + def get_political_features(factory: IdeologicalUserAction, + user: IdeologicalUser, ty: String, + opts: FindPoliticalParams) = { + for {(ref, times) <- user.get_feature_values(factory, ty) + lcref = ref.toLowerCase } yield { + val is_lib = user.ideology <= opts.max_liberal + val is_conserv = user.ideology >= opts.min_conservative + PoliticalFeature( + lcref, Map(ref->times), 1, times, + if (is_lib) 1 else 0, + if (is_lib) times else 0, + if (is_conserv) 1 else 0, + if (is_conserv) times else 0, + times * user.ideology, + Seq("FIXME fill-in text maybe") + ) + } + } + + /** + * Merge two PoliticalFeature objects, which must refer to the same user. + * Add up the references and combine the set of spellings. + */ + def merge_political_features(u1: PoliticalFeature, u2: PoliticalFeature) = { + assert(u1.value == u2.value) + PoliticalFeature(u1.value, combine_maps(u1.spellings, u2.spellings), + u1.num_accounts + u2.num_accounts, + u1.num_refs + u2.num_refs, + u1.num_lib_accounts + u2.num_lib_accounts, + u1.num_lib_refs + u2.num_lib_refs, + u1.num_cons_accounts + u2.num_cons_accounts, + u1.num_cons_refs + u2.num_cons_refs, + u1.num_refs_ideo_weighted + u2.num_refs_ideo_weighted, + u1.all_text ++ u2.all_text) + } + } + + implicit val political_feature = + mkCaseWireFormat(PoliticalFeature.apply _, PoliticalFeature.unapply _) + + class FindPoliticalDriver(opts: FindPoliticalParams) + extends FindPoliticalAction(opts) { + val operation_category = "Driver" + + /** + * Read the set of ideological accounts. Create a "Politico" object for + * each such account, and return a map from a normalized (lowercased) + * version of each account to the corresponding Politico object (which + * may refer to multiple accounts). + */ + def read_ideological_accounts(filename: String) = { + val politico = + """^([^ .]+)\. (.*?), (.*?) (-+ |(?:@[^ ]+ )+)([RDI?]) \((.*)\)$""".r + val all_accounts = + // Open the file and read line by line. + for ((line, lineind) <- (new LocalFileHandler).openr(filename).zipWithIndex + // Skip comments and blank lines + if !line.startsWith("#") && !(line.trim.length == 0)) yield { + lineno = lineind + 1 + line match { + // Match the line. + case politico(title, last, first, accountstr, party, where) => { + // Compute the list of normalized accounts. + val accounts = + if (accountstr.startsWith("-")) Seq[String]() + // `tail` removes the leading @; lowercase to normalize + else accountstr.split(" ").map(_.tail.toLowerCase).toSeq + val obj = Politico(last, first, title, party, where, accounts) + for (account <- accounts) yield (account, obj) + } + case _ => { + warning(line, "Unable to match") + Seq[(String, Politico)]() + } + } + } + lineno = 0 + // For each account read in, we generated multiple pairs; flatten and + // convert to a map. Reverse because the last of identical keys will end + // up in the map but we want the first one taken. + all_accounts.flatten.toSeq.reverse.toMap + } + + /** + * Convert map of accounts->politicos to accounts->ideology + */ + def politico_accounts_map_to_ideo_users_map( + accounts: Map[String, Politico]) = { + accounts. + filter { case (string, politico) => "DR".contains(politico.party) }. + map { case (string, politico) => + (string, politico.party match { case "D" => 0.0; case "R" => 1.0 }) } + } + + /* + 2. We go through users looking for references to these politicians. For + users that reference politicians, we can compute an "ideology" score of + the user by a weighted average of the references by the ideology of + the politicians. + 3. For each such user, look at all other people referenced -- the idea is + we want to look for people referenced a lot especially by users with + a consistent ideology (e.g. Glenn Beck or Rush Limbaugh for + conservatives), which we can then use to mark others as having a + given ideology. For each person, we generate a record with their + name, the number of times they were referenced and an ideology score + and merge these all together. + */ + } + + def create_params(ap: ArgParser) = new FindPoliticalParams(ap) + val progname = "FindPolitical" + + def run() { + // For testing + // errprint("Calling error_wrap ...") + // error_wrap(1,0) { _ / 0 } + val opts = init_scoobi_app() + /* + We are doing the following: + + 1. We are given a list of known politicians, their twitter accounts, and + their ideology -- either determined simply by their party, or using + the DW-NOMINATE score or similar. + 2. We go through users looking for references to these politicians. For + users that reference politicians, we can compute an "ideology" score of + the user by a weighted average of the references by the ideology of + the politicians. + 3. For each such user, look at all other people referenced -- the idea is + we want to look for people referenced a lot especially by users with + a consistent ideology (e.g. Glenn Beck or Rush Limbaugh for + conservatives), which we can then use to mark others as having a + given ideology. For each person, we generate a record with their + name, the number of times they were referenced and an ideology score + and merge these all together. + */ + val ptp = new FindPoliticalDriver(opts) + val filehand = new HadoopFileHandler(configuration) + if (opts.corpus_name == null) { + val (_, last_component) = filehand.split_filename(opts.input) + opts.corpus_name = last_component.replace("*", "_") + } + var accounts: Map[String, Double] = + if (opts.political_twitter_accounts_format == "officeholders") { + val politico_accounts = + ptp.read_ideological_accounts(opts.political_twitter_accounts) + ptp.politico_accounts_map_to_ideo_users_map(politico_accounts) + } + else { + val processor = new IdeoUserFileProcessor + processor.read_textdb(filehand, opts.political_twitter_accounts). + flatten.toMap + } + // errprint("Accounts: %s", accounts) + + val suffix = "tweets" + opts.schema = + TextDBProcessor.read_schema_from_textdb(filehand, opts.input, suffix) + + def output_directory_for_suffix(corpus_suffix: String) = + opts.output + "-" + corpus_suffix + + /** + * For the given sequence of lines and related info for writing output a corpus, + * return a tuple of two thunks: One for persisting the data, the other for + * fixing up the data into a proper corpus. + */ + def output_lines(lines: DList[String], corpus_suffix: String, + fields: Seq[String]) = { + val outdir = output_directory_for_suffix(corpus_suffix) + (TextOutput.toTextFile(lines, outdir), () => { + rename_output_files(outdir, opts.corpus_name, corpus_suffix) + output_schema_for_suffix(corpus_suffix, fields) + }) + } + + def output_schema_for_suffix(corpus_suffix: String, fields: Seq[String]) { + val outdir = output_directory_for_suffix(corpus_suffix) + val fixed_fields = + Map("corpus-name" -> opts.corpus_name, + "generating-app" -> "FindPolitical", + "corpus-type" -> "twitter-%s".format(corpus_suffix)) ++ + opts.non_default_params_string.toMap ++ + Map( + "ideological-ref-type" -> opts.ideological_ref_type, + "political-feature-type" -> "%s".format(opts.political_feature_type) + ) + val out_schema = new Schema(fields, fixed_fields) + out_schema.output_constructed_schema_file(filehand, outdir, + opts.corpus_name, corpus_suffix) + } + + var ideo_users: DList[IdeologicalUser] = null + + val ideo_fact = new IdeologicalUserAction(opts) + val matching_patterns = TextDBProcessor. + get_matching_patterns(filehand, opts.input, suffix) + val lines: DList[String] = TextInput.fromTextFile(matching_patterns: _*) + + errprint("Step 1, pass 0: %d ideological users on input", + accounts.size) + for (iter <- 1 to opts.iterations) { + errprint( + "Step 1, pass %d: Filter corpus for conservatives/liberals, compute ideology." + format iter) + val last_pass = iter == opts.iterations + ideo_users = + lines.flatMap(ideo_fact.get_ideological_user(_, accounts, last_pass)) + if (!last_pass) { + accounts = + persist(ideo_users.materialize).map(x => + (x.user.toLowerCase, x.ideology)).toMap + errprint("Step 1, pass %d: %d ideological users on input", + iter, accounts.size) + errprint("Step 1, pass %d: done." format iter) + } + } + + val (ideo_users_persist, ideo_users_fixup) = + output_lines(ideo_users.map(_.to_row(opts)), "ideo-users", + ideo_fact.row_fields) + /* This is a separate function because including it inline in the for loop + below results in a weird deserialization error. */ + def handle_political_feature_type(ty: String) = { + errprint("Step 2: Handling feature type '%s' ..." format ty) + val political_features = ideo_users. + flatMap(PoliticalFeature. + get_political_features(ideo_fact, _, ty, opts)). + groupBy(_.value). + combine(PoliticalFeature.merge_political_features). + map(_._2) + output_lines(political_features.map(_.to_row(opts)), + "political-features-%s" format ty, PoliticalFeature.row_fields) + } + + errprint("Step 2: Generate political features.") + val (ft_persists, ft_fixups) = ( + for (ty <- opts.political_feature_type) yield + handle_political_feature_type(ty) + ).unzip + persist(Seq(ideo_users_persist) ++ ft_persists) + ideo_users_fixup() + for (fixup <- ft_fixups) + fixup() + errprint("Step 1, pass %d: done." format opts.iterations) + errprint("Step 2: done.") + + finish_scoobi_app(opts) + } + /* + + To build a classifier for conserv vs liberal: + + 1. Look for people retweeting congressmen or governor tweets, possibly at + some minimum level of retweeting (or rely on followers, for some + minimum number of people following) + 2. Make sure either they predominate having retweets from one party, + and/or use the DW-NOMINATE scores to pick out people whose average + ideology score of their retweets is near the extremes. + */ +} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/FrobTextDB.scala b/src/main/scala/opennlp/fieldspring/preprocess/FrobTextDB.scala new file mode 100644 index 0000000..8c1493b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/FrobTextDB.scala @@ -0,0 +1,404 @@ +/////////////////////////////////////////////////////////////////////////////// +// FrobTextDB.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import collection.mutable + +import java.io._ + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil._ +import opennlp.fieldspring.util.MeteredTask +import opennlp.fieldspring.util.printutil.warning + +import opennlp.fieldspring.gridlocate.DistDocument + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +class FrobTextDBParameters(ap: ArgParser) extends + ProcessFilesParameters(ap) { + val input_dir = + ap.option[String]("i", "input-dir", + metavar = "DIR", + help = """Directory containing input corpus.""") + var input_suffix = + ap.option[String]("s", "input-suffix", + metavar = "DIR", + help = """Suffix used to select the appropriate files to operate on. +Defaults to 'unigram-counts' unless --convert-to-unigram-counts is given, +in which case it defaults to 'text'.""") + var output_suffix = + ap.option[String]("output-suffix", + metavar = "DIR", + help = """Suffix used when generating the output files. Defaults to +the value of --input-suffix, unless --convert-to-unigram-counts is given, +in which case it defaults to 'unigram-counts'.""") + val add_field = + ap.multiOption[String]("a", "add-field", + metavar = "FIELD=VALUE", + help = """Add a fixed field named FIELD, with the value VALUE.""") + val rename_field = + ap.multiOption[String]("rename-field", + metavar = "FIELD=NEWFIELD", + help = """Rename field FIELD to NEWFIELD.""") + val remove_field = + ap.multiOption[String]("r", "remove-field", + metavar = "FIELD", + help = """Remove a field, either from all rows (for a normal field) +or a fixed field.""") + val set_split_by_value = + ap.option[String]("set-split-by-value", + metavar = "SPLITFIELD,MAX-TRAIN-VAL,MAX-DEV-VAL", + help = """Set the "split" field to one of "training", "dev" or "test" +according to the value of another field (e.g. by time). For the field named +SPLITFIELD, values <= MAX-TRAIN-VAL go into the training split; +values <= MAX-DEV-VAL go into the dev split; and higher values go into the +test split. Comparison is lexicographically (i.e. string comparison, +rather than numeric).""") + val split_by_field = + ap.option[String]("split-by-field", + metavar = "FIELD", + help = """Divide the corpus into separate corpora according to the value +of the given field. (For example, the "split" field.) You can combine this +action with any of the other actions, and they will be done in the right +order.""") + val convert_to_unigram_counts = + ap.flag("convert-to-unigram-counts", + help = """If specified, convert the 'text' field to a 'counts' field +containing unigram counts.""") + + var split_field: String = null + var max_training_val: String = null + var max_dev_val: String = null +} + +/** + * A textdb processor that outputs fields as they come in, possibly modified + * in various ways. + * + * @param output_filehand FileHandler of the output corpus (directory is + * taken from parameters) + * @param params Parameters retrieved from the command-line arguments + */ +class FrobTextDBProcessor( + output_filehand: FileHandler, + params: FrobTextDBParameters +) extends BasicTextDBProcessor[Unit](params.input_suffix) { + val split_value_to_writer = mutable.Map[String, TextDBWriter]() + val split_value_to_outstream = mutable.Map[String, PrintStream]() + var unsplit_writer: TextDBWriter = _ + var unsplit_outstream: PrintStream = _ + + def frob_row(fieldvals: Seq[String]) = { + val docparams = mutable.LinkedHashMap[String, String]() + docparams ++= (rename_fields(schema.fieldnames) zip fieldvals) + for (field <- params.remove_field) + docparams -= field + if (params.split_field != null) { + if (docparams(params.split_field) <= params.max_training_val) + docparams("split") = "training" + else if (docparams(params.split_field) <= params.max_dev_val) + docparams("split") = "dev" + else + docparams("split") = "test" + } + if (params.convert_to_unigram_counts) { + val text = docparams("text") + docparams -= "text" + val counts = intmap[String]() + for (word <- text.split(" ", -1)) + counts(word) += 1 + val counts_text = encode_count_map(counts.toSeq) + docparams += (("counts", counts_text)) + } + docparams.toSeq + } + + def rename_fields(fieldnames: Seq[String]) = { + for (field <- fieldnames) yield { + var f = field + for (rename_field <- params.rename_field) { + val Array(oldfield, newfield) = rename_field.split("=", 2) + if (f == oldfield) + f = newfield + } + f + } + } + + def modify_fixed_values(fixed_values: Map[String, String]) = { + var (names, values) = fixed_values.toSeq.unzip + var new_fixed_values = (rename_fields(names) zip values).toMap + for (field <- params.remove_field) + new_fixed_values -= field + val new_fields = + for (add_field <- params.add_field) yield { + val Array(field, value) = add_field.split("=", 2) + (field -> value) + } + new_fixed_values ++= new_fields + new_fixed_values + } + + /** + * Find the writer and output stream for the given frobbed document, + * creating one or both as necessary. There will be one writer overall, + * and one output stream per input file. + */ + def get_unsplit_writer_and_outstream(fieldnames: Seq[String], + fieldvals: Seq[String]) = { + if (unsplit_writer == null) { + /* Construct a new schema. Create a new writer for this schema; + write the schema out; and record the writer in + `unsplit_writer`. + */ + val new_schema = + new Schema(fieldnames, modify_fixed_values(schema.fixed_values)) + unsplit_writer = new TextDBWriter(new_schema, params.output_suffix) + unsplit_writer.output_schema_file(output_filehand, params.output_dir, + schema_prefix) + } + if (unsplit_outstream == null) + unsplit_outstream = unsplit_writer.open_document_file(output_filehand, + params.output_dir, current_document_prefix) + (unsplit_writer, unsplit_outstream) + } + + /** + * Find the writer and output stream for the given frobbed document, + * based on the split field, creating one or both as necessary. + * There will be one writer for each possible split value, and one output + * stream per split value per input file. + */ + def get_split_writer_and_outstream(fieldnames: Seq[String], + fieldvals: Seq[String]) = { + val split = + Schema.get_field_or_else(fieldnames, fieldvals, params.split_by_field) + if (split == null) + (null, null) + else { + if (!(split_value_to_writer contains split)) { + /* Construct a new schema where the split has been moved into a + fixed field. Create a new writer for this schema; write the + schema out; and record the writer in `split_value_to_writer`. + */ + val field_map = Schema.to_map(fieldnames, fieldvals) + field_map -= params.split_by_field + val (new_fieldnames, new_fieldvals) = Schema.from_map(field_map) + val new_fixed_values = + schema.fixed_values + (params.split_by_field -> split) + val new_schema = + new Schema(new_fieldnames, modify_fixed_values(new_fixed_values)) + val writer = new TextDBWriter(new_schema, params.output_suffix) + writer.output_schema_file(output_filehand, params.output_dir, + schema_prefix + "-" + split) + split_value_to_writer(split) = writer + } + if (!(split_value_to_outstream contains split)) { + val writer = split_value_to_writer(split) + split_value_to_outstream(split) = + writer.open_document_file(output_filehand, params.output_dir, + current_document_prefix + "-" + split) + } + (split_value_to_writer(split), split_value_to_outstream(split)) + } + } + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + val task = new MeteredTask("document", "frobbing") + for (line <- lines) { + task.item_processed() + parse_row(line) + } + task.finish() + (true, ()) + } + + override def end_process_file(filehand: FileHandler, file: String) { + /* Close the output stream(s), clearing the appropriate variable(s) so + that the necessary stream(s) will be re-created again for the + next input file. + */ + if (params.split_by_field != null) { + for (outstream <- split_value_to_outstream.values) + outstream.close() + split_value_to_outstream.clear() + } else { + unsplit_outstream.close() + unsplit_outstream = null + } + super.end_process_file(filehand, file) + } + + def process_row(fieldvals: Seq[String]) = { + val (new_fieldnames, new_fieldvals) = frob_row(fieldvals).unzip + if (params.split_by_field != null) { + val (writer, outstream) = + get_split_writer_and_outstream(new_fieldnames, new_fieldvals) + if (writer == null) { + warning("Skipped row because can't find split field: %s", + new_fieldvals mkString "\t") + } else { + /* Remove the split field from the output, since it's constant + for all rows and is moved to the fixed fields */ + val field_map = Schema.to_map(new_fieldnames, new_fieldvals) + field_map -= params.split_by_field + val (nosplit_fieldnames, nosplit_fieldvals) = Schema.from_map(field_map) + assert(nosplit_fieldnames == writer.schema.fieldnames, + "resulting fieldnames %s should be same as schema fieldnames %s" + format (nosplit_fieldnames, writer.schema.fieldnames)) + writer.schema.output_row(outstream, nosplit_fieldvals) + } + } else { + val (writer, outstream) = + get_unsplit_writer_and_outstream(new_fieldnames, new_fieldvals) + writer.schema.output_row(outstream, new_fieldvals) + } + (true, true) + } +} + +class FrobTextDBDriver extends + ProcessFilesDriver with StandaloneExperimentDriverStats { + type TParam = FrobTextDBParameters + + override def handle_parameters() { + need(params.input_dir, "input-dir") + if (params.set_split_by_value != null) { + val Array(split_field, training_max, dev_max) = + params.set_split_by_value.split(",") + params.split_field = split_field + params.max_training_val = training_max + params.max_dev_val = dev_max + } + if (params.input_suffix == null) + params.input_suffix = + if (params.convert_to_unigram_counts) "text" + else "unigram-counts" + if (params.output_suffix == null) + params.output_suffix = + if (params.convert_to_unigram_counts) "unigram-counts" + else params.input_suffix + super.handle_parameters() + } + + override def run_after_setup() { + super.run_after_setup() + + val filehand = get_file_handler + val fileproc = + new FrobTextDBProcessor(filehand, params) + fileproc.read_schema_from_textdb(filehand, params.input_dir) + fileproc.process_files(filehand, Seq(params.input_dir)) + } +} + +object FrobTextDB extends + ExperimentDriverApp("FrobTextDB") { + type TDriver = FrobTextDBDriver + + override def description = +"""Modify a corpus by changing particular fields. Fields can be added +(--add-field) or removed (--remove-field); the "split" field can be +set based on the value of another field (--set-split-by-value); +the corpus can be changed from text to unigram counts +(--convert-to-unigram-counts); and it can be divided into sub-corpora +based on the value of a field, e.g. "split" (--split-by-field). +""" + + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver +} + +//class ScoobiConvertTextToUnigramCountsDriver extends +// BaseConvertTextToUnigramCountsDriver { +// +// def usage() { +// sys.error("""Usage: ScoobiConvertTextToUnigramCounts [-o OUTDIR | --outfile OUTDIR] [--group-by-user] INFILE ... +// +//Using Scoobi (front end to Hadoop), convert input files from raw-text format +//(one document per line) into unigram counts, in the format expected by +//Fieldspring. OUTDIR is the directory to store the results in, which must not +//exist already. If --group-by-user is given, a document is the concatenation +//of all tweets for a given user. Else, each individual tweet is a document. +//""") +// } +// +// def process_files(filehand: FileHandler, files: Seq[String]) { +// val task = new MeteredTask("tweet", "processing") +// var tweet_lineno = 0 +// val task2 = new MeteredTask("user", "processing") +// var user_lineno = 0 +// val out_counts_name = "%s/counts-only-coord-documents.txt" format (params.output_dir) +// errprint("Counts output file is %s..." format out_counts_name) +// val tweets = extractFromDelimitedTextFile("\t", params.files(0)) { +// case user :: title :: text :: Nil => (user, text) +// } +// val counts = tweets.groupByKey.map { +// case (user, tweets) => { +// val counts = intmap[String]() +// tweet_lineno += 1 +// for (tweet <- tweets) { +// val words = tweet.split(' ') +// for (word <- words) +// counts(word) += 1 +// } +// val result = +// (for ((word, count) <- counts) yield "%s:%s" format (word, count)). +// mkString(" ") +// (user, result) +// } +// } +// +// // Execute everything, and throw it into a directory +// DList.persist ( +// TextOutput.toTextFile(counts, out_counts_name) +// ) +// } +//} + +//abstract class ScoobiApp( +// progname: String +//) extends ExperimentDriverApp(progname) { +// override def main(orig_args: Array[String]) = withHadoopArgs(orig_args) { +// args => { +// /* Thread execution back to the ExperimentDriverApp. This will read +// command-line arguments, call initialize_parameters() to verify +// and canonicalize them, and then pass control back to us by +// calling run_program(), which we override. */ +// set_errout_prefix(progname + ": ") +// implement_main(args) +// } +// } +//} +// +//object ScoobiConvertTextToUnigramCounts extends +// ScoobiApp("Convert raw text to unigram counts") { +// type TDriver = ScoobiConvertTextToUnigramCountsDriver +// def create_param_object(ap: ArgParser) = new TParam(ap) +// def create_driver() = new TDriver +//} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/MergeMetadataAndOldCounts.scala b/src/main/scala/opennlp/fieldspring/preprocess/MergeMetadataAndOldCounts.scala new file mode 100644 index 0000000..02926dd --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/MergeMetadataAndOldCounts.scala @@ -0,0 +1,256 @@ +/////////////////////////////////////////////////////////////////////////////// +// MergeMetadataAndOldCounts.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import util.matching.Regex +import util.control.Breaks._ +import collection.mutable + +import java.io.{InputStream, PrintStream} + +import opennlp.fieldspring.worddist.IdentityMemoizer._ +import opennlp.fieldspring.gridlocate.DistDocument + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil.DynamicArray +import opennlp.fieldspring.util.textdbutil._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil._ +import opennlp.fieldspring.util.MeteredTask +import opennlp.fieldspring.util.printutil.{errprint, warning} + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +class MMCParameters(ap: ArgParser) extends + ArgParserParameters(ap) { + val output_dir = + ap.option[String]("o", "output-dir", + metavar = "DIR", + help = """Directory to store output files in.""") + val input_dir = + ap.option[String]("i", "input-dir", + metavar = "FILE", + help = """Directory containing the input corpus, in old format.""") + val counts_file = + ap.option[String]("counts-file", + metavar = "FILE", + help = """File containing the word counts, in old format.""") + var output_file_prefix = + ap.option[String]("output-file-prefix", + metavar = "FILE", + help = """Prefix to add to files in the output corpus dir.""") +} + +/** + * A simple reader for word-count files in the old multi-line count format. + * This is stripped of debugging code, code to handle case-merging, stopwords + * and other modifications of the distributions, etc. All it does is + * read the distributions and pass them to `handle_document`. + */ + +trait SimpleUnigramWordDistConstructor { + val initial_dynarr_size = 1000 + val keys_dynarr = + new DynamicArray[Word](initial_alloc = initial_dynarr_size) + val values_dynarr = + new DynamicArray[Int](initial_alloc = initial_dynarr_size) + + def handle_document(title: String, keys: Array[Word], values: Array[Int], + num_words: Int): Boolean + + def read_word_counts(filehand: FileHandler, filename: String) { + errprint("Reading word counts from %s...", filename) + errprint("") + + var title: String = null + + // Written this way because there's another line after the for loop, + // corresponding to the else clause of the Python for loop + breakable { + for (line <- filehand.openr(filename)) { + if (line.startsWith("Article title: ")) { + if (title != null) { + if (!handle_document(title, keys_dynarr.array, values_dynarr.array, + keys_dynarr.length)) + break + } + // Extract title and set it + val titlere = "Article title: (.*)$".r + line match { + case titlere(ti) => title = ti + case _ => assert(false) + } + keys_dynarr.clear() + values_dynarr.clear() + } else if (line.startsWith("Article coordinates) ") || + line.startsWith("Article ID: ")) { + } else { + val linere = "(.*) = ([0-9]+)$".r + line match { + case linere(word, count) => { + // errprint("Saw1 %s,%s", word, count) + keys_dynarr += memoize_string(word) + values_dynarr += count.toInt + } + case _ => + warning("Strange line, can't parse: title=%s: line=%s", + title, line) + } + } + } + if (!handle_document(title, keys_dynarr.array, values_dynarr.array, + keys_dynarr.length)) + break + } + } +} + +/** + * A simple factory for reading in the unigram distributions in the old + * format. + */ +class MMCUnigramWordDistHandler( + schema: Schema, + document_fieldvals: mutable.Map[String, Seq[String]], + filehand: FileHandler, + output_dir: String, + output_file_prefix: String +) extends SimpleUnigramWordDistConstructor { + val new_schema = new Schema(schema.fieldnames ++ Seq("counts"), + schema.fixed_values) + val writer = new TextDBWriter(new_schema, "unigram-counts") + writer.output_schema_file(filehand, output_dir, output_file_prefix) + val outstream = writer.open_document_file(filehand, output_dir, + output_file_prefix, compression = "bzip2") + + def handle_document(title: String, keys: Array[Word], values: Array[Int], + num_words: Int) = { + errprint("Handling document: %s", title) + val params = document_fieldvals.getOrElse(title, null) + if (params == null) + warning("Strange, can't find document %s in document file", title) + else { + val counts = + (for (i <- 0 until num_words) yield { + // errprint("Saw2 %s,%s", keys(i), values(i)) + ("%s:%s" format + (encode_string_for_count_map_field(unmemoize_string(keys(i))), + values(i))) + }). + mkString(" ") + val new_params = params ++ Seq(counts) + writer.schema.output_row(outstream, new_params) + } + true + } + + def finish() { + outstream.close() + } +} + +/** + * A simple field-text file processor that just records the documents read in, + * by title. + * + * @param suffix Suffix used to select document metadata files in a directory + */ +class MMCDocumentFileProcessor( + suffix: String +) extends BasicTextDBProcessor[Unit](suffix) { + val document_fieldvals = mutable.Map[String, Seq[String]]() + + def process_row(fieldvals: Seq[String]) = { + val params = (schema.fieldnames zip fieldvals).toMap + document_fieldvals(params("title")) = fieldvals + (true, true) + } + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + val task = new MeteredTask("document", "reading") + for (line <- lines) { + task.item_processed() + parse_row(line) + } + task.finish() + (true, ()) + } +} + +class MMCDriver extends ArgParserExperimentDriver { + type TParam = MMCParameters + type TRunRes = Unit + + val filehand = new LocalFileHandler + + def usage() { + sys.error("""Usage: MergeMetadataAndOldCounts [-o OUTDIR | --outfile OUTDIR] [--output-stats] INFILE ... + +Merge document-metadata files and old-style counts files into a new-style +counts file also containing the metadata. +""") + } + + def handle_parameters() { + need(params.output_dir, "output-dir") + need(params.input_dir, "input-dir") + need(params.counts_file, "counts-file") + } + + def setup_for_run() { } + + def run_after_setup() { + if (!filehand.make_directories(params.output_dir)) + param_error("Output dir %s must not already exist" format + params.output_dir) + + val fileproc = + new MMCDocumentFileProcessor(document_metadata_suffix) + fileproc.read_schema_from_textdb(filehand, params.input_dir) + + if (params.output_file_prefix == null) { + var (_, base) = filehand.split_filename(fileproc.schema_file) + params.output_file_prefix = base.replaceAll("-[^-]*$", "") + params.output_file_prefix = + params.output_file_prefix.stripSuffix("-document-metadata") + errprint("Setting new output-file prefix to '%s'", + params.output_file_prefix) + } + + fileproc.process_files(filehand, Seq(params.input_dir)) + + val counts_handler = + new MMCUnigramWordDistHandler(fileproc.schema, + fileproc.document_fieldvals, filehand, params.output_dir, + params.output_file_prefix) + counts_handler.read_word_counts(filehand, params.counts_file) + counts_handler.finish() + } +} + +object MergeMetadataAndOldCounts extends + ExperimentDriverApp("Merge document metadata files and old counts file") { + type TDriver = MMCDriver + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/OldGroupCorpus.scala b/src/main/scala/opennlp/fieldspring/preprocess/OldGroupCorpus.scala new file mode 100644 index 0000000..9548d14 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/OldGroupCorpus.scala @@ -0,0 +1,382 @@ +/////////////////////////////////////////////////////////////////////////////// +// GroupCorpus.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import collection.mutable +import collection.JavaConversions._ + +import java.io._ + +import org.apache.hadoop.io._ +import org.apache.hadoop.util._ +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.conf.{Configuration, Configured} +import org.apache.hadoop.fs._ + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.hadoop._ +import opennlp.fieldspring.util.ioutil._ +import opennlp.fieldspring.util.mathutil.mean +import opennlp.fieldspring.util.MeteredTask +import opennlp.fieldspring.util.printutil.warning + +import opennlp.fieldspring.gridlocate.DistDocument + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// +/* + +We have a corpus consisting of a single directory and suffix identifying +the corpus. In the corpus is a schema file plus one or more document +files. We want to group all documents together that have the same value +for a particular field, usually the 'user' field (i.e. group together +all documents with the same user). + +NOTE: Do we want to do this before or after splitting on train/test/dev? +If we do this after, we end up with up to three documents per user -- one +per each split. Because of the way we chose the splits, they will be +offset in time. This could definitely be a useful test. If we do this +before, we end up with only one document per user, and then have to split +the users, perhaps randomly, perhaps by (average?) time as well. + +Note that no changes are needed in the coding of this program to make it work +regardless of whether we end up splitting before or after grouping. +When we split before grouping, we run the grouping program separately on +each subcorpus; when we split after, we run on the whole corpus. + +What about other fields? The other fields in the Infochimps corpus are as +follows, with a description of what to do in each case. + +1. 'id' = Tweet ID: Throw it away. +2. 'coord' = Coordinate: Take some sort of average. It might be enough + to simply average the latitude and longitudes. But we need to be careful + averaging longitudes, because they are cyclic (wrap around); e.g. +179 + longitude and -177 longitude are only 4 degrees apart in longitude; hence + the average should obviously be -179. However, a simple-minded averaging + produces +1 longitude, which is halfway around the world from either of + the original values (and exactly 180 degrees away from the correct value). + The standard rule in this case is to take advantage of modular arithmetic + to place the two values as close as possible -- which in all cases will + be no more than 180 degrees -- then average, and finally readjust to be + within the preferred range. In the example above, we could average + correctly by adding 360 to -177 to become +183; then the difference + between the the two values is only 4. The average is +181, which can be + readjusted to fall within range by subtracting 360, resulting in -179. + When there are more than two points, it's necessary to choose the + representation of the points that minimizes the distance between the + minimum and maximum. We probably also want to use some sort of outlier + detection mechanism and reject the outliers entirely, so they don't skew + the results; or choose the median instead of mean. The correct algorithms + for all this might be tricky, but are critical because the coordinate is + the basis of geolocation. Surely there must already exist algorithms to + do this in an efficient and robust manner, and surely there must already + exist implementations of these algorithms. For example, what did + Eisenstein et al. (2010) do to produce their GeoText corpus? They surely + struggled with this issue. + + Note that it seems that, in the majority of cases where all the values + are fairly close together, the maximum distance between points should be + less than 180 degrees. If so, we can number the points using both a + [-180,+180) scale and a [0, 360) scale, sort in each case and see what + the difference between minimum and maximum is. Choose the smaller of + the differences, and if it's <= 180, we know this is the correct scale, + and we can average using this scale. If both are > 180, it might work + as well in all cases to select the smaller one, but I'm not sure; that + would have to be proven. + +3. 'time' = Time: Either throw away, or produce an average, standard deviation, + min and max. +4. 'username' = There will be only one. +5. 'userid' = There may be more than one; if so, list any that get at least 25% + of the total, in order starting from the most common (break ties randomly). + Include 'other' as an additional possibility if all non-included values + together include at least 25% of the total, and sort 'other' along with the + others according to its percent. +6. 'reply_username', 'reply_userid' = Same as for 'userid' +7. 'anchor' = Same as for 'userid' +8. 'lang' = Same as for 'userid' +9. 'counts' = Combine counts according to keys, adding the values of cases + where the same key occurs in both records. +10. 'text' = Concatenate the text of the Tweets. It's probably good to + put an -EOT- ("end of Tweet") token after each Tweet so that they can + still be distinguished when concatenated together. + +So, we receive in the mapper single documents. We extract the group and +the remaining fields and output a pair where the key is the group (typically, +the user name), and the value is the remaining fields. The reducer then needs +to reduce the whole set of records down to a single one using the instructions +above, and output a single combined text record. Possibly some of the work +could also be done in a combiner, but it's probably not worth the extra effort +to construct such a combiner, since we'd have to complexify the intermediate +format away from text to some kind of binary format. + +*/ + +class GroupCorpusParameters(ap: ArgParser) extends + ArgParserParameters(ap) { + val input_dir = + ap.option[String]("i", "input-dir", + metavar = "DIR", + help = """Directory containing input corpus.""") + var input_suffix = + ap.option[String]("s", "input-suffix", + metavar = "DIR", + default = "unigram-counts", + help = """Suffix used to select the appropriate files to operate on. +Defaults to '%default'.""") + var output_suffix = + ap.option[String]("output-suffix", + metavar = "DIR", + help = """Suffix used when generating the output files. Defaults to +the value of --input-suffix.""") + val output_dir = + ap.positional[String]("output-dir", + help = """Directory to store output files in. It must not already +exist, and will be created (including any parent directories).""") + val field = + ap.option[String]("f", "field", + default = "username", + help = """Field to group on; default '%default'.""") +} + +class GroupCorpusDriver extends + HadoopableArgParserExperimentDriver with HadoopExperimentDriver { + type TParam = GroupCorpusParameters + type TRunRes = Unit + + def handle_parameters() { + need(params.input_dir, "input-dir") + if (params.output_suffix == null) + params.output_suffix = params.input_suffix + } + def setup_for_run() { } + + def run_after_setup() { } +} + +/** + * This file processor, the GroupCorpusMapReducer trait below and the + * subclass of this file processor together are a lot of work simply to read + * the schema from the input corpus in both the mapper and reducer. Perhaps + * we could save the schema and pass it to the reducer as the first item? + * Perhaps that might not make the code any simpler. + */ +class GroupCorpusFileProcessor( + context: TaskInputOutputContext[_,_,_,_], + driver: GroupCorpusDriver +) extends BasicTextDBProcessor[Unit](driver.params.input_suffix) { + def process_row(fieldvals: Seq[String]): (Boolean, Boolean) = + throw new IllegalStateException("This shouldn't be called") + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = + throw new IllegalStateException("This shouldn't be called") +} + +trait GroupCorpusMapReducer extends HadoopExperimentMapReducer { + def progname = GroupCorpus.progname + type TDriver = GroupCorpusDriver + // more type erasure crap + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver + def create_processor(context: TContext) = + new GroupCorpusFileProcessor(context, driver) + + var processor: GroupCorpusFileProcessor = _ + override def init(context: TContext) { + super.init(context) + processor = create_processor(context) + processor.read_schema_from_textdb(driver.get_file_handler, + driver.params.input_dir) + context.progress + } +} + +class GroupCorpusMapper extends + Mapper[Object, Text, Text, Text] with GroupCorpusMapReducer { + type TContext = Mapper[Object, Text, Text, Text]#Context + + class GroupCorpusMapFileProcessor( + map_context: TContext, + driver: GroupCorpusDriver + ) extends GroupCorpusFileProcessor(map_context, driver) { + override def process_row(fieldvals: Seq[String]) = { + map_context.write( + new Text(schema.get_field(fieldvals, driver.params.field)), + new Text(fieldvals.mkString("\t"))) + (true, true) + } + } + + override def create_processor(context: TContext) = + new GroupCorpusMapFileProcessor(context, driver) + + override def setup(context: TContext) { init(context) } + + override def map(key: Object, value: Text, context: TContext) { + processor.parse_row(value.toString) + context.progress + } +} + +class GroupCorpusReducer extends + Reducer[Text, Text, Text, NullWritable] with GroupCorpusMapReducer { + type TContext = Reducer[Text, Text, Text, NullWritable]#Context + + def average_coords(coords: Iterable[(Double, Double)]) = { + val (lats, longs) = coords.unzip + // FIXME: Insufficiently correct! + "%s,%s" format (mean(lats.toSeq), mean(longs.toSeq)) + } + + def compute_most_common(values: mutable.Buffer[String]) = { + val minimum_percent = 0.25 + val len = values.length + val counts = intmap[String]() + for (v <- values) + counts(v) += 1 + val other = counts.values.filter(_.toDouble/len < minimum_percent).sum + counts("other") = other + val sorted = counts.toSeq.sortWith(_._2 > _._2) + (for ((v, count) <- sorted if count.toDouble/len >= minimum_percent) + yield v).mkString(",") + } + + def combine_text(values: mutable.Buffer[String]) = { + values.mkString(" EOT ") + " EOT " + } + + def combine_counts(values: mutable.Buffer[String]) = { + val counts = intmap[String]() + val split_last_colon = "(.*):([^:]*)".r + for (vv <- values; v <- vv.split(" ")) { + v match { + case split_last_colon(word, count) => counts(word) += count.toInt + case _ => + warning("Saw bad item in counts field, without a colon in it: %s", + v) + } + } + (for ((word, count) <- counts) yield "%s:%s" format (word, count)). + mkString(" ") + } + + override def setup(context: TContext) { init(context) } + + override def reduce(key: Text, values: java.lang.Iterable[Text], + context: TContext) { + var num_tweets = 0 + val coords = mutable.Buffer[(Double, Double)]() + val times = mutable.Buffer[Long]() + val userids = mutable.Buffer[String]() + val reply_usernames = mutable.Buffer[String]() + val reply_userids = mutable.Buffer[String]() + val anchors = mutable.Buffer[String]() + val langs = mutable.Buffer[String]() + val countses = mutable.Buffer[String]() + val texts = mutable.Buffer[String]() + for (vv <- values) { + num_tweets += 1 + val fieldvals = vv.toString.split("\t") + for (v <- processor.schema.fieldnames zip fieldvals) { + v match { + case ("coord", coord) => { + val Array(lat, long) = coord.split(",") + coords += ((lat.toDouble, long.toDouble)) + } + case ("time", time) => times += time.toLong + case ("userid", userid) => userids += userid + case ("reply_username", reply_username) => + reply_usernames += reply_username + case ("reply_userid", reply_userid) => + reply_userids += reply_userid + case ("anchor", anchor) => anchors += anchor + case ("lang", lang) => langs += lang + case ("counts", counts) => countses += counts + case ("text", text) => texts += text + case (_, _) => { } + } + } + } + val output = mutable.Buffer[(String, String)]() + for (field <- processor.schema.fieldnames) { + field match { + case "coord" => output += (("coord", average_coords(coords))) + case "time" => { + output += (("mintime", times.min.toString)) + output += (("maxtime", times.max.toString)) + output += + (("avgtime", mean(times.map(_.toDouble).toSeq).toLong.toString)) + } + case "userid" => output += (("userid", compute_most_common(userids))) + case "reply_userid" => + output += (("reply_userid", compute_most_common(reply_userids))) + case "reply_username" => + output += (("reply_username", compute_most_common(reply_usernames))) + case "anchor" => output += (("anchor", compute_most_common(anchors))) + case "lang" => output += (("lang", compute_most_common(langs))) + case "counts" => output += (("counts", combine_counts(countses))) + case "text" => output += (("text", combine_text(texts))) + case _ => { } + } + } + val (outkeys, outvalues) = output.unzip + context.write(new Text("%s\t%s\t%s" format (key.toString, num_tweets, + outvalues.mkString("\t"))), null) + } +} + +object GroupCorpus extends + ExperimentDriverApp("GroupCorpus") with HadoopTextDBApp { + type TDriver = GroupCorpusDriver + + override def description = +"""Group rows in a corpus according to the value of a field (e.g. the "user" +field). The "text" and "counts" fields are combined appropriately. +""" + + // FUCKING TYPE ERASURE + def create_param_object(ap: ArgParser) = new TParam(ap) + def create_driver() = new TDriver() + + def corpus_suffix = driver.params.input_suffix + def corpus_dirs = Seq(driver.params.input_dir) + + override def initialize_hadoop_input(job: Job) { + super.initialize_hadoop_input(job) + FileOutputFormat.setOutputPath(job, new Path(driver.params.output_dir)) + } + + def initialize_hadoop_classes(job: Job) { + job.setJarByClass(classOf[GroupCorpusMapper]) + job.setMapperClass(classOf[GroupCorpusMapper]) + job.setReducerClass(classOf[GroupCorpusReducer]) + job.setOutputKeyClass(classOf[Text]) + job.setOutputValueClass(classOf[NullWritable]) + job.setMapOutputValueClass(classOf[Text]) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ParseTweets.scala b/src/main/scala/opennlp/fieldspring/preprocess/ParseTweets.scala new file mode 100644 index 0000000..e058b2a --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ParseTweets.scala @@ -0,0 +1,2222 @@ +// ParseTweets.scala +// +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +/* + * This program reads in tweets and processes them, optionally filtering + * and/or grouping them. Normally, tweets are read in JSON format, as + * directly pulled from the twitter API, and output in "textdb" format, + * i.e. as a simple TAB-separated database with one record per line and + * a separate schema file specifying the column names. However, other + * input and output formats can be used. The output normally includes the + * text of the tweets as well as unigrams and/or n-grams. + * + * See README.txt in the top-level directory for more info. + */ + +import collection.JavaConversions._ +import collection.mutable + +import java.io._ +import java.lang.Double.isNaN +import Double.NaN +import java.text.{SimpleDateFormat, ParseException} +import java.util.Date + +import net.liftweb +import org.apache.commons.logging +import org.apache.hadoop.fs.{FileSystem=>HFileSystem,_} + +import com.nicta.scoobi.Scoobi._ + +import opennlp.fieldspring.{util => tgutil} +import tgutil.Twokenize +import tgutil.argparser._ +import tgutil.collectionutil._ +import tgutil.textdbutil._ +import tgutil.ioutil.FileHandler +import tgutil.hadoop.HadoopFileHandler +import tgutil.printutil._ +import tgutil.textutil.with_commas +import tgutil.timeutil._ + +class ParseTweetsParams(ap: ArgParser) extends + ScoobiProcessFilesParams(ap) { + var grouping = ap.option[String]("grouping", "g", + default = "none", + choices = Seq("user", "time", "file", "none"), + help="""Mode for grouping tweets in the output. There are currently + four methods of grouping: `user`, `time` (i.e. all tweets within a given + timeslice, specified with `--timeslice`), `file` (all tweets within a + given input file) and `none` (no grouping; tweets are passed through + directly, after duplicated tweets have been removed). Default is + `%default`. See also `--filter-grouping`.""") + var filter_grouping = ap.option[String]("filter-grouping", "fg", + choices = Seq("user", "time", "file", "none"), + help="""Mode for grouping tweets for filtering by group (using + `--filter-groups` or `--cfilter-groups`). The possible modes are the + same as in `--grouping`. The default is the same as `--grouping`, but + it is possible to specify a different type of grouping when applying the + group-level filters.""") + var input_format = ap.option[String]("input-format", "if", + choices = Seq("json", "textdb", "raw-lines"), + default = "json", + help="""Format for input of tweets. Possibilities are + + -- `json` (Read in JSON-formatted tweets.) + + -- `textdb` (Read in data in textdb format, i.e. as a simple database + with one record per line, fields separated by TAB characters, and a + separate schema file indicating the names of the columns. Typically this + will result from a previous run of ParseTweets using the default textdb + output format.) + + -- `raw-lines` (Read in lines of raw text and treat them as "tweets". The + tweets will have no useful information in them except for the text and + file name, but this can still be useful for things like generating n-grams.) + """) + var output_format = ap.option[String]("output-format", "of", + choices = Seq("textdb", "stats", "json"), + default = "textdb", + help="""Format for output of tweets or tweet groups. Possibilities are + + -- `textdb` (Store in textdb format, i.e. as a simple database with one + record per line, fields separated by TAB characters, and a separate schema + file indicating the names of the columns.) + + -- `json` (Simply output JSON-formatted tweets directly, exactly as + received.) + + -- `stats` (Textdb-style output with statistics on the tweets, users, etc. + rather than outputting the tweets themselves.)""") + var corpus_name = ap.option[String]("corpus-name", + help="""Name of output corpus; for identification purposes. + Default to name taken from input directory.""") + var split = ap.option[String]("split", default = "training", + help="""Split (training, dev, test) to place data in. Default %default.""") + var timeslice_float = ap.option[Double]("timeslice", "time-slice", + default = 6.0, + help="""Number of seconds per timeslice when `--grouping=time`. + Can be a fractional number. Default %default.""") + // The following is set based on --timeslice + var timeslice: Long = _ + var filter_tweets = ap.option[String]("filter-tweets", + help="""Boolean expression used to filter tweets to be output. +Expression consists of one or more expressions, joined by the operators +AND, OR and NOT (which must be written all-caps to be recognized). The +order of precedence (from high to low) is + +-- comparison operators (<, <=, >, >=, WITHIN) +-- NOT +-- AND +-- OR + +Parentheses can be used for grouping or precedence. Expressions consist of +one of the following: + +-- A sequence of words, which matches a tweet if and only if that exact + sequence is found in the tweet. Matching happens on the word-by-word + level, after a tweet has been tokenized. Matching is case-insensitive; + use '--cfilter-tweets' for case-sensitive matching. Note that the use of + '--preserve-case' has no effect on the case sensitivity of filtering; + it rather affects whether the output is converted to lowercase or + left as-is. Any word that is quoted is treated as a literal + regardless of the characters in it; this can be used to treat words such + as "AND" literally. + +-- An expression specifying a one-sided restriction on the time of the tweet, + such as 'TIME < 20100802180502PDT' (earlier than August 2, 2010, 18:05:02 + Pacific Daylight Time) or 'TIME >= 2011:12:25:0905pm (at least as late as + December 25, 2011, 9:05pm local time). The operators can be <, <=, > or >=. + As for the time, either 12-hour or 24-hour time can be given, colons can + optionally be inserted anywhere for readability, the time zone can be + omitted or specified, and part or all of the time of day (hours, minutes, + seconds) can be omitted. Years must always be full (i.e. 4 digits). + +-- An expression specifying a two-sided restriction on the time of the tweet + (i.e. the tweet's time must be within a given interval). Either of the + following forms are allowed: + + -- 'TIME WITHIN 2010:08:02:1805PDT/2h' + -- 'TIME WITHIN (2010:08:02:0500pmPDT 2010:08:03:0930amPDT)' + + That is, the interval of time can be given either as a point of time plus + an offset, or as two points of time. The offset can be specified in + various ways, e.g. + + -- '1h' or '+1h' (1 hour) + -- '3m2s' or '3m+2s' or '+3m+2s' (3 minutes, 2 seconds) + -- '2.5h' (2.5 hours, i.e. 2 hours 30 minutes) + -- '5d2h30m' (5 days, 2 hours, 30 minutes) + -- '-3h' (-3 hours, i.e. 3 hours backwards from a given point of time) + -- '5d-1s' (5 days less 1 second) + + That is, an offset is a combination of individual components, each of + which is a number (possibly fractional or negative or with a prefixed + plus sign, which is ignored) plus a unit: 'd' = days, 'h' = hours, + 'm' = minutes, 's' = seconds. Negative offsets are allowed, to indicate + an interval backwards from a reference point. + +Examples: + +--filter-tweets "mitt romney OR obama" + +Look for any tweets containing the sequence "mitt romney" (in any case) or +"Obama". + +--filter-tweets "mitt AND romney OR barack AND obama" + +Look for any tweets containing either the words "mitt" and "romney" (in any +case and anywhere in the tweet) or the words "barack" and "obama". + +--filter-tweets "hillary OR bill AND clinton" + +Look for any tweets containing either the word "hillary" or both the words +"bill" and "clinton" (anywhere in the tweet). + +--filter-tweets "(hillary OR bill) AND clinton" + +Look for any tweets containing the word "clinton" as well as either the words +"bill" or "hillary".""") + var cfilter_tweets = ap.option[String]("cfilter-tweets", + help="""Boolean expression used to filter tweets to be output, with + case-sensitive matching. Format is identical to `--filter-tweets`.""") + var output_fields = ap.option[String]("output-fields", + default="default", + help="""Fields to output in textdb format. This should consist of one or + more directives, separated by spaces or commas. Directives are processed + sequentially. Each directive should be one of + + 1. A field name, meaning to include that field + + 2. A field set, meaning to include those fields; currently the recognized + sets are 'big-fields' (fields that may become arbitrarily large, + including 'user-mentions', 'retweets', 'hashtags', 'urls', 'text', + 'count') and 'small-fields' (all remaining fields). + + 3. A field name or field set with a preceding + sign, same as if the + + sign were omitted. + + 4. A field name or field set with a preceding - sign, meaning to exclude + the respective field(s). + + 5. The directive 'all', meaning to include all fields, canceling any + previous directives. + + 6. The directive 'none', meaning to include no fields, canceling any + previous directives. + + 7. The directive 'default', meaning to set the current fields to output + to the default (which may vary depending on other settings), canceling + any previous directives. + + Currently recognized fields: + + 'path': Path of file that the tweet came from + + 'user': User name + + 'id': Tweet ID + + 'min-timestamp': Earliest timestamp + + 'max-timestamp': Latest timestamp + + 'geo-timestamp': Earliest timestamp of tweet with corresponding location + + 'coord': Best latitude/longitude pair (corresponding to earliest tweet), + separated by a comma + + 'followers': Max followers + + 'following': Max following + + 'lang': Language used + + 'numtweets': Number of tweets merged + + 'user-mentions': List of @-mentions of users, along with counts + + 'retweets': List of users from which tweets were retweeted, with counts + + 'hashtags': List of hashtags, with counts + + 'urls': List of URL's, with counts + + 'text': Actual text of all tweets, separate by >> signs + + 'counts': All words, with counts + + The default is as follows: + + 1. For --input-format=raw-lines, include only 'path', 'numtweets', + 'text' and 'counts'. (Only these fields are meaningful.) + 2. Else, for --grouping=file, all small fields. + 3. Else everything. +""") + var filter_groups = ap.option[String]("filter-groups", + help="""Boolean expression used to filter on the grouped-tweet level. + This is like `--filter-tweets` but filters groups of tweets (grouped + according to `--grouping`), such that groups of tweets will be accepted + if *any* tweet matches the filter.""") + var cfilter_groups = ap.option[String]("cfilter-groups", + help="""Same as `--filter-groups` but does case-sensitive matching.""") + + var preserve_case = ap.flag("preserve-case", + help="""Don't lowercase words. This preserves the difference + between e.g. the name "Mark" and the word "mark".""") + var max_ngram = ap.option[Int]("max-ngram", "max-n-gram", "ngram", "n-gram", + default = 1, + help="""Largest size of n-grams to create. Default 1, i.e. distribution + only contains unigrams.""") + + // FIXME!! The following should really be controllable using the + // normal filter mechanism, rather than the special-case hacks below. + var geographic_only = ap.flag("geographic-only", "go", + help="""Filter out tweets that don't have a geotag.""") + var north_america_only = ap.flag("north-america-only", "nao", + help="""Filter out tweets that don't have a geotag or that have a + geotag outside of North America.""") + var filter_spammers = ap.flag("filter-spammers", "fs", + help="""Filter out tweets that don't have number of tweets within + [10, 1000], or number of followers >= 10, or number of people + following within [5, 1000]. This is an attempt to filter out + "spammers", i.e. accounts not associated with normal users.""") + + import ParseTweets.Tweet + + private def match_field(field: String) = + if (Tweet.all_fields contains field) + Some(Seq(field)) + else if (field == "big-fields") + Some(Tweet.big_fields) + else if (field == "small-fields") + Some(Tweet.small_fields) + else + None + + private object Field { + def unapply(spec: String) = { + if (spec.length > 0 && spec.head == '+') + match_field(spec.tail) + else + match_field(spec) + } + } + + private object NegField { + def unapply(spec: String) = { + if (spec.length > 0 && spec.head == '-') + match_field(spec.tail) + else + None + } + } + + def parse_output_fields(fieldspec: String) = { + val directives = fieldspec.split("[ ,]") + val incfields = mutable.LinkedHashSet[String]() + for (direc <- Array("default") ++ directives) { + direc match { + case "default" => { + incfields.clear() + incfields ++= ParseTweets.Tweet.default_fields(this) + } + case "all" => { + incfields.clear() + incfields ++= ParseTweets.Tweet.all_fields + } + case "none" => { + incfields.clear() + } + case Field(fields) => { + incfields ++= fields + } + case NegField(fields) => { + incfields --= fields + } + case x => { ap.usageError( + "Unrecognized directive '%s' in --output-fields" format x) + } + } + } + incfields.toSeq + } + var included_fields: Seq[String] = _ + var input_schema: Schema = _ + + /* Whether we are doing tweet-level filtering. To check whether doing + group-level filtering, check whether filter_grouping == "none". */ + var has_tweet_filtering: Boolean = _ + + override def check_usage() { + timeslice = (timeslice_float * 1000).toLong + has_tweet_filtering = filter_tweets != null || cfilter_tweets != null + val has_group_filtering = filter_groups != null || cfilter_groups != null + if (filter_grouping == null) { + if (has_group_filtering) + filter_grouping = grouping + else + filter_grouping = "none" + } + if (has_group_filtering && filter_grouping == "none") + ap.usageError("group-level filtering not possible when `--filter-grouping=none`") + if (!has_group_filtering && filter_grouping != "none") + ap.usageError("when not doing group-level filtering, `--filter-grouping` must be `none`") + if (output_format == "json" && grouping != "none") + ap.usageError("output grouping (--grouping) not possible when output format is JSON") + included_fields = parse_output_fields(output_fields) + } +} + +object ParseTweets extends ScoobiProcessFilesApp[ParseTweetsParams] { + + // TweetID = Twitter's numeric ID used to uniquely identify a tweet. + type TweetID = Long + + type Timestamp = Long + + val empty_map = Map[String, Int]() + + /** + * Data for a tweet or grouping of tweets. + * + * @param json Raw JSON for tweet; only stored when --output-format=json + * @param path Path of file that the tweet came from + * @param text Text for tweet or tweets (a Seq in case of multiple tweets) + * @param user User name (FIXME: or one of them, when going by time; should + * do something smarter) + * @param id Tweet ID + * @param min_timestamp Earliest timestamp + * @param max_timestamp Latest timestamp + * @param geo_timestamp Earliest timestamp of tweet with corresponding + * location + * @param lat Best latitude (corresponding to the earliest tweet) + * @param long Best longitude (corresponding to the earliest tweet) + * @param followers Max followers + * @param following Max following + * @param lang Language used + * @param numtweets Number of tweets merged + * @param user_mentions Item-count map of all @-mentions + * @param retweets Like `user_mentions` but only for retweet mentions + * @param hashtags Item-count map of hashtags + * @param urls Item-count map of URL's + */ + case class Tweet( + json: String, + path: String, + text: Seq[String], + user: String, + id: TweetID, + min_timestamp: Timestamp, + max_timestamp: Timestamp, + geo_timestamp: Timestamp, + lat: Double, + long: Double, + followers: Int, + following: Int, + lang: String, + numtweets: Int, + user_mentions: Map[String, Int], + retweets: Map[String, Int], + hashtags: Map[String, Int], + urls: Map[String, Int] + /* NOTE: If you add a field here, you need to update a bunch of places, + including (of course) wherever a Tweet is created, but also + some less obvious places. In all: + + -- the doc string just above + -- the definition of to_row() and Tweet.row_fields() + -- parse_json_lift() below + -- merge_records() below + -- TweetFilterParser.main() below + */ + ) { + def to_row(tokenize_act: TokenizeCountAndFormat, opts: ParseTweetsParams) = { + import Encoder.{long => elong, _} + val optfields = mutable.Buffer[String]() + for (field <- opts.included_fields) { + val fieldval = field match { + case "user" => string(user) + case "id" => elong(id) + case "path" => string(path) + case "min-timestamp" => timestamp(min_timestamp) + case "max-timestamp" => timestamp(max_timestamp) + case "geo-timestamp" => timestamp(geo_timestamp) + case "coord" => { + // Latitude/longitude need to be combined into a single field, + // but only if both actually exist. + if (!isNaN(lat) && !isNaN(long)) "%s,%s" format (lat, long) + else "" + } + case "followers" => int(followers) + case "following" => int(following) + case "lang" => string(lang) + case "numtweets" => int(numtweets) + case "user-mentions" => count_map(user_mentions) + case "retweets" => count_map(retweets) + case "hashtags" => count_map(hashtags) + case "urls" => count_map(urls) + case "text" => seq_string(text) + case "counts" => tokenize_act.emit_ngrams(text) + } + optfields += fieldval + } + optfields.toSeq mkString "\t" + } + } + + object Tweet extends ParseTweetsAction { + val operation_category = "Tweet" + + val small_fields = + Seq("user", "id", "path", "min-timestamp", "max-timestamp", + "geo-timestamp", "coord", "followers", "following", "lang", + "numtweets") + + val big_fields = + Seq("user-mentions", "retweets", "hashtags", "urls", "text", "counts") + + val all_fields = small_fields ++ big_fields + + def default_fields(opts: ParseTweetsParams) = { + if (opts.input_format == "raw-lines") + Seq("path", "numtweets", "text", "counts") + else if (opts.grouping == "file") + small_fields + else + all_fields + } + + def row_fields(opts: ParseTweetsParams) = opts.included_fields + + def from_row(schema: Schema, fields: Seq[String]) = { + var json = "" + var path = "" + var text = Seq[String]() + var user = "" + var id: TweetID = 0L + var min_timestamp: Timestamp = 0L + var max_timestamp: Timestamp = 0L + var geo_timestamp: Timestamp = 0L + var lat = NaN + var long = NaN + var followers = 0 + var following = 0 + var lang = "" + var numtweets = 1 + var user_mentions = empty_map + var retweets = empty_map + var hashtags = empty_map + var urls = empty_map + + import Decoder.{long => dlong,_} + schema.check_values_fit_schema(fields) + for ((name, x) <- schema.fieldnames zip fields) { + name match { + case "json" => json = string(x) + case "path" => path = string(x) + case "text" => text = seq_string(x) + case "user" => user = string(x) + case "id" => id = dlong(x) + case "min-timestamp" => min_timestamp = dlong(x) + case "max-timestamp" => max_timestamp = dlong(x) + case "geo-timestamp" => geo_timestamp = dlong(x) + case "coord" => { + if (x == "") + { lat = NaN; long = NaN } + else { + val Array(xlat, xlong) = x.split(",", -1) + lat = double(xlat) + long = double(xlong) + } + } + case "followers" => followers = int(x) + case "following" => following = int(x) + case "lang" => lang = string(x) + case "numtweets" => numtweets = int(x) + case "user-mentions" => user_mentions = count_map(x) + case "retweets" => retweets = count_map(x) + case "hashtags" => hashtags = count_map(x) + case "urls" => urls = count_map(x) + case "counts" => + { } // We don't record counts as they're built from text + case _ => + logger.warn("Unrecognized field %s with value %s" format + (name, x)) + } + } + Tweet(json, path, text, user, id, min_timestamp, max_timestamp, + geo_timestamp, lat, long, followers, following, lang, numtweets, + user_mentions, retweets, hashtags, urls) + } + + def from_raw_text(path: String, text: String) = { + Tweet("", path, Seq(text), "", 0L, 0L, 0L, 0L, NaN, NaN, 0, 0, "", + 1, empty_map, empty_map, + empty_map, empty_map) + } + } + implicit val tweetWire = mkCaseWireFormat(Tweet.apply _, Tweet.unapply _) + + /** + * A tweet along with ancillary data used for merging and filtering. + * + * @param output_key Key used for output grouping (--grouping). + * @param filter_key Key used for filter grouping (--filter-grouping). + * @param matches Whether the tweet matches the group-level boolean filters + * (if any). + * @param tweet The tweet itself. + */ + case class Record( + output_key: String, + filter_key: String, + matches: Boolean, + tweet: Tweet + ) + implicit val recordWire = mkCaseWireFormat(Record.apply _, Record.unapply _) + + /** + * A generic action in the ParseTweets app. + */ + trait ParseTweetsAction extends ScoobiProcessFilesAction { + val progname = "ParseTweets" + + def create_parser(expr: String, foldcase: Boolean) = { + if (expr == null) null + else new TweetFilterParser(foldcase).parse(expr) + } + } + + import scala.util.parsing.combinator.lexical.StdLexical + import scala.util.parsing.combinator.syntactical._ + import scala.util.parsing.input.CharArrayReader.EofCh + + /** + * A class used for filtering tweets using a boolean expression. + * Parsing of the boolean expression uses Scala parsing combinators. + * + * To use, create an instance of this class; then call `parse` to + * parse an expression into an abstract syntax tree object. Then use + * the `matches` method on this object to match against a tweet. + */ + class TweetFilterParser(foldcase: Boolean) extends StandardTokenParsers { + sealed abstract class Expr { + /** + * Check if this expression matches the given sequence of words. + */ + def matches(tweet: Tweet): Boolean = { + // FIXME: When a filter is present, we may end up calling Twokenize + // 2 or 3 times (once when generating words or n-grams, once when + // implementing tweet-level filters, and once when implementing + // user-level filters). But the alternative is to pass around the + // tokenized text, which might not be any faster in a Hadoop env. + val tokenized = tweet.text flatMap (Twokenize(_)) + if (foldcase) + matches(tweet, tokenized map (_.toLowerCase)) + else + matches(tweet, tokenized) + } + + // Not meant to be called externally. Actually implement the matching, + // with the text explicitly given (so it can be downcased to implement + // case-insensitive matching). + def matches(tweet: Tweet, text: Seq[String]): Boolean + } + + case class EConst(value: Seq[String]) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = text containsSlice value + } + + case class EAnd(left:Expr, right:Expr) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = + left.matches(tweet, text) && right.matches(tweet, text) + } + + case class EOr(left:Expr, right:Expr) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = + left.matches(tweet, text) || right.matches(tweet, text) + } + + case class ENot(e:Expr) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = + !e.matches(tweet, text) + } + + def time_compare(time1: Timestamp, op: String, time2: Timestamp) = { + op match { + case "<" => time1 < time2 + case "<=" => time1 <= time2 + case ">" => time1 > time2 + case ">=" => time1 >= time2 + } + } + + def time_compare(tw: Tweet, op: String, time: Timestamp): Boolean = { + assert(tw.min_timestamp == tw.max_timestamp) + time_compare(tw.min_timestamp, op, time) + } + + case class TimeCompare(op: String, time: Timestamp) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = + time_compare(tweet, op, time) + } + + case class TimeWithin(interval: (Timestamp, Timestamp)) extends Expr { + def matches(tweet: Tweet, text: Seq[String]) = { + val (start, end) = interval + time_compare(tweet, ">=", start) && + time_compare(tweet, "<", end) + } + } + + // NOT CURRENTLY USED, but potentially useful as an indicator of how to + // implement a parser for numbers. +// class ExprLexical extends StdLexical { +// override def token: Parser[Token] = floatingToken | super.token +// +// def floatingToken: Parser[Token] = +// rep1(digit) ~ optFraction ~ optExponent ^^ +// { case intPart ~ frac ~ exp => NumericLit( +// (intPart mkString "") :: frac :: exp :: Nil mkString "")} +// +// def chr(c:Char) = elem("", ch => ch==c ) +// def sign = chr('+') | chr('-') +// def optSign = opt(sign) ^^ { +// case None => "" +// case Some(sign) => sign +// } +// def fraction = '.' ~ rep(digit) ^^ { +// case dot ~ ff => dot :: (ff mkString "") :: Nil mkString "" +// } +// def optFraction = opt(fraction) ^^ { +// case None => "" +// case Some(fraction) => fraction +// } +// def exponent = (chr('e') | chr('E')) ~ optSign ~ rep1(digit) ^^ { +// case e ~ optSign ~ exp => +// e :: optSign :: (exp mkString "") :: Nil mkString "" +// } +// def optExponent = opt(exponent) ^^ { +// case None => "" +// case Some(exponent) => exponent +// } +// } + + class FilterLexical extends StdLexical { + // see `token` in `Scanners` + override def token: Parser[Token] = + ( delim + | unquotedWordChar ~ rep( unquotedWordChar ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | '\"' ~ rep( quotedWordChar ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\"' ~> failure("unclosed string literal") + | failure("illegal character") + ) + + def isPrintable(ch: Char) = + !ch.isControl && !ch.isSpaceChar && !ch.isWhitespace && ch != EofCh + def isPrintableNonDelim(ch: Char) = + isPrintable(ch) && ch != '(' && ch != ')' + def unquotedWordChar = elem("unquoted word char", + ch => ch != '"' && isPrintableNonDelim(ch)) + def quotedWordChar = elem("quoted word char", + ch => ch != '"' && ch != '\n' && ch != EofCh) + + // // see `whitespace in `Scanners` + // def whitespace: Parser[Any] = rep( + // whitespaceChar + // // | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + // ) + + override protected def processIdent(name: String) = + if (reserved contains name) Keyword(name) else StringLit(name) + } + + override val lexical = new FilterLexical + lexical.reserved ++= List("AND", "OR", "NOT", "TIME", "WITHIN") + lexical.delimiters ++= List("(", ")", "<", "<=", ">", ">=") + + def word = stringLit ^^ { + s => EConst(Seq(if (foldcase) s.toLowerCase else s)) + } + + def words = word.+ ^^ { + x => EConst(x.flatMap(_ match { case EConst(y) => y })) + } + + def compare_op = ( "<=" | "<" | ">=" | ">" ) + + def time = stringLit ^^ { s => (s, parse_date(s)) } ^? ( + { case (_, Some(x)) => x }, + { case (s, None) => "Unable to parse date %s" format s } ) + + def short_interval = stringLit ^^ { parse_date_interval(_) } ^? ( + { case (Some((from, to)), "") => (from, to) }, + { case (None, errmess) => errmess } ) + + def full_interval = "(" ~> time ~ time <~ ")" ^^ { + case from ~ to => (from, to) } + + def interval = (short_interval | full_interval) + + def time_compare = "TIME" ~> compare_op ~ time ^^ { + case op ~ time => TimeCompare(op, time) + // case op~time if parse_time(time) => TimeCompare(op, time) + } + + def time_within = "TIME" ~> "WITHIN" ~> interval ^^ { + interval => TimeWithin(interval) + } + + def parens: Parser[Expr] = "(" ~> expr <~ ")" + + def not: Parser[ENot] = "NOT" ~> term ^^ { ENot(_) } + + def term = ( words | parens | not | time_compare | time_within ) + + def andexpr = term * ( + "AND" ^^^ { (a:Expr, b:Expr) => EAnd(a,b) } ) + + def orexpr = andexpr * ( + "OR" ^^^ { (a:Expr, b:Expr) => EOr(a,b) } ) + + def expr = ( orexpr | term ) + + def maybe_parse(s: String) = { + val tokens = new lexical.Scanner(s) + phrase(expr)(tokens) + } + + def parse(s: String): Expr = { + maybe_parse(s) match { + case Success(tree, _) => tree + case e: NoSuccess => + throw new + IllegalArgumentException("Bad syntax: %s: %s" format (s, e)) + } + } + + def test(exprstr: String, tweet: Tweet) = { + maybe_parse(exprstr) match { + case Success(tree, _) => + println("Tree: "+tree) + val v = tree.matches(tweet) + println("Eval: "+v) + case e: NoSuccess => errprint("%s\n" format e) + } + } + + //A main method for testing + def main(args: Array[String]) = { + val text = args(2) + val timestamp = parse_date(args(1)) match { + case Some(time) => time + case None => throw new IllegalArgumentException( + "Unable to parse date %s" format args(1)) + } + val tweet = + Tweet("", "", Seq(text), "user", 0, timestamp, timestamp, + timestamp, NaN, NaN, 0, 0, "unknown", 1, + empty_map, empty_map, empty_map, empty_map) + test(args(0), tweet) + } + } + + class ParseAndUniquifyTweets( + opts: ParseTweetsParams, + set_counters: Boolean = true + ) extends ParseTweetsAction { + + val operation_category = "Parse" + + def maybe_counter(counter: String, amount: Long = 1) { + if (set_counters) + bump_counter(counter, amount) + } + + // Used internally to force an exit when a problem in parse_json_lift + // occurs. + private class ParseJSonExit extends Exception { } + + /** + * Parse a JSON line into a tweet, using Lift. + * + * @return status and tweet. + */ + def parse_json_lift(path: String, line: String): (String, Tweet) = { + + /** + * Convert a Twitter timestamp, e.g. "Tue Jun 05 14:31:21 +0000 2012", + * into a time in milliseconds since the Epoch (Jan 1 1970, or so). + */ + def parse_time(timestring: String): Timestamp = { + val sdf = new SimpleDateFormat("EEE MMM dd HH:mm:ss ZZZZZ yyyy") + try { + sdf.parse(timestring) + sdf.getCalendar.getTimeInMillis + } catch { + case pe: ParseException => { + maybe_counter("unparsable date") + logger.warn("Error parsing date %s on line %s: %s\n%s" format ( + timestring, lineno, line, pe)) + 0 + } + } + } + + def parse_problem(e: Exception) = { + logger.warn("Error parsing line %s: %s\n%s" format ( + lineno, line, stack_trace_as_string(e))) + ("error", null) + } + + /** + * Retrieve a string along a path, checking to make sure the path + * exists. + */ + def force_string(value: liftweb.json.JValue, fields: String*) = { + var fieldval = value + var path = List[String]() + for (field <- fields) { + path :+= field + fieldval \= field + if (fieldval == liftweb.json.JNothing) { + val fieldpath = path mkString "." + maybe_counter("ERROR: tweet with missing field %s" format fieldpath) + warning(line, "Can't find field path %s in tweet", fieldpath) + throw new ParseJSonExit + } + } + fieldval.values.toString + } + + /** + * Retrieve the list of entities of a particular type from a tweet, + * along with the indices referring to the entity. + * + * @param key the key referring to the type of entity + * @param subkey the subkey within the entity to return as the "value" + * of the entity + */ + def retrieve_entities_with_indices(parsed: liftweb.json.JValue, + key: String, subkey: String) = { + // Retrieve the raw list of entities based on the key. + val entities_raw = + (parsed \ "entities" \ key values). + asInstanceOf[List[Map[String, Any]]] + // For each entity, fetch the user actually mentioned. Sort and + // convert into a map counting mentions. Do it this way to count + // multiple mentions properly. + for { ent <- entities_raw + rawvalue = ent(subkey) + if rawvalue != null + value = rawvalue.toString + indices = ent("indices").asInstanceOf[List[Number]] + start = indices(0).intValue + end = indices(1).intValue + if { + if (value.length == 0) { + maybe_counter("zero length %s/%s seen" format (key, subkey)) + warning(line, + "Zero-length %s/%s in interval [%d,%d], skipped", + key, subkey, start, end) + false + } else true + } + } + yield (value, start, end) + } + + /** + * Retrieve the list of entities of a particular type from a tweet, + * as a map with counts (so as to count multiple identical entities + * properly). + * + * @param key the key referring to the type of entity + * @param subkey the subkey within the entity to return as the "value" + * of the entity + */ + def retrieve_entities(parsed: liftweb.json.JValue, + key: String, subkey: String) = { + list_to_item_count_map( + for { (value, start, end) <- + retrieve_entities_with_indices(parsed, key, subkey) } + yield value + ) + } + + try { + /* The result of parsing is a JValue, which is an abstract type with + subtypes for all of the types of objects found in JSON: maps, arrays, + strings, integers, doubles, booleans. Corresponding to the atomic + types are JString, JInt, JDouble, JBool. Corresponding to arrays is + JArray (which underlyingly holds something like a List[JValue]), and + corresponding to maps is JObject. JObject underlyingly holds + something like a List[JField], and a JField is a structure holding a + 'name' (a string) and a 'value' (a JValue). The following + operations can be done on these objects: + + 1. For JArray, JObject and JField, the method 'children' yields + a List[JValue] of their children. For JObject, as mentioned + above, this will be a list of JField objects. For JField + objects, this will be a one-element list of whatever was + in 'value'. + 2. JArray, JObject and JField can be indexed like an array. + Indexing JArray is obvious. Indexing JObject gives a JField + object, and indexing JField tries to index the object in the + field's 'value' element. + 3. You can directly fetch the components of a JField using the + accessors 'name' and 'value'. + + 4. You can retrieve the underlying Scala form of any object using + 'values'. This returns JArrays as a List, JObject as a Map, and + JField as a tuple of (name, value). The conversion is deep, in + that subobjects are recursively converted. However, the result + type isn't necessarily useful -- for atoms you get more or less + what you expect (although BigInt instead of Int), and for JField + it's just a tuple, but for JArray and JObject the type of the + expression is a path-dependent type ending in 'Value'. So you + will probably have to cast it using asInstanceOf[]. (For Maps, + consider using Map[String, Any], since you don't necessarily know + the type of the different elements, which may vary from element + to element. + + 5. For JObject, you can do a field lookup using the "\" operator, + as shown below. This yields a JValue, onto which you can do + another "\" if it happens to be another JObject. (That's because + "\", like 'apply', 'values' and 'children' is defined on the + JValue itself. When "\" is given a non-existent field name, + or run on a JInt or other atom, the result is JNothing. + 6. You can also pass a class (one of the subclasses of JValue) to + "\" instead of a string, e.g. classOf[JInt], to find children + with this type. BEWARE: It looks only at the children returned + by the 'children' method, which (for objects) are all of type + JField, so looking for JInt won't help even if you have some + key-value pairs where the value is an integer. + 7. You can also use "\\" to do multi-level lookup, i.e. this + recursively traipses down 'children' and 'children' of 'children' + and does the equivalent of "\" on each. The return value is a + List, with 'values' applied to each element of the List to + convert to Scala objects. Beware that you only get the items + themselves, without the field names. That is, if you search for + e.g. a JInt, you'll get back a List[Int] with a lot of numbers + in it -- but no indication of where they came from or what the + associated field name was. + 8. There's also 'extract' for getting a JObject as a case class, + as well as 'map', 'filter', '++', and other standard methods for + collections. + */ + val parsed = liftweb.json.parse(line) + if ((parsed \ "delete" values) != None) { + maybe_counter("tweet deletion notices skipped") + ("delete", null) + } else if ((parsed \ "limit" values) != None) { + maybe_counter("tweet limit notices skipped") + ("limit", null) + } else { + val user = force_string(parsed, "user", "screen_name") + val timestamp = parse_time(force_string(parsed, "created_at")) + val raw_text = force_string(parsed, "text") + val text = raw_text.replaceAll("\\s+", " ") + val followers = force_string(parsed, "user", "followers_count").toInt + val following = force_string(parsed, "user", "friends_count").toInt + val tweet_id = force_string(parsed, "id_str") + val lang = force_string(parsed, "user", "lang") + val (lat, long) = + if ((parsed \ "coordinates" \ "type" values).toString != "Point") { + (NaN, NaN) + } else { + val latlong: List[Number] = + (parsed \ "coordinates" \ "coordinates" values). + asInstanceOf[List[Number]] + (latlong(1).doubleValue, latlong(0).doubleValue) + } + + /////////////// HANDLE ENTITIES + + /* Entity types: + + user_mentions: @-mentions in the text; subkeys are + screen_name = Username of user + name = Display name of user + id_str = ID of user, as a string + id = ID of user, as a number + indices = indices in text of @-mention, including the @ + + urls: URLs mentioned in the text; subkeys are + url = raw URL (probably a shortened reference to bit.ly etc.) + expanded_url = actual URL + display_url = display form of URL (without initial http://, cut + off after a point with \u2026 (Unicode ...) + indices = indices in text of raw URL + (NOTE: all URL's in the JSON text have /'s quoted as \/, and + display_url may not be present) + + hashtags: Hashtags mentioned in the text; subkeys are + text = text of the hashtag + indices = indices in text of hashtag, including the # + + media: Embedded objects + type = "photo" for images + indices = indices of URL for image + url, expanded_url, display_url = similar to URL mentions + media_url = photo on p.twimg.com? + media_url_https = https:// alias for media_url + id_str, id = some sort of tweet ID or similar? + sizes = map with keys "small", "medium", "large", "thumb"; + each has subkeys: + resize = "crop" or "fit" + h = height (as number) + w = width (as number) + + Example of URL's for photo: + url = http:\/\/t.co\/AO3mRYaG + expanded_url = http:\/\/twitter.com\/alejandraoraa\/status\/215758169589284864\/photo\/1 + display_url = pic.twitter.com\/AO3mRYaG + media_url = http:\/\/p.twimg.com\/Av6G6YBCAAAwD7J.jpg + media_url_https = https:\/\/p.twimg.com\/Av6G6YBCAAAwD7J.jpg + +"id_str":"215758169597673472" + */ + + // Retrieve "user_mentions", which is a list of mentions, each + // listing the span of text, the user actually mentioned, etc. -- + // along with whether the mention is a retweet (by looking for + // "RT" before the mention) + val user_mentions_raw = retrieve_entities_with_indices( + parsed, "user_mentions", "screen_name") + val user_mentions_retweets = + for { (screen_name, start, end) <- user_mentions_raw + namelen = screen_name.length + retweet = (start >= 3 && + raw_text.slice(start - 3, start) == "RT ") + } yield { + // Subtract one because of the initial @ in the index reference + if (end - start - 1 != namelen) { + maybe_counter("wrong length interval for screen name seen") + warning(line, "Strange indices [%d,%d] for screen name %s, length %d != %d, text context is '%s'", + start, end, screen_name, end - start - 1, namelen, + raw_text.slice(start, end)) + } + (screen_name, retweet) + } + val user_mentions_list = + for { (screen_name, retweet) <- user_mentions_retweets } + yield screen_name + val retweets_list = + for { (screen_name, retweet) <- user_mentions_retweets if retweet } + yield screen_name + val user_mentions = list_to_item_count_map(user_mentions_list) + val retweets = list_to_item_count_map(retweets_list) + + val hashtags = retrieve_entities(parsed, "hashtags", "text") + val urls = retrieve_entities(parsed, "urls", "expanded_url") + // map + // { case (url, count) => (url.replace("\\/", "/"), count) } + + ("success", + Tweet(if (opts.output_format == "json") line else "", + path, Seq(text), user, tweet_id.toLong, timestamp, + timestamp, timestamp, lat, long, followers, following, lang, 1, + user_mentions, retweets, hashtags, urls)) + } + } catch { + case jpe: liftweb.json.JsonParser.ParseException => { + maybe_counter("ERROR: lift-json parsing error") + parse_problem(jpe) + } + case npe: NullPointerException => { + maybe_counter("ERROR: NullPointerException when parsing") + parse_problem(npe) + } + case nfe: NumberFormatException => { + maybe_counter("ERROR: NumberFormatException when parsing") + parse_problem(nfe) + } + case _: ParseJSonExit => ("error", null) + case e: Exception => { + maybe_counter("ERROR: %s when parsing" format e.getClass.getName) + parse_problem(e); throw e + } + } + } + + def get_string(value: Map[String, Any], l1: String) = { + value(l1).asInstanceOf[String] + } + + def get_2nd_level_value[T](value: Map[String, Any], l1: String, + l2: String) = { + value(l1).asInstanceOf[java.util.LinkedHashMap[String,Any]](l2). + asInstanceOf[T] + } + + /* + * Parse a line (JSON or textdb) into a tweet. Return `null` if + * unable to parse. + */ + def parse_line(pathline: (String, String)) = { + val (path, line) = pathline + maybe_counter("total lines") + lineno += 1 + // For testing + if (opts.debug) + logger.debug("parsing JSON: %s" format line) + if (line.trim == "") { + maybe_counter("blank lines skipped") + null + } + else { + maybe_counter("total tweets parsed") + val (status, tweet) = opts.input_format match { + case "raw-lines" => ("success", Tweet.from_raw_text(path, line)) + case "json" => parse_json_lift(path, line) + case "textdb" => + error_wrap(line, ("error", null: Tweet)) { line => + ("success", + Tweet.from_row(opts.input_schema, line.split("\t", -1))) + } + } + if (status == "error") { + maybe_counter("total tweets unsuccessfully parsed") + } else if (status == "success") { + maybe_counter("total tweets successfully parsed") + } else { + maybe_counter("total tweet-related notices skipped during parsing") + } + tweet + } + } + + + /** + * Return true if this tweet is "valid" in that it doesn't have any + * out-of-range values (blank strings or 0-valued quantities). Note + * that we treat a case where both latitude and longitude are 0 as + * invalid even though technically such a place could exist. (FIXME, + * use NaN or something to indicate a missing latitude or longitude). + */ + def is_valid_tweet(tw: Tweet): Boolean = { + val valid = + // filters out invalid tweets, as well as trivial spam + tw.id != 0 && tw.min_timestamp != 0 && tw.max_timestamp != 0 && + tw.user != "" && !(tw.lat == 0.0 && tw.long == 0.0) + if (!valid) + maybe_counter("tweets skipped due to invalid fields") + valid + } + + /** + * Select the first tweet with the same ID. For various reasons we may + * have duplicates of the same tweet among our data. E.g. it seems that + * Twitter itself sometimes streams duplicates through its Streaming API, + * and data from different sources will almost certainly have duplicates. + * Furthermore, sometimes we want to get all the tweets even in the + * presence of flakiness that causes Twitter to sometimes bomb out in a + * Streaming session and take a while to restart, so we have two or three + * simultaneous streams going recording the same stuff, hoping that Twitter + * bombs out at different points in the different sessions (which is + * generally true). Then, all or almost all the tweets are available in + * the different streams, but there is a lot of duplication that needs to + * be tossed aside. + */ + def tweet_once(id_tweets: (TweetID, Iterable[Tweet])) = { + val (id, tweets) = id_tweets + val head = tweets.head + val skipped = tweets.tail.toSeq.length + if (skipped > 0) + maybe_counter("duplicate tweets skipped", skipped) + head + } + + lazy val filter_tweets_ast = + create_parser(opts.filter_tweets, foldcase = true) + lazy val cfilter_tweets_ast = + create_parser(opts.cfilter_tweets, foldcase = false) + + /** + * Apply any boolean filters given in `--filter-tweets` or + * `--cfilter-tweets`. + */ + def filter_tweet_by_tweet_filters(tweet: Tweet) = { + val good = + (filter_tweets_ast == null || (filter_tweets_ast matches tweet)) && + (cfilter_tweets_ast == null || (cfilter_tweets_ast matches tweet)) + if (!good) + maybe_counter("tweets skipped due to non-matching tweet-level filter") + good + } + + // Filter out duplicate tweets -- group by tweet ID and then take the + // first tweet for a given ID. Duplicate tweets occur for various + // reasons -- they are common even in a single Twitter stream, not to + // mention when two or more streams covering the same time period and + // source are combined to deal with the inevitable gaps in coverage + // resulting from brief Twitter failures or hiccups, periods when + // one of the scraping machines goes down, etc. + def filter_duplicates(values: DList[Tweet]) = + values.groupBy(_.id).map(tweet_once) + + def filter_tweets(values: DList[Tweet]) = { + // It's necessary to filter for duplicates when reading JSON tweets + // directly (see comment above), but a bad idea when reading other + // formats because tweet ID's may not be unique (particularly when + // reading "tweets" that actually stem from multiple grouped tweets, + // raw text, etc. where the ID field may simply have -1 in it). + // + // Likewise it may be a good idea to filter out JSON tweets with + // invalid fields in them, but not a good idea otherwise. + val deduplicated = + if (opts.input_format == "json") + filter_duplicates(values.filter(is_valid_tweet)) + else + values + deduplicated.filter(x => filter_tweet_by_tweet_filters(x)) + } + + def note_remaining_tweets(tweet: Tweet) = { + maybe_counter("tweets remaining after uniquifying and tweet-level filtering") + true + } + + /** + * Parse a set of JSON-formatted tweets. Input is (path, JSON). + */ + def apply(lines: DList[(String, String)]) = { + + // Parse JSON into tweet records. Filter out nulls (unparsable tweets). + val values_extracted = lines.map(parse_line).filter(_ != null) + + /* Filter duplicates, invalid tweets, tweets not matching any + tweet-level boolean filters. (User-level boolean filters get + applied later.) */ + val good_tweets = filter_tweets(values_extracted) + + good_tweets.filter(note_remaining_tweets) + } + } + + class GroupTweets(opts: ParseTweetsParams) + extends ParseTweetsAction { + + val operation_category = "Group" + + lazy val filter_groups_ast = + create_parser(opts.filter_groups, foldcase = true) + lazy val cfilter_groups_ast = + create_parser(opts.cfilter_groups, foldcase = false) + + /** + * Apply any boolean filters given in `--filter-groups` or + * `--cfilter-groups`. + */ + private def filter_tweet_by_group_filters(tweet: Tweet) = { + (filter_groups_ast == null || (filter_groups_ast matches tweet)) && + (cfilter_groups_ast == null || (cfilter_groups_ast matches tweet)) + } + + private def tweet_key(tweet: Tweet, keytype: String) = { + keytype match { + case "user" => tweet.user + case "file" => tweet.path + case "time" => + ((tweet.min_timestamp / opts.timeslice) * opts.timeslice).toString + case "none" => "" + } + } + + private def tweet_to_record(tweet: Tweet) = { + Record(tweet_key(tweet, opts.grouping), + tweet_key(tweet, opts.filter_grouping), + filter_tweet_by_group_filters(tweet), + tweet) + } + + /** + * Merge the data associated with two tweets or tweet combinations + * into a single tweet combination. Concatenate text. Find maximum + * numbers of followers and followees. Add number of tweets in each. + * For latitude and longitude, take the earliest provided values + * ("earliest" by timestamp and "provided" meaning not missing). + */ + private def merge_records(tw1: Record, tw2: Record): Record = { + assert(tw1.output_key == tw2.output_key) + val t1 = tw1.tweet + val t2 = tw2.tweet + val id = if (t1.id != t2.id) -1L else t1.id + val lang = if (t1.lang != t2.lang) "[multiple]" else t1.lang + val path = if (t1.path != t2.path) "[multiple]" else t1.path + val (followers, following) = + (math.max(t1.followers, t2.followers), + math.max(t1.following, t2.following)) + val numtweets = t1.numtweets + t2.numtweets + // Avoid computing stuff we will never use + val text = + if (opts.included_fields contains "text") + t1.text ++ t2.text + else Seq[String]() + val user_mentions = + if (opts.included_fields contains "user-mentions") + combine_maps(t1.user_mentions, t2.user_mentions) + else empty_map + val retweets = + if (opts.included_fields contains "retweets") + combine_maps(t1.retweets, t2.retweets) + else empty_map + val hashtags = + if (opts.included_fields contains "hashtags") + combine_maps(t1.hashtags, t2.hashtags) + else empty_map + val urls = + if (opts.included_fields contains "urls") + combine_maps(t1.urls, t2.urls) + else empty_map + + val (lat, long, geo_timestamp) = + if (isNaN(t1.lat) && isNaN(t2.lat)) { + (t1.lat, t1.long, math.min(t1.geo_timestamp, t2.geo_timestamp)) + } else if (isNaN(t2.lat)) { + (t1.lat, t1.long, t1.geo_timestamp) + } else if (isNaN(t1.lat)) { + (t2.lat, t2.long, t2.geo_timestamp) + } else if (t1.geo_timestamp < t2.geo_timestamp) { + (t1.lat, t1.long, t1.geo_timestamp) + } else { + (t2.lat, t2.long, t2.geo_timestamp) + } + val min_timestamp = math.min(t1.min_timestamp, t2.min_timestamp) + val max_timestamp = math.max(t1.max_timestamp, t2.max_timestamp) + + // FIXME maybe want to track the different users + val tweet = + Tweet("", path, text, t1.user, id, min_timestamp, max_timestamp, + geo_timestamp, lat, long, followers, following, lang, numtweets, + user_mentions, retweets, hashtags, urls) + Record(tw1.output_key, "", tw1.matches || tw2.matches, tweet) + } + + /** + * Return true if tweet (combination) has a fully-specified latitude + * and longitude. + */ + private def has_latlong(tw: Tweet) = { + val good = !isNaN(tw.lat) && !isNaN(tw.long) + if (!good) + bump_counter("grouped tweets filtered due to missing lat/long") + good + } + + val MAX_NUMBER_FOLLOWING = 1000 + val MIN_NUMBER_FOLLOWING = 5 + val MIN_NUMBER_FOLLOWERS = 10 + val MAX_NUMBER_TWEETS = 1000 + val MIN_NUMBER_TWEETS = 10 + /** + * Return true if this tweet combination (tweets for a given user) + * appears to reflect a "spammer" user or some other user with + * sufficiently nonstandard behavior that we want to exclude them (e.g. + * a celebrity or an inactive user): Having too few or too many tweets, + * following too many or too few, or having too few followers. A spam + * account is likely to have too many tweets -- and even more, to send + * tweets to too many people (although we don't track this). A spam + * account sends much more than it receives, and may have no followers. + * A celebrity account receives much more than it sends, and tends to have + * a lot of followers. People who send too few tweets simply don't + * provide enough data. + * + * FIXME: We don't check for too many followers of a given account, but + * instead too many people that a given account is following. Perhaps + * this is backwards? + */ + private def is_nonspammer(tw: Tweet): Boolean = { + val good = + (tw.following >= MIN_NUMBER_FOLLOWING && + tw.following <= MAX_NUMBER_FOLLOWING) && + (tw.followers >= MIN_NUMBER_FOLLOWERS) && + (tw.numtweets >= MIN_NUMBER_TWEETS && + tw.numtweets <= MAX_NUMBER_TWEETS) + if (!good) + bump_counter("grouped tweets filtered due to failing following, followers, or min/max tweets restrictions") + good + } + + // bounding box for north america + val MIN_LAT = 25.0 + val MIN_LNG = -126.0 + val MAX_LAT = 49.0 + val MAX_LNG = -60.0 + + /** + * Return true if this tweet (combination) is located within the + * bounding box of North America. + */ + private def matches_north_america(tw: Tweet): Boolean = { + val good = (tw.lat >= MIN_LAT && tw.lat <= MAX_LAT) && + (tw.long >= MIN_LNG && tw.long <= MAX_LNG) + if (!good) + bump_counter("grouped tweets filtered due to outside North America") + good + } + + /** + * Return true if this tweet (combination) matches all of the + * misc. filters (--geographic-only, --north-america-only, + * --filter-spammers). + * + * FIXME: Eliminate misc. filters, make them possible using + * normal filter mechanism. + */ + private def matches_misc_filters(tw: Tweet): Boolean = { + (!opts.geographic_only || has_latlong(tw)) && + (!opts.north_america_only || matches_north_america(tw)) && + (!opts.filter_spammers || is_nonspammer(tw)) + } + + /** + * Return true if a grouped set of tweets matches group-level filters + * (i.e. if any of them individually matches the filters). + * We've already checked each individual tweet against the + * group-level filters, and grouped tweets based on the group-filtering + * key. Note that we only do things this way if we're not also + * grouping output on the same key; see below. + */ + private def matches_group_filters(records: Iterable[Record]) = { + val good = records.exists(_.matches) + if (!good) + bump_counter("grouped tweets filtered by group-level filters") + good + } + + /** + * Apply group filtering (--filter-grouping, --filter-groups, + * --cfilter-groups) in the absence of output grouping on the same key. + */ + private def do_group_filtering(records: DList[Record]) = { + records + .groupBy(_.filter_key) + .map(_._2) // throw away key + .filter(matches_group_filters) + .flatten + } + + /** + * Group output according to --grouping. + */ + private def do_output_grouping(records: DList[Record]) = { + records + .groupBy(_.output_key) + .combine(merge_records) + .map(_._2) // throw away key + } + + private def note_remaining_tweets(tweet: Tweet) = { + bump_counter("grouped tweets remaining after output grouping and group-level filtering") + bump_counter("ungrouped tweets remaining after output grouping and group-level filtering", + tweet.numtweets) + true + } + + def apply(tweets: DList[Tweet]) = { + /* Here we implement output grouping and group filtering. + + Possibilities: + + 1. No output grouping, no group filtering: + -- Do nothing. + + 2. Output grouping, no group filtering: + -- Group by output key; + -- Merge. + + 3. No output grouping, group filtering: + -- Group by filter key; + -- Filter to see if any tweet in group matches group filter + (_.exists(_.matches)); + -- Flatten. + + 4. Output grouping, filter grouping, same keys: + -- Group by output key; + -- Merge, merging the 'matches' values (a.matches || b.matches); + -- Filter results based on _.matches. + + 5. Output grouping, filter grouping, different keys: Combine 3+2: + -- Group by filter key; + -- Filter to see if any tweet in group matches group filter + (_.exists(_.matches)) + -- Flatten; + -- Group by output key; + -- Merge. + + Note that in the cases of (3) and (5), we need to iterate over + the result of grouping. Currently there are some issues doing + this in Scoobi. We do have a fix in our private version, but it + potentially can lead to memory errors. In general, the problem + is that we need to iterate twice through the list of grouped + tweets (once to check to see if any match, again to do the + flattening). Hadoop doesn't really have support for this, so + we have to buffer everything, which can lead to memory errors. + (Note that Hadoop 0.21+/2.0+ has a feature to allow multiple + iteration in this case, but it also does buffering. The only + difference is that it buffers only a certain amount in memory + and then spills the remainder to disk. Scoobi should probably do + the same.) + */ + + val grouped_tweets = + if (opts.grouping == "none" && opts.filter_grouping == "none") + tweets + else { + // convert to Record (which contains grouping keys and group-level + // filter results) + val records = tweets.map(tweet_to_record) + val grouped_records = + (opts.grouping, opts.filter_grouping) match { + case (_, "none") => do_output_grouping(records) + case ("none", _) => do_group_filtering(records) + case (x, y) if x == y => + do_output_grouping(records).filter(_.matches) + case (_, _) => + do_output_grouping(do_group_filtering(records)) + } + // convert back to Tweet + grouped_records.map(_.tweet) + } + + // Apply misc. filters, and note whatever remains. + // FIXME: Misc. filters should be doable using normal filter + // mechanism, rather than special-cased. + grouped_tweets + .filter(matches_misc_filters) + .filter(note_remaining_tweets) + } + } + + class TokenizeCountAndFormat(opts: ParseTweetsParams) + extends ParseTweetsAction { + + val operation_category = "Tokenize" + + /** + * Convert a word to lowercase. + */ + def normalize_word(orig_word: String) = { + val word = + if (opts.preserve_case) + orig_word + else + orig_word.toLowerCase + // word.startsWith("@") + if (word.contains("http://") || word.contains("https://")) + "-LINK-" + else + word + } + + /** + * Return true if word should be filtered out (post-normalization). + */ + def reject_word(word: String) = { + word == "-LINK-" + } + + /** + * Return true if ngram should be filtered out (post-normalization). + * Here we filter out things where every word should be filtered, or + * where the first or last word should be filtered (in such a case, all + * the rest will be contained in a one-size-down n-gram). + */ + def reject_ngram(ngram: Iterable[String]) = { + ngram.forall(reject_word) || reject_word(ngram.head) || + reject_word(ngram.last) + } + + /** + * Use Twokenize to break up a tweet into tokens and separate into ngrams. + */ + def break_tweet_into_ngrams(text: String): + Iterable[Iterable[String]] = { + val words = Twokenize(text) + val normwords = words.map(normalize_word) + + // Then, generate all possible ngrams up to a specified maximum length, + // where each ngram is a sequence of words. `sliding` overlays a sliding + // window of a given size on a sequence to generate successive + // subsequences -- exactly what we want to generate ngrams. So we + // generate all 1-grams, then all 2-grams, etc. up to the maximum size, + // and then concatenate the separate lists together (that's what `flatMap` + // does). + (1 to opts.max_ngram). + flatMap(normwords.sliding(_)).filter(!reject_ngram(_)) + } + + /** + * Tokenize a tweet text string into ngrams and count them, emitting + * the word-count pairs encoded into a string. + * and emit the ngrams individually. + * Each ngram is emitted along with the text data and a count of 1, + * and later grouping + combining will add all the 1's to get the + * ngram count. + */ + def emit_ngrams(tweet_text: Seq[String]): String = { + val ngrams = + tweet_text.flatMap(break_tweet_into_ngrams(_)).toSeq. + map(encode_ngram_for_count_map_field) + shallow_encode_count_map(list_to_item_count_map(ngrams).toSeq) + } + + /** + * Given a tweet, tokenize the text into ngrams, count words and format + * the result as a field; then convert the whole into a record to be + * written out. + */ + def tokenize_count_and_format(tweet: Tweet): String = { + tweet.to_row(this, opts) + } + } + + /** + * @param ty Type of value + * @param key2 Second-level key to group on; Typically, either we want the + * individual value to appear in the overall stats (at 2nd level) or not; + * when not, put the individual value in `value`, and put the value of + * `ty` in `key2`; otherwise, put the individual value in `key2`, and + * usually put a constant value (e.g. "") in `value` so all relevant + * tweets get grouped together. + * @param value See above. + */ + case class FeatureValueStats( + ty: String, + key2: String, + value: String, + num_tweets: Int, + min_timestamp: Timestamp, + max_timestamp: Timestamp + ) { + def to_row(opts: ParseTweetsParams) = { + import Encoder._ + Seq( + string(ty), + string(key2), + string(value), + int(num_tweets), + timestamp(min_timestamp), + timestamp(max_timestamp) + ) mkString "\t" + } + } + + implicit val featureValueStatsWire = + mkCaseWireFormat(FeatureValueStats.apply _, FeatureValueStats.unapply _) + + object FeatureValueStats { + def row_fields = + Seq( + "type", + "key2", + "value", + "num-tweets", + "min-timestamp", + "max-timestamp") + + def from_row(row: String, opts: ParseTweetsParams) = { + import Decoder._ + val Array(ty, key2, value, num_tweets, min_timestamp, max_timestamp) = + row.split("\t", -1) + FeatureValueStats( + string(ty), + string(key2), + string(value), + int(num_tweets), + timestamp(min_timestamp), + timestamp(max_timestamp) + ) + } + + def from_tweet(tweet: Tweet, ty: String, key2: String, value: String) = { + FeatureValueStats(ty, key2, value, 1, tweet.min_timestamp, + tweet.max_timestamp) + } + + def merge_stats(x1: FeatureValueStats, x2: FeatureValueStats) = { + assert(x1.ty == x2.ty) + assert(x1.key2 == x2.key2) + assert(x1.value == x2.value) + FeatureValueStats(x1.ty, x1.key2, x1.value, + x1.num_tweets + x2.num_tweets, + math.min(x1.min_timestamp, x2.min_timestamp), + math.max(x1.max_timestamp, x2.max_timestamp)) + } + } + + /** + * Statistics on any tweet "feature" (e.g. user, language) that can be + * identified by a value of some type (e.g. string, number) and has an + * associated map of occurrences of values of the feature. + */ + case class FeatureStats( + ty: String, + key2: String, + lowest_value_by_sort: String, + highest_value_by_sort: String, + most_common_value: String, + most_common_count: Int, + least_common_value: String, + least_common_count: Int, + num_value_types: Int, + num_value_occurrences: Int + ) extends Ordered[FeatureStats] { + def compare(that: FeatureStats) = { + (ty compare that.ty) match { + case 0 => key2 compare that.key2 + case x => x + } + } + + def to_row(opts: ParseTweetsParams) = { + import Encoder._ + Seq( + string(ty), + string(key2), + string(lowest_value_by_sort), + string(highest_value_by_sort), + string(most_common_value), + int(most_common_count), + string(least_common_value), + int(least_common_count), + int(num_value_types), + int(num_value_occurrences), + double(num_value_occurrences.toDouble/num_value_types) + ) mkString "\t" + } + } + + implicit val featureStatsWire = + mkCaseWireFormat(FeatureStats.apply _, FeatureStats.unapply _) + + object FeatureStats { + def row_fields = + Seq( + "type", + "key2", + "lowest-value-by-sort", + "highest-value-by-sort", + "most-common-value", + "most-common-count", + "least-common-value", + "least-common-count", + "num-value-types", + "num-value-occurrences", + "avg-value-occurrences" + ) + + def from_row(row: String) = { + import Decoder._ + val Array(ty, key2, lowest_value_by_sort, highest_value_by_sort, + most_common_value, most_common_count, + least_common_value, least_common_count, + num_value_types, num_value_occurrences, avo) = + row.split("\t", -1) + FeatureStats( + string(ty), + string(key2), + string(lowest_value_by_sort), + string(highest_value_by_sort), + string(most_common_value), + int(most_common_count), + string(least_common_value), + int(least_common_count), + int(num_value_types), + int(num_value_occurrences) + ) + } + + def from_value_stats(vs: FeatureValueStats) = + FeatureStats(vs.ty, vs. key2, vs.value, vs.value, vs.value, vs.num_tweets, + vs.value, vs.num_tweets, 1, vs.num_tweets) + + def merge_stats(x1: FeatureStats, x2: FeatureStats) = { + assert(x1.ty == x2.ty) + assert(x1.key2 == x2.key2) + val (most_common_value, most_common_count) = + if (x1.most_common_count > x2.most_common_count) + (x1.most_common_value, x1.most_common_count) + else + (x2.most_common_value, x2.most_common_count) + val (least_common_value, least_common_count) = + if (x1.least_common_count < x2.least_common_count) + (x1.least_common_value, x1.least_common_count) + else + (x2.least_common_value, x2.least_common_count) + FeatureStats(x1.ty, x1.key2, + if (x1.lowest_value_by_sort < x2.lowest_value_by_sort) + x1.lowest_value_by_sort + else x2.lowest_value_by_sort, + if (x1.highest_value_by_sort > x2.highest_value_by_sort) + x1.highest_value_by_sort + else x2.highest_value_by_sort, + most_common_value, most_common_count, + least_common_value, least_common_count, + x1.num_value_types + x2.num_value_types, + x1.num_value_occurrences + x2.num_value_occurrences) + } + } + + class GetStats(opts: ParseTweetsParams) + extends ParseTweetsAction { + + import java.util.Calendar + + val operation_category = "GetStats" + + val date_fmts = Seq( + ("year", "yyyy"), // Ex: "2012" + ("year/month", "yyyy-MM (MMM)"), // Ex. "2012-07 (Jul)" + ("year/month/day", + "yyyy-MM-dd (MMM d)"), // Ex. "2012-07-05 (Jul 5)" + ("month", "'month' MM (MMM)"), // Ex. "month 07 (Jul)" + ("month/week", + "'month' MM (MMM), 'week' W"), // Ex. "month 07 (Jul), week 2" + ("month/day", "MM-dd (MMM d)"), // Ex. "07-05 (Jul 5)" + ("weekday", "'weekday' 'QQ' (EEE)"),// Ex. "weekday 2 (Mon)" + ("hour", "HHaa"), // Ex. "09am" + ("weekday/hour", + "'weekday' 'QQ' (EEE), HHaa"), // Ex. "weekday 2 (Mon), 23pm" + ("hour/weekday", + "HHaa, 'weekday' 'QQ' (EEE)"), // Ex. "23pm, weekday 2 (Mon)" + ("weekday/month", "'weekday' 'QQ' (EEE), 'month' MM (MMM)"), + // Ex. "weekday 5 (Thu), month 07 (Jul)" + ("month/weekday", "'month' MM (MMM), 'weekday' 'QQ' (EEE)") + // Ex. "month 07 (Jul), weekday 5 (Thu)" + ) + lazy val calinst = Calendar.getInstance + lazy val sdfs = + for ((engl, fmt) <- date_fmts) yield (engl, new SimpleDateFormat(fmt)) + + /** + * Format a timestamp according to a date format. This would be easy if + * not for the fact that SimpleDateFormat provides no way of inserting + * the numeric equivalent of a day of the week, which we want in order + * to make sorting turn out correctly. So we have to retrieve it using + * the `Calendar` class and shoehorn it in wherever the non-code QQ + * was inserted. + * + * NOTE, a quick guide to Java date-related classes: + * + * java.util.Date: A simple wrapper around a timestamp in "Epoch time", + * i.e. milliseconds after the Unix Epoch of Jan 1, 1970, 00:00:00 GMT. + * Formerly also used for converting timestamps into human-style + * dates, but all that stuff is long deprecated because of lack of + * internationalization support. + * java.util.Calendar: A class that supports conversion between timestamps + * and human-style dates, e.g. to figure out the year, month and day of + * a given timestamp. Supports time zones, daylight savings oddities, + * etc. Subclasses are supposed to represent different calendar systems + * but in reality there's only one, named GregorianCalendar but which + * actually supports both Gregorian (modern) and Julian (old-style, + * with no leap-year special-casing of years divisible by 100) calendars, + * with a configurable cross-over point. Support for other calendars + * (Islamic, Jewish, etc.) is provided by third-party libraries (e.g. + * JodaTime), which typically discard the java.util.Calendar framework + * and create their own. + * java.util.DateFormat: A class that supports conversion between + * timestamps and human-style dates nicely formatted into a string, e.g. + * generating strings like "Jul 23, 2012 08:05pm". Again, theoretically + * there are particular subclasses to support different formatting + * mechanisms but in reality there's only one, SimpleDateFormat. + * Readable dates are formatted using a template, and there's a good + * deal of localization-specific stuff under the hood that theoretically + * the programmer doesn't need to worry about. + */ + def format_date(time: Timestamp, fmt: SimpleDateFormat) = { + val output = fmt.format(new Date(time)) + calinst.setTimeInMillis(time) + val weekday = calinst.get(Calendar.DAY_OF_WEEK) + output.replace("QQ", weekday.toString) + } + + def stats_for_tweet(tweet: Tweet) = { + Seq(FeatureValueStats.from_tweet(tweet, "user", "user", tweet.user), + // Get a summary for all languages plus a summary for each lang + FeatureValueStats.from_tweet(tweet, "lang", tweet.lang, ""), + FeatureValueStats.from_tweet(tweet, "lang", "lang", tweet.lang)) ++ + sdfs.map { case (engl, fmt) => + FeatureValueStats.from_tweet( + tweet, engl, format_date(tweet.min_timestamp, fmt), "") } + } + + /** + * Compute statistics on a DList of tweets. + */ + def get_by_value(tweets: DList[Tweet]) = { + /* Operations: + + 1. For each tweet, and for each feature we're interested in getting + stats on (e.g. users, languages, etc.), generate a tuple + (keytype, key, value) that has the type of feature as `keytype` + (e.g. "user"), the value of the feature in `key` (e.g. the user + name), and some sort of stats object (e.g. `FeatureStats`), + giving statistics on that user (etc.) derived from the individual + tweet. + + 2. Take the resulting DList and group by grouping key. Combine the + resulting `FeatureStats` together by adding their values or + taking max/min or whatever. (If there are multiple feature types, + we might have multiple classes involved, so we need to condition + on the feature type.) + + 3. The resulting DList has one entry per feature value, giving + stats on all tweets corresponding to that feature value. We + want to aggregate again of feature type, to get statistics on + the whole type (e.g. how many different feature values, how + often they occur). + */ + tweets.flatMap(x => stats_for_tweet(x)). + groupBy({ stats => (stats.ty, stats.key2, stats.value)}). + combine(FeatureValueStats.merge_stats). + map(_._2) + } + + def get_by_type(values: DList[FeatureValueStats]) = { + values.map(FeatureStats.from_value_stats(_)). + groupBy({ stats => (stats.ty, stats.key2) }). + combine(FeatureStats.merge_stats). + map(_._2) + } + } + + class ParseTweetsDriver(opts: ParseTweetsParams) + extends ParseTweetsAction { + + val operation_category = "Driver" + + def corpus_suffix = { + val dist_type = if (opts.max_ngram == 1) "unigram" else "ngram" + "%s-%s-counts-tweets" format (opts.split, dist_type) + } + + /** + * Output a schema file of the appropriate name. + */ + def output_schema(filehand: FileHandler) { + val filename = Schema.construct_schema_file(filehand, + opts.output, opts.corpus_name, corpus_suffix) + logger.info("Outputting a schema to %s ..." format filename) + // We add the counts data to what to_row() normally outputs so we + // have to add the same field here + val fields = Tweet.row_fields(opts) + val fixed_fields = Map( + "corpus-name" -> opts.corpus_name, + "generating-app" -> "ParseTweets", + "corpus-type" -> ("twitter-%s" format + (if (opts.grouping == "none") "tweets" else opts.grouping))) ++ + opts.non_default_params_string.toMap ++ + Map( + "grouping" -> opts.grouping, + "output-format" -> opts.output_format, + "split" -> opts.split + ) ++ ( + if (opts.grouping == "time") + Map("corpus-timeslice" -> opts.timeslice.toString) + else + Map[String, String]() + ) + val schema = new Schema(fields, fixed_fields) + schema.output_schema_file(filehand, filename) + } + } + + def create_params(ap: ArgParser) = new ParseTweetsParams(ap) + val progname = "ParseTweets" + + private def grouping_type_english(opts: ParseTweetsParams, + grouptype: String) = { + grouptype match { + case "time" => + "time, with slices of %g seconds".format(opts.timeslice_float) + case "user" => + "user" + case "file" => + "file" + case "none" => + "not grouping" + } + } + + def run() { + val opts = init_scoobi_app() + val filehand = new HadoopFileHandler(configuration) + if (opts.corpus_name == null) { + val (_, last_component) = filehand.split_filename(opts.input) + opts.corpus_name = last_component.replace("*", "_") + } + errprint("ParseTweets: " + (opts.grouping match { + case "none" => "not grouping output" + case x => "grouping output by " + grouping_type_english(opts, x) + })) + errprint("ParseTweets: " + (opts.filter_grouping match { + case "none" => "no group-level filtering" + case x => "group-level filtering by " + grouping_type_english(opts, x) + })) + errprint("ParseTweets: " + (opts.has_tweet_filtering match { + case false => "no tweet-level filtering" + case true => "doing tweet-level filtering" + })) + errprint("ParseTweets: " + (opts.output_format match { + case "textdb" => "outputting in textdb format" + case "json" => "outputting as raw JSON" + case "stats" => "outputting statistics on tweets" + })) + val ptp = new ParseTweetsDriver(opts) + // Firstly we load up all the (new-line-separated) JSON or textdb lines. + val lines: DList[(String, String)] = { + opts.input_format match { + case "textdb" => { + val insuffix = "tweets" + opts.input_schema = TextDBProcessor.read_schema_from_textdb( + filehand, opts.input, insuffix) + val matching_patterns = TextDBProcessor.get_matching_patterns( + filehand, opts.input, insuffix) + TextInput.fromTextFileWithPath(matching_patterns: _*) + } + case "json" | "raw-lines" => TextInput.fromTextFileWithPath(opts.input) + } + } + + errprint("ParseTweets: Generate tweets ...") + val tweets1 = new ParseAndUniquifyTweets(opts)(lines) + + /* Maybe group tweets */ + val tweets = new GroupTweets(opts)(tweets1) + + // Construct output directory for a given corpus suffix, based on + // user-provided output directory + def output_directory_for_suffix(corpus_suffix: String) = + opts.output + "-" + corpus_suffix + + // create a schema given a set of data fields plus user params + def create_schema(fields: Seq[String]) = + new Schema(fields, Map("corpus-name" -> opts.corpus_name)) + + // output lines of data in a DList to a corpus + def dlist_output_lines(lines: DList[String], + corpus_suffix: String, fields: Seq[String]) = { + // get output directory + val outdir = output_directory_for_suffix(corpus_suffix) + + // output data file + persist(TextOutput.toTextFile(lines, outdir)) + rename_output_files(outdir, opts.corpus_name, corpus_suffix) + + // output schema file + val out_schema = create_schema(fields) + out_schema.output_constructed_schema_file(filehand, outdir, + opts.corpus_name, corpus_suffix) + outdir + } + + // output lines of data in an Iterable to a corpus + def local_output_lines(lines: Iterable[String], + corpus_suffix: String, fields: Seq[String]) = { + // get output directory + val outdir = output_directory_for_suffix(corpus_suffix) + + // output data file + filehand.make_directories(outdir) + val outfile = TextDBProcessor.construct_output_file(filehand, outdir, + opts.corpus_name, corpus_suffix, ".txt") + val outstr = filehand.openw(outfile) + lines.map(outstr.println(_)) + outstr.close() + + // output schema file + val out_schema = create_schema(fields) + out_schema.output_constructed_schema_file(filehand, outdir, + opts.corpus_name, corpus_suffix) + outdir + } + + def rename_outfiles() { + rename_output_files(opts.output, opts.corpus_name, ptp.corpus_suffix) + } + + opts.output_format match { + case "json" => { + /* We're outputting JSON's directly. */ + persist(TextOutput.toTextFile(tweets.map(_.json), opts.output)) + rename_outfiles() + } + case "textdb" => { + val tfct = new TokenizeCountAndFormat(opts) + // Tokenize the combined text into words, possibly generate ngrams + // from them, count them up and output results formatted into a record. + val nicely_formatted = tweets.map(tfct.tokenize_count_and_format) + persist(TextOutput.toTextFile(nicely_formatted, opts.output)) + rename_outfiles() + // create a schema + ptp.output_schema(filehand) + } + case "stats" => { + val get_stats = new GetStats(opts) + val by_value = get_stats.get_by_value(tweets) + val dlist_by_type = get_stats.get_by_type(by_value) + val by_type = persist(dlist_by_type.materialize).toSeq.sorted + val stats_suffix = "stats" + local_output_lines(by_type.map(_.to_row(opts)), + stats_suffix, FeatureStats.row_fields) + val userstat = by_type.filter(x => + x.ty == "user" && x.key2 == "user").toSeq(0) + errprint("\nCombined summary:") + errprint("%s tweets by %s users = %.2f tweets/user", + with_commas(userstat.num_value_occurrences), + with_commas(userstat.num_value_types), + userstat.num_value_occurrences.toDouble / userstat.num_value_types) + val monthstat = by_type.filter(_.ty == "year/month") + errprint("\nSummary by month:") + for (mo <- monthstat) + errprint("%-20s: %12s tweets", mo.key2, + with_commas(mo.num_value_occurrences)) + val daystat = by_type.filter(_.ty == "year/month/day") + errprint("\nSummary by day:") + for (d <- daystat) + errprint("%-20s: %12s tweets", d.key2, + with_commas(d.num_value_occurrences)) + errprint("\n") + } + } + + errprint("ParseTweets: done.") + + finish_scoobi_app(opts) + } + /* + + To build a classifier for conserv vs liberal: + + 1. Look for people retweeting congressmen or governor tweets, possibly at + some minimum level of retweeting (or rely on followers, for some + minimum number of people following) + 2. Make sure either they predominate having retweets from one party, + and/or use the DW-NOMINATE scores to pick out people whose average + ideology score of their retweets is near the extremes. + */ +} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/Permute.scala b/src/main/scala/opennlp/fieldspring/preprocess/Permute.scala new file mode 100644 index 0000000..f38b106 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/Permute.scala @@ -0,0 +1,69 @@ +/////////////////////////////////////////////////////////////////////////////// +// Permute.scala +// +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import util.Random +import com.nicta.scoobi.Scoobi._ +import java.io._ + +/* + * This program randomly permutes all the lines in a text file, using Hadoop + * and Scoobi. + */ + +object Permute extends ScoobiApp { + val rnd = new Random + + def generate_key(line: String): (Double, String) = { + (rnd.nextDouble, line) + } + + def remove_key(kvs: (Double, Iterable[String])): Iterable[String] = { + val (key, values) = kvs + for (v <- values) + yield v + } + + def run() { + // make sure we get all the input + val (inputPath, outputPath) = + if (args.length == 2) { + (args(0), args(1)) + } else { + sys.error("Expecting input and output path.") + } + + // Firstly we load up all the (new-line-seperated) json lines + val lines: DList[String] = TextInput.fromTextFile(inputPath) + + // randomly generate keys + val with_keys = lines.map(generate_key) + + // sort by keys + val keys_sorted = with_keys.groupByKey + + // remove keys + val keys_removed = keys_sorted.flatMap(remove_key) + + // save to disk + persist(TextOutput.toTextFile(keys_removed, outputPath)) + + } +} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ProcessFiles.scala b/src/main/scala/opennlp/fieldspring/preprocess/ProcessFiles.scala new file mode 100644 index 0000000..af24aa9 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ProcessFiles.scala @@ -0,0 +1,70 @@ +/////////////////////////////////////////////////////////////////////////////// +// ProcessFiles.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil._ + +/* + Common code for doing basic file-processing operations. + + FIXME: It's unclear there's enough code to justify factoring it out + like this. +*/ + +///////////////////////////////////////////////////////////////////////////// +// Main code // +///////////////////////////////////////////////////////////////////////////// + +/** + * Class for defining and retrieving command-line arguments. Consistent + * with "field-style" access to an ArgParser, this class needs to be + * instantiated twice with the same ArgParser object, before and after parsing + * the command line. The first instance defines the allowed arguments in the + * ArgParser, while the second one retrieves the values stored into the + * ArgParser as a result of parsing. + * + * @param ap ArgParser object. + */ +class ProcessFilesParameters(ap: ArgParser) extends + ArgParserParameters(ap) { + val output_dir = + ap.option[String]("o", "output-dir", + metavar = "DIR", + help = """Directory to store output files in. It must not already +exist, and will be created (including any parent directories).""") +} + +abstract class ProcessFilesDriver extends HadoopableArgParserExperimentDriver { + override type TParam <: ProcessFilesParameters + type TRunRes = Unit + + def handle_parameters() { + need(params.output_dir, "output-dir") + } + + def setup_for_run() { } + + def run_after_setup() { + if (!get_file_handler.make_directories(params.output_dir)) + param_error("Output dir %s must not already exist" format + params.output_dir) + } +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ScoobiConvertTwitterInfochimps.scala b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiConvertTwitterInfochimps.scala new file mode 100644 index 0000000..1ea44a6 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiConvertTwitterInfochimps.scala @@ -0,0 +1,111 @@ +/////////////////////////////////////////////////////////////////////////////// +// ScoobiConvertTwitterInfochimps.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import java.io._ +import java.io.{FileSystem=>_,_} +import util.control.Breaks._ + +import org.apache.hadoop.io._ +import org.apache.hadoop.util._ +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.conf.{Configuration, Configured} +import org.apache.hadoop.fs._ + +import com.nicta.scoobi.Scoobi._ + +/* + +Steps for converting Infochimps to our format: + +1) Input is a series of files, e.g. part-00000.gz, each about 180 MB. +2) Each line looks like this: + + +100000018081132545 20110807002716 25430513 GTheHardWay Niggas Lost in the Sauce ..smh better slow yo roll and tell them hoes to get a job nigga #MRIloveRatsIcanchange&amp;saveherassNIGGA <a href="http://twitter.com/download/android" rel="nofollow">Twitter for Android</a> en 42.330165 -83.045913 +The fields are: + +1) Tweet ID +2) Time +3) User ID +4) User name +5) Empty? +6) User name being replied to (FIXME: which JSON field is this?) +7) User ID for replied-to user name (but sometimes different ID's for same user name) +8) Empty? +9) Tweet text -- double HTML-encoded (e.g. & becomes &amp;) +10) HTML anchor text indicating a link of some sort, HTML-encoded (FIXME: which JSON field is this?) +11) Language, as a two-letter code +12) Latitude +13) Longitude +14) Empty? +15) Empty? +16) Empty? +17) Empty? + + +3) We want to convert each to two files: (1) containing the article-data + +*/ + +/** + * Convert files in the Infochimps Twitter corpus into files in our format. + */ + +object ScoobiConvertTwitterInfochimps extends ScoobiApp { + + def usage() { + sys.error("""Usage: ConvertTwitterInfochimps INFILE OUTDIR + +Convert input files in the Infochimps Twitter corpus into files in the +format expected by Fieldspring. INFILE is a single file or a glob. +OUTDIR is the directory to store the results in. +""") + } + + def run() { + if (args.length != 2) + usage() + val infile = args(0) + val outdir = args(1) + + val fields = List("id", "title", "split", "coord", "time", + "username", "userid", "reply_username", "reply_userid", "anchor", "lang") + + val tweets = fromDelimitedTextFile("\t", infile) { + case id :: time :: userid :: username :: _ :: + reply_username :: reply_userid :: _ :: text :: anchor :: lang :: + lat :: long :: _ :: _ :: _ :: _ => { + val metadata = List(id, id, "training", "%s,%s" format (lat, long), + time, username, userid, reply_username, reply_userid, anchor, lang) + val splittext = text.split(" ") + val textdata = List(id, id, splittext) + (metadata mkString "\t", textdata mkString "\t") + } + } + persist ( + TextOutput.toTextFile(tweets.map(_._1), + "%s-twitter-infochimps-combined-document-data.txt" format outdir), + TextOutput.toTextFile(tweets.map(_._2), + "%s-twitter-infochimps-text-data.txt" format outdir) + ) + } +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ScoobiProcessFilesApp.scala b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiProcessFilesApp.scala new file mode 100644 index 0000000..0639a44 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiProcessFilesApp.scala @@ -0,0 +1,244 @@ +// ScoobiProcessFilesApp.scala +// +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +/** + * This file provides support for a Scoobi application that processes files. + */ + +package opennlp.fieldspring.preprocess + +import java.io._ + +import org.apache.commons.logging.LogFactory +import org.apache.log4j.{Level=>JLevel,_} +import org.apache.hadoop.fs.{FileSystem => HFileSystem, Path, FileStatus} + +import com.nicta.scoobi.Scoobi._ +// import com.nicta.scoobi.testing.HadoopLogFactory +import com.nicta.scoobi.application.HadoopLogFactory + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.osutil._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.printutil._ + +class ScoobiProcessFilesParams(val ap: ArgParser) { + var debug = ap.flag("debug", + help="""Output debug info about data processing.""") + var debug_file = ap.option[String]("debug-file", + help="""File to write debug info to, instead of stderr.""") + var input = ap.positional[String]("INPUT", + help = "Source directory to read files from.") + var output = ap.positional[String]("OUTPUT", + help = "Destination directory to place files in.") + + /** + * Check usage of command-line parameters and use `ap.usageError` to + * signal that an error occurred. + */ + def check_usage() {} + + def non_default_params = { + for (name <- ap.argNames if (ap.specified(name))) yield + (name, ap(name)) + } + def non_default_params_string = { + non_default_params.map { + case (param,x) => (param, "%s" format x) } + } +} + +/** + * An "action" -- simply used to encapsulate code to perform arbitrary + * operations. Encapsulating them like this allows information (e.g. + * command-line options) to be made available to the routines without + * using global variables, which won't work when the routines to be + * executed are done in a Hadoop task rather than on the client (which + * runs on the user's own machine, typically the Hadoop job node). + * You need to set the values of `progname` and `operation_category`, + * which appear in log messages (output using `warning`) and in + * counters (incremented using `bump_counter`). + */ +trait ScoobiProcessFilesAction { + + def progname: String + val operation_category: String + def full_operation_category = progname + "." + operation_category + + var lineno = 0 + lazy val logger = LogFactory.getLog(full_operation_category) + + def warning(line: String, fmt: String, args: Any*) { + logger.warn("Line %d: %s: %s" format + (lineno, fmt format (args: _*), line)) + } + + def bump_counter(counter: String, amount: Long = 1) { + incrCounter(full_operation_category, counter, amount) + } + + /** + * A class used internally by `error_wrap`. + */ + private class ErrorWrapper { + // errprint("Created an ErrorWrapper") + var lineno = 0 + } + + /** + * This is used to keep track of the line number. Theoretically we should + * be able to key off of `fun` itself but this doesn't actually work, + * because a new object is created each time to hold the environment of + * the function. However, using the class of the function works well, + * because internally each anonymous function is implemented by defining + * a new class that underlyingly implements the function. + */ + private val wrapper_map = + defaultmap[Class[_], ErrorWrapper](new ErrorWrapper, setkey = true) + + /** + * Wrapper function used to catch errors when doing line-oriented + * (or record-oriented) processing. It is passed a value (typically, + * the line or record to be processed), a function to process the + * value, and a default value to be returned upon error. If an error + * occurs during execution of the function, we log the line number, + * the value that triggered the error, and the error itself (including + * stack trace), and return the default. + * + * We need to create internal state in order to track the line number. + * This function is able to handle multiple overlapping or nested + * invocations of `error_wrap`, keyed (approximately) on the particular + * function invoked. + * + * @param value Value to process + * @param fun Function to use to process the value + * @param default Default value to be returned during error + */ + def error_wrap[T, U](value: T, default: => U)(fun: T => U) = { + // errprint("error_wrap called with fun %s", fun) + // errprint("class is %s", fun.getClass) + val wrapper = wrapper_map(fun.getClass) + // errprint("got wrapper %s", wrapper) + try { + wrapper.lineno += 1 + fun(value) + } catch { + case e: Exception => { + logger.warn("Line %d: %s: %s\n%s" format + (wrapper.lineno, e, value, stack_trace_as_string(e))) + default + } + } + } +} + +abstract class ScoobiProcessFilesApp[ParamType <: ScoobiProcessFilesParams] + extends ScoobiApp with ScoobiProcessFilesAction { + + // This is necessary to turn off the LibJars mechanism, which is somewhat + // buggy and interferes with assemblies. + override def upload = false + def create_params(ap: ArgParser): ParamType + val operation_category = "MainApp" + def output_command_line_parameters(arg_parser: ArgParser) { + // Output using errprint() rather than logger() so that the results + // stand out more. + errprint("") + errprint("Non-default parameter values:") + for (name <- arg_parser.argNames) { + if (arg_parser.specified(name)) + errprint("%30s: %s" format (name, arg_parser(name))) + } + errprint("") + errprint("Parameter values:") + for (name <- arg_parser.argNames) { + errprint("%30s: %s" format (name, arg_parser(name))) + //errprint("%30s: %s" format (name, arg_parser.getType(name))) + } + errprint("") + } + + def init_scoobi_app() = { + initialize_osutil() + val ap = new ArgParser(progname) + // This first call is necessary, even though it doesn't appear to do + // anything. In particular, this ensures that all arguments have been + // defined on `ap` prior to parsing. + create_params(ap) + // Here and below, output using errprint() rather than logger() so that + // the basic steps stand out more -- when accompanied by typical logger + // prefixes, they easily disappear. + errprint("Parsing args: %s" format (args mkString " ")) + ap.parse(args) + val Opts = create_params(ap) + Opts.check_usage() + enableCounterLogging() + if (Opts.debug) { + HadoopLogFactory.setQuiet(false) + HadoopLogFactory.setLogLevel(HadoopLogFactory.TRACE) + LogManager.getRootLogger().setLevel(JLevel.DEBUG.asInstanceOf[JLevel]) + } + if (Opts.debug_file != null) + set_errout_file(Opts.debug_file) + output_command_line_parameters(ap) + Opts + } + + def finish_scoobi_app(Opts: ParamType) { + errprint("All done with everything.") + errprint("") + output_resource_usage() + } + + /** + * Given a Hadoop-style path specification (specifying a single directory, + * a single file, or a glob), return a list of all files specified. + */ + def files_of_path_spec(spec: String) = { + /** + * Expand a file status possibly referring to a directory to a list of + * the files within. FIXME: Should this be recursive? + */ + def expand_dirs(status: FileStatus) = { + if (status.isDir) + configuration.fs.listStatus(status.getPath) + else + Array(status) + } + + configuration.fs.globStatus(new Path(spec)). + flatMap(expand_dirs). + map(_.getPath.toString) + } + + def rename_output_files(dir: String, corpus_name: String, suffix: String) { + // Rename output files appropriately + errprint("Renaming output files ...") + val globpat = "%s/*-r-*" format dir + val fs = configuration.fs + for (file <- fs.globStatus(new Path(globpat))) { + val path = file.getPath + val basename = path.getName + val newname = "%s/%s-%s-%s.txt" format ( + dir, corpus_name, basename, suffix) + errprint("Renaming %s to %s" format (path, newname)) + fs.rename(path, new Path(newname)) + } + } +} + diff --git a/src/main/scala/opennlp/fieldspring/preprocess/ScoobiWordCount.scala b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiWordCount.scala new file mode 100644 index 0000000..6b38786 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/ScoobiWordCount.scala @@ -0,0 +1,47 @@ +package opennlp.fieldspring.preprocess + +import com.nicta.scoobi.Scoobi._ +// import com.nicta.scoobi.testing.HadoopLogFactory +import com.nicta.scoobi.application.HadoopLogFactory +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.fs.FileSystem +import java.io._ + +object ScoobiWordCount extends ScoobiApp { + def run() { + // There's some magic here in the source code to make the get() call + // work -- there's an implicit conversion in object ScoobiConfiguration + // from a ScoobiConfiguration to a Hadoop Configuration, which has get() + // defined on it. Evidently implicit conversions in the companion object + // get made available automatically for classes or something? + System.err.println("mapred.job.tracker " + + configuration.get("mapred.job.tracker", "value not found")) + // System.err.println("job tracker " + jobTracker) + // System.err.println("file system " + fs) + System.err.println("configure file system " + configuration.fs) + System.err.println("file system key " + + configuration.get(FileSystem.FS_DEFAULT_NAME_KEY, "value not found")) + + val lines = + // Test fromTextFileWithPath, but currently appears to trigger an + // infinite loop. + // TextInput.fromTextFileWithPath(args(0)) + TextInput.fromTextFile(args(0)).map(x => (args(0), x)) + + def splitit(x: String) = { + HadoopLogFactory.setQuiet(false) + // val logger = LogFactory.getLog("foo.bar") + // logger.info("Processing " + x) + // System.err.println("Processing", x) + x.split(" ") + } + //val counts = lines.flatMap(_.split(" ")) + val counts = lines.map(_._2).flatMap(splitit) + .map(word => (word, 1)) + .groupByKey + .filter { case (word, lens) => word.length < 8 } + .filter { case (word, lens) => lens.exists(x => true) } + .combine((a: Int, b: Int) => a + b) + persist(toTextFile(counts, args(1))) + } +} diff --git a/src/main/scala/opennlp/fieldspring/preprocess/TwitterPullLocationVariance.scala b/src/main/scala/opennlp/fieldspring/preprocess/TwitterPullLocationVariance.scala new file mode 100644 index 0000000..0bc5e13 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/preprocess/TwitterPullLocationVariance.scala @@ -0,0 +1,177 @@ +/////////////////////////////////////////////////////////////////////////////// +// TwitterPullLocationVariance.scala +// +// Copyright (C) 2012 Stephen Roller, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.preprocess + +import net.liftweb.json +import com.nicta.scoobi.Scoobi._ +import java.io._ +import java.lang.Double.isNaN +import java.text.{SimpleDateFormat, ParseException} +import math.pow + +import opennlp.fieldspring.util.Twokenize +import opennlp.fieldspring.util.distances.{spheredist, SphereCoord} + +/* + * This program takes, as input, files which contain one tweet + * per line in json format as directly pulled from the twitter + * api. It outputs a folder that may be used as the + * --input-corpus argument of tg-geolocate. + */ + +object TwitterPullLocationVariance extends ScoobiApp { + type Tweet = (Double, Double) + type Record = (String, Tweet) + + def force_value(value: json.JValue): String = { + if ((value values) == null) + null + else + (value values) toString + } + + def is_valid_tweet(id_r: (String, Record)): Boolean = { + // filters out invalid tweets, as well as trivial spam + val (tw_id, (a, (lat, lng))) = id_r + a != "" && tw_id != "" && !isNaN(lat) && !isNaN(lng) + } + + def preferred_latlng(latlng1: (Double, Double), latlng2: (Double, Double)): (Double, Double) = { + val (lat1, lng1) = latlng1 + val (lat2, lng2) = latlng2 + if (!isNaN(lat1) && !isNaN(lng1)) + (lat1, lng1) + else + (lat2, lng2) + } + + val empty_tweet: (String, Record) = ("", ("", (Double.NaN, Double.NaN))) + + def parse_json(line: String): (String, Record) = { + try { + val parsed = json.parse(line) + val author = force_value(parsed \ "user" \ "screen_name") + val tweet_id = force_value(parsed \ "id_str") + val (blat, blng) = + try { + val bounding_box = (parsed \ "place" \ "bounding_box" \ "coordinates" values) + .asInstanceOf[List[List[List[Double]]]](0) + val bounding_box_sum = bounding_box.reduce((a, b) => List(a(1) + b(1), a(0) + b(0))) + (bounding_box_sum(0) / bounding_box.length, bounding_box_sum(1) / bounding_box.length) + } catch { + case npe: NullPointerException => (Double.NaN, Double.NaN) + case cce: ClassCastException => (Double.NaN, Double.NaN) + } + val (plat, plng) = + if ((parsed \ "coordinates" values) == null || + (force_value(parsed \ "coordinates" \ "type") != "Point")) { + (Double.NaN, Double.NaN) + } else { + val latlng: List[Number] = + (parsed \ "coordinates" \ "coordinates" values).asInstanceOf[List[Number]] + (latlng(1).doubleValue, latlng(0).doubleValue) + } + val (lat, lng) = preferred_latlng((plat, plng), (blat, blng)) + (tweet_id, (author, (lat, lng))) + } catch { + case jpe: json.JsonParser.ParseException => empty_tweet + case npe: NullPointerException => empty_tweet + case nfe: NumberFormatException => empty_tweet + } + } + + def tweet_once(id_rs: (String, Iterable[Record])): Record = { + val (id, rs) = id_rs + rs.head + } + + def has_latlng(r: Record): Boolean = { + val (a, (lat, lng)) = r + !isNaN(lat) && !isNaN(lng) + } + + def cartesian_product[T1, T2](A: Seq[T1], B: Seq[T2]): Iterable[(T1, T2)] = { + for (a <- A; b <- B) yield (a, b) + } + + def mean_variance_and_maxdistance(inpt: (String, Iterable[(Double, Double)])): + // Author, AvgLat, AvgLng, AvgDistance, DistanceVariance, MaxDistance + (String, Double, Double, Double, Double, Double) = { + val (author, latlngs_i) = inpt + val latlngs = latlngs_i.toSeq + val lats = latlngs.map(_._1) + val lngs = latlngs.map(_._2) + + val avgpoint = SphereCoord(lats.sum / lats.length, lngs.sum / lngs.length) + val allpoints = latlngs.map(ll => SphereCoord(ll._1, ll._2)) + val distances = allpoints.map(spheredist(_, avgpoint)) + val avgdistance = distances.sum / distances.length + val distancevariance = distances.map(x => pow(x - avgdistance, 2)).sum / distances.length + + val maxdistance = cartesian_product(allpoints, allpoints) + .map{case (a, b) => spheredist(a, b)}.max + + (author, avgpoint.lat, avgpoint.long, avgdistance, distancevariance, maxdistance) + } + + def nicely_format(r: (String, Double, Double, Double, Double, Double)): String = { + val (a, b, c, d, e, f) = r + Seq(a, b, c, d, e, f) mkString "\t" + } + + def checkpoint_str(r: Record): String = { + val (a, (lat, lng)) = r + a + "\t" + lat + "\t" + lng + } + + def from_checkpoint_to_record(s: String): Record = { + val s_a = s.split("\t") + (s_a(0), (s_a(1).toDouble, s_a(2).toDouble)) + } + + def run() { + val (inputPath, outputPath) = + if (args.length == 2) { + (args(0), args(1)) + } else { + sys.error("Expecting input and output path.") + } + + /* + val lines: DList[String] = TextInput.fromTextFile(inputPath) + + val values_extracted = lines.map(parse_json).filter(is_valid_tweet) + val single_tweets = values_extracted.groupByKey.map(tweet_once) + .filter(has_latlng) + + val checkpointed = single_tweets.map(checkpoint_str) + persist(TextOutput.toTextFile(checkpointed, inputPath + "-st")) + */ + + val single_tweets_lines: DList[String] = TextInput.fromTextFile(inputPath + "-st") + val single_tweets_reloaded = single_tweets_lines.map(from_checkpoint_to_record) + val grouped_by_author = single_tweets_reloaded.groupByKey + + val averaged = grouped_by_author.map(mean_variance_and_maxdistance) + + val nicely_formatted = averaged.map(nicely_format) + persist(TextOutput.toTextFile(nicely_formatted, outputPath)) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToPlaintext.scala b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToPlaintext.scala new file mode 100644 index 0000000..3d9021a --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToPlaintext.scala @@ -0,0 +1,36 @@ +package opennlp.fieldspring.tr.app + +import java.io._ + +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object ConvertCorpusToPlaintext extends App { + + val outDirName = if(args(1).endsWith("/")) args(1) else args(1)+"/" + val outDir = new File(outDirName) + if(!outDir.exists) + outDir.mkdir + + val tokenizer = new OpenNLPTokenizer + + val corpus = Corpus.createStoredCorpus + corpus.addSource(new TrXMLDirSource(new File(args(0)), tokenizer)) + corpus.setFormat(BaseApp.CORPUS_FORMAT.TRCONLL) + corpus.load + + for(doc <- corpus) { + val out = new BufferedWriter(new FileWriter(outDirName+doc.getId+".txt")) + for(sent <- doc) { + for(token <- sent) { + out.write(token.getForm+" ") + } + out.write("\n") + } + out.close + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToToponymAsDoc.scala b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToToponymAsDoc.scala new file mode 100644 index 0000000..bcdfa32 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToToponymAsDoc.scala @@ -0,0 +1,65 @@ +package opennlp.fieldspring.tr.app + +import java.io._ + +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +object ConvertCorpusToToponymAsDoc extends App { + + val windowSize = if(args.length >= 2) args(1).toInt else 0 + + val alphanumRE = """^[a-zA-Z0-9]+$""".r + + val tokenizer = new OpenNLPTokenizer + + val corpus = Corpus.createStoredCorpus + corpus.addSource(new TrXMLDirSource(new File(args(0)), tokenizer)) + corpus.setFormat(BaseApp.CORPUS_FORMAT.TRCONLL) + corpus.load + + for(doc <- corpus) { + val docAsArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].hasGold) { + val goldCoord = token.asInstanceOf[Toponym].getGold.getRegion.getCenter + + val unigramCounts = getUnigramCounts(docAsArray, tokIndex, windowSize) + + print(doc.getId.drop(1)+"_"+tokIndex+"\t") + print(doc.getId+"_"+tokIndex+"\t") + print(goldCoord.getLatDegrees+","+goldCoord.getLngDegrees+"\t") + print("1\t\tMain\tno\tno\tno\t") + //print(token.getForm+":"+1+" ")\ + for((word, count) <- unigramCounts) { + print(word+":"+count+" ") + } + println + } + tokIndex += 1 + } + } + + def getUnigramCounts(docAsArray:Array[StoredToken], tokIndex:Int, windowSize:Int): Map[String, Int] = { + + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(docAsArray.length, tokIndex + windowSize + 1) + + val unigramCounts = new collection.mutable.HashMap[String, Int] + + for(rawToken <- docAsArray.slice(startIndex, endIndex)) { + for(token <- rawToken.getForm.split(" ")) { + val prevCount = unigramCounts.getOrElse(token, 0) + unigramCounts.put(token, prevCount + 1) + } + } + + unigramCounts.toMap + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToUnigramCounts.scala b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToUnigramCounts.scala new file mode 100644 index 0000000..b47bc07 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCorpusToUnigramCounts.scala @@ -0,0 +1,84 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object ConvertCorpusToUnigramCounts extends BaseApp { + + val alphanumRE = """^[a-z0-9]+$""".r + + //val tokenizer = new OpenNLPTokenizer + + def main(args:Array[String]) { + + initializeOptionsFromCommandLine(args); + + /*var corpus = Corpus.createStoredCorpus + + if(getCorpusFormat == BaseApp.CORPUS_FORMAT.PLAIN/**/) { + /* + val tokenizer = new OpenNLPTokenizer + //val recognizer = new OpenNLPRecognizer + //val gis = new GZIPInputStream(new FileInputStream(args(1))) + //val ois = new ObjectInputStream(gis) + //val gnGaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + //gis.close + corpus.addSource(new PlainTextSource( + new BufferedReader(new FileReader(args(0))), new OpenNLPSentenceDivider(), tokenizer)) + //corpus.addSource(new ToponymAnnotator(new PlainTextSource( + // new BufferedReader(new FileReader(args(0))), new OpenNLPSentenceDivider(), tokenizer), + // recognizer, gnGaz, null)) + corpus.setFormat(BaseApp.CORPUS_FORMAT.PLAIN) + */ + val importCorpus = new ImportCorpus + //if(args(0).endsWith("txt")) + corpus = importCorpus.doImport(getCorpusInputPath, , getCorpusFormat, false) + //else + // corpus = importCorpus + } + else if(getCorpusFormat == BaseApp.CORPUS_FORMAT.TRCONLL) { + corpus.addSource(new TrXMLDirSource(new File(args(0)), tokenizer)) + corpus.setFormat(BaseApp.CORPUS_FORMAT.TRCONLL) + corpus.load + } + //corpus.load*/ + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + var i = 0 + for(doc <- corpus) { + val unigramCounts = new collection.mutable.HashMap[String, Int] + for(sent <- doc) { + for(rawToken <- sent) { + for(token <- rawToken.getForm.split(" ")) { + val ltoken = token.toLowerCase + if(alphanumRE.findFirstIn(ltoken) != None) { + val prevCount = unigramCounts.getOrElse(ltoken, 0) + unigramCounts.put(ltoken, prevCount + 1) + } + } + } + } + + print(i/*doc.getId.drop(1)*/ +"\t") + print(doc.getId+"\t") + print("0,0\t") + print("1\t\tMain\tno\tno\tno\t") + for((word, count) <- unigramCounts) { + print(word+":"+count+" ") + } + println + i += 1 + } + + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ConvertCwarToGoldCorpus.scala b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCwarToGoldCorpus.scala new file mode 100644 index 0000000..7df81e0 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ConvertCwarToGoldCorpus.scala @@ -0,0 +1,135 @@ +package opennlp.fieldspring.tr.app + +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text.io._ +import java.util.zip._ + +import java.io._ + +import scala.collection.JavaConversions._ + +object ConvertCwarToGoldCorpus extends App { + + val digitsRE = """^\d+$""".r + val floatRE = """^-?\d*\.\d*$""".r + val toponymRE = """^tgn,(\d+)-(.+)-]][^\s]*$""".r + + val corpusFiles = new File(args(0)).listFiles + val goldKml = scala.xml.XML.loadFile(args(1)) + val gazIn = args(2) + + + val tgnToCoord = + //(goldKml \\ "Placemark").foreach { placemark => + (for(placemark <- (goldKml \\ "Placemark")) yield { + var tgn = -1 + (placemark \\ "Data").foreach { data => + val name = (data \ "@name").text + if(name.equals("tgn")) { + val text = data.text.trim + if(digitsRE.findFirstIn(text) != None) + tgn = text.toInt + } + } + var coordsRaw = "" + (placemark \\ "coordinates").foreach { coordinates => + coordsRaw = coordinates.text.trim.dropRight(2) + } + + val coordsSplit = coordsRaw.split(",") + + if(tgn == -1 || coordsSplit.size != 2 || floatRE.findFirstIn(coordsSplit(0)) == None + || floatRE.findFirstIn(coordsSplit(1)) == None) + None + else { + val lng = coordsSplit(0).toDouble + val lat = coordsSplit(1).toDouble + Some((tgn, Coordinate.fromDegrees(lat, lng))) + } + }).flatten.toMap + + //tgnToCoord.foreach(println) + + + var gaz:GeoNamesGazetteer = null; + //println("Reading serialized GeoNames gazetteer from " + gazIn + " ...") + var ois:ObjectInputStream = null; + if(gazIn.toLowerCase().endsWith(".gz")) { + val gis = new GZIPInputStream(new FileInputStream(gazIn)) + ois = new ObjectInputStream(gis) + } + else { + val fis = new FileInputStream(gazIn) + ois = new ObjectInputStream(fis) + } + gaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + + println("") + println("") + + for(file <- corpusFiles) { + println(" ") + + for(line <- scala.io.Source.fromFile(file).getLines.map(_.trim).filter(_.length > 0)) { + println(" ") + for(token <- line.split(" ")) { + if(toponymRE.findFirstIn(token) != None) { + val toponymRE(tgnRaw, formRaw) = token + val form = formRaw.replaceAll("-", " ") + val candidates = gaz.lookup(form.toLowerCase) + val goldCoord = tgnToCoord.getOrElse(tgnRaw.toInt, null) + if(candidates == null || goldCoord == null) { + for(tok <- form.split(" ").filter(t => CorpusXMLWriter.isSanitary(t))) + println(" ") + } + else { + var matchingCand:Location = null + for(cand <- candidates) { + if(cand.getRegion.getCenter.distanceInMi(goldCoord) < 5.0) { + matchingCand = cand + } + } + if(matchingCand == null) { + for(tok <- form.split(" ").filter(t => CorpusXMLWriter.isSanitary(t))) + println(" ") + } + //val formToWrite = if(CorpusXMLWriter.isSanitary(form)) form else "MALFORMED" + else if(CorpusXMLWriter.isSanitary(form)) { + println(" ") + println(" ") + for(cand <- candidates) { + val region = cand.getRegion + val center = region.getCenter + print(" ") + /*println(" ") + for(rep <- region.getRepresentatives) { + println(" ") + } + println(" ")*/ + } + println(" ") + println(" ") + } + } + } + else { + val strippedToken = TextUtil.stripPunc(token) + if(CorpusXMLWriter.isSanitary(/*BaseApp.CORPUS_FORMAT.PLAIN, */strippedToken)) + println(" ") + } + } + println(" ") + } + + println(" ") + } + + println("") +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ConvertGeoTextToJSON.scala b/src/main/scala/opennlp/fieldspring/tr/app/ConvertGeoTextToJSON.scala new file mode 100644 index 0000000..f8b51ba --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ConvertGeoTextToJSON.scala @@ -0,0 +1,12 @@ +package opennlp.fieldspring.tr.app + +import com.codahale.jerkson.Json._ + +object ConvertGeoTextToJSON extends App { + for(line <- scala.io.Source.fromFile(args(0), "ISO-8859-1").getLines) { + val tokens = line.split("\t") + println(generate(new tweet(tokens(3).toDouble, tokens(4).toDouble, tokens(5)))) + } +} + +case class tweet(val lat:Double, val lon:Double, val text:String) diff --git a/src/main/scala/opennlp/fieldspring/tr/app/CorpusErrorAnalyzer.scala b/src/main/scala/opennlp/fieldspring/tr/app/CorpusErrorAnalyzer.scala new file mode 100644 index 0000000..dbcc56f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/CorpusErrorAnalyzer.scala @@ -0,0 +1,30 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object CorpusErrorAnalyzer extends BaseApp { + + def main(args:Array[String]) { + initializeOptionsFromCommandLine(args) + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + + } + } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/CorpusInfo.scala b/src/main/scala/opennlp/fieldspring/tr/app/CorpusInfo.scala new file mode 100644 index 0000000..c4dc71b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/CorpusInfo.scala @@ -0,0 +1,117 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object CorpusInfo { + + val DPC = 1.0 + + def getCorpusInfo(filename: String/* useNER:Boolean = false*/): collection.mutable.HashMap[String, collection.mutable.HashMap[Int, Int]] = { + + val corpus = if(filename.endsWith(".ser.gz")) TopoUtil.readStoredCorpusFromSerialized(filename) + else Corpus.createStoredCorpus + + /*if(useNER) { + + //System.out.println("Reading serialized GeoNames gazetteer from " + gazPath + " ...") + + val gis = new GZIPInputStream(new FileInputStream(gazPath)) + val ois = new ObjectInputStream(gis) + + val gnGaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + + corpus.addSource(new ToponymAnnotator(new ToponymRemover(new TrXMLDirSource(new File(filename), new OpenNLPTokenizer)), new OpenNLPRecognizer, gnGaz, null)) + }*/ + if(!filename.endsWith(".ser.gz")) { + corpus.addSource(new TrXMLDirSource(new File(filename), new OpenNLPTokenizer)) + corpus.setFormat(BaseApp.CORPUS_FORMAT.TRCONLL) + corpus.load + } + + + val topsToCellCounts = new collection.mutable.HashMap[String, collection.mutable.HashMap[Int, Int]] + + if(filename.endsWith(".ser.gz")) { + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)/*.filter(_.hasGold)*/) { + val cellCounts = topsToCellCounts.getOrElse(toponym.getForm, new collection.mutable.HashMap[Int, Int]) + val cellNum = if(toponym.hasGold) TopoUtil.getCellNumber(toponym.getGold.getRegion.getCenter, DPC) else 0 + if(cellNum != -1) { + val prevCount = cellCounts.getOrElse(cellNum, 0) + cellCounts.put(cellNum, prevCount + 1) + topsToCellCounts.put(toponym.getForm, cellCounts) + } + } + } + } + } + + else { + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0).filter(_.hasGold)) { + val cellCounts = topsToCellCounts.getOrElse(toponym.getForm, new collection.mutable.HashMap[Int, Int]) + val cellNum = TopoUtil.getCellNumber(toponym.getGold.getRegion.getCenter, DPC) + if(cellNum != -1) { + val prevCount = cellCounts.getOrElse(cellNum, 0) + cellCounts.put(cellNum, prevCount + 1) + topsToCellCounts.put(toponym.getForm, cellCounts) + } + } + } + } + } + + + topsToCellCounts + } + + def printCorpusInfo(topsToCellCounts: collection.mutable.HashMap[String, collection.mutable.HashMap[Int, Int]]) { + + for((topForm, cellCounts) <- topsToCellCounts.toList.sortWith((x, y) => if(x._2.size != y._2.size) + x._2.size > y._2.size + else x._1 < y._1) ) { + print(topForm+": ") + cellCounts.toList.sortWith((x, y) => if(x._2 != y._2) x._2 > y._2 else x._1 < y._1) + .foreach(p => print("["+TopoUtil.getCellCenter(p._1, DPC)+":"+p._2+"] ")) + println + } + } + + def printCollapsedCorpusInfo(topsToCellCounts: collection.mutable.HashMap[String, collection.mutable.HashMap[Int, Int]]) { + + topsToCellCounts.toList.map(p => (p._1, p._2.toList.map(q => q._2).sum)).sortWith((x, y) => if(x._2 != y._2) + x._2 > y._2 + else x._1 < y._1). + foreach(p => println(p._1+" "+p._2)) + + + /*for((topForm, cellCounts) <- topsToCellCounts.toList.sortWith((x, y) => if(x._2.size != y._2.size) + x._2.size > y._2.size + else x._1 < y._1) ) { + print(topForm+": ") + cellCounts.toList.sortWith((x, y) => if(x._2 != y._2) x._2 > y._2 else x._1 < y._1) + .foreach(p => print("["+TopoUtil.getCellCenter(p._1, DPC)+":"+p._2+"] ")) + println + }*/ + } + + def main(args:Array[String]) = { + //if(args.length >= 2) + printCollapsedCorpusInfo(getCorpusInfo(args(0))) + //else + // printCollapsedCorpusInfo(getCorpusInfo(args(0))) + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/FilterGeotaggedWiki.scala b/src/main/scala/opennlp/fieldspring/tr/app/FilterGeotaggedWiki.scala new file mode 100644 index 0000000..5f30a03 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/FilterGeotaggedWiki.scala @@ -0,0 +1,64 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +import org.apache.commons.compress.compressors.bzip2._ +import org.clapper.argot._ +import ArgotConverters._ + +object FilterGeotaggedWiki extends App { + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.tr.app.FilterGeotaggedWiki", preUsage = Some("Fieldspring")) + + val wikiTextInputFile = parser.option[String](List("w", "wiki"), "wiki", "wiki text input file") + val wikiCorpusInputFile = parser.option[String](List("c", "corpus"), "corpus", "wiki corpus input file") + + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + val ids = new collection.mutable.HashSet[String] + + val fis = new FileInputStream(wikiCorpusInputFile.value.get) + fis.read; fis.read + val cbzis = new BZip2CompressorInputStream(fis) + val in = new BufferedReader(new InputStreamReader(cbzis)) + var curLine = in.readLine + while(curLine != null) { + ids += curLine.split("\t")(0) + curLine = in.readLine + } + in.close + + val wikiTextCorpus = Corpus.createStreamCorpus + + wikiTextCorpus.addSource(new WikiTextSource(new BufferedReader(new FileReader(wikiTextInputFile.value.get)))) + wikiTextCorpus.setFormat(BaseApp.CORPUS_FORMAT.WIKITEXT) + + for(doc <- wikiTextCorpus) { + if(ids contains doc.getId) { + println("Article title: " + doc.title) + println("Article ID: " + doc.getId) + for(sent <- doc) { + for(token <- sent) { + println(token.getOrigForm) + } + } + } + else { + for(sent <- doc) { for(token <- sent) {} } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/GazEntryKMLPlotter.scala b/src/main/scala/opennlp/fieldspring/tr/app/GazEntryKMLPlotter.scala new file mode 100644 index 0000000..d48c2c1 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/GazEntryKMLPlotter.scala @@ -0,0 +1,55 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object GazEntryKMLPlotter /*extends BaseApp*/ { + + def main(args:Array[String]) { + + val toponym = args(0).replaceAll("_", " ") + //val gaz = println("Reading serialized gazetteer from " + args(1) + " ...") + val gis = new GZIPInputStream(new FileInputStream(args(1))) + val ois = new ObjectInputStream(gis) + val gnGaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + gis.close + + val entries = gnGaz.lookup(toponym) + if(entries != null) { + var loc = entries(0) + for(entry <- entries) + if(entry.getRegion.getRepresentatives.size > 1) + loc = entry + if(loc != null) + for(coord <- loc.getRegion.getRepresentatives) { + println("") + println("#My_Style") + println("") + println(""+coord.getLngDegrees+","+coord.getLatDegrees+",0") + println("") + println("") + } + } + + /*initializeOptionsFromCommandLine(args) + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + + } + } + }*/ + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelProp.scala b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelProp.scala new file mode 100644 index 0000000..1682084 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelProp.scala @@ -0,0 +1,125 @@ +package opennlp.fieldspring.tr.app + +import java.io._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.app._ +import opennlp.fieldspring.tr.util.TopoUtil + +import upenn.junto.app._ +import upenn.junto.config._ + +import gnu.trove._ + +import scala.collection.JavaConversions._ + +object GeoTextLabelProp extends BaseApp { + + import BaseApp._ + + val MIN_COUNT_THRESHOLD = 5 + + val DPC = 1.0 + + val CELL_ = "cell_" + val CELL_LABEL_ = "cell_label_" + val DOC_ = "doc_" + val UNI_ = "uni_" + val BI_ = "bi_" + val NGRAM_ = "ngram_" + val USER = "USER" + val USER_ = "USER_" + + //val nodeRE = """(.+)_([^_]+)""".r SAVE AND COMPILE + + def main(args: Array[String]) { + + this.initializeOptionsFromCommandLine(args) + checkExists(getSerializedCorpusInputPath) + + val stoplist:Set[String] = + if(getStoplistInputPath != null) scala.io.Source.fromFile(getStoplistInputPath).getLines.toSet + else Set() + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + val graph = createGraph(corpus, stoplist) + + JuntoRunner(graph, 1.0, .01, .01, getNumIterations, false) + + val docIdsToCells = new collection.mutable.HashMap[String, Int] + + for ((id, vertex) <- graph._vertices) { + //val nodeRE(nodeType,nodeId) = id + + //if(nodeType.equals(USER)) + if(id.startsWith(USER_)) + docIdsToCells.put(id, getGreatestCell(vertex.GetEstimatedLabelScores)) + } + + for(doc <- corpus.filter(d => (d.isDev || d.isTest) && docIdsToCells.containsKey(d.getId))) { + val cellNumber = docIdsToCells(doc.getId) + if(cellNumber != -1) { + val lat = ((cellNumber / 1000) * DPC) + DPC/2.0 + val lon = ((cellNumber % 1000) * DPC) + DPC/2.0 + doc.setSystemCoord(Coordinate.fromDegrees(lat, lon)) + } + } + + val eval = new EvaluateCorpus + eval.doEval(corpus, corpus, CORPUS_FORMAT.GEOTEXT, true) + } + + def getGreatestCell(estimatedLabelScores: TObjectDoubleHashMap[String]): Int = { + + estimatedLabelScores.keys(Array[String]()).filter(_.startsWith(CELL_LABEL_)).maxBy(estimatedLabelScores.get(_)).substring(CELL_LABEL_.length).toInt + + } + + def createGraph(corpus: StoredCorpus, stoplist: Set[String]) = { + val edges = getDocCellEdges(corpus) ::: getNgramDocEdges(corpus, stoplist) + val seeds = getCellSeeds(corpus) + GraphBuilder(edges, seeds) + } + + def getCellSeeds(corpus: StoredCorpus): List[Label] = { + (for (doc <- corpus.filter(_.isTrain)) yield { + val cellNumber = TopoUtil.getCellNumber(doc.getGoldCoord, DPC) + new Label(CELL_ + cellNumber, CELL_LABEL_ + cellNumber, 1.0) + }).toList + } + + def getDocCellEdges(corpus: StoredCorpus): List[Edge] = { + (corpus.filter(_.isTrain).map(doc => new Edge(doc.getId, CELL_ + TopoUtil.getCellNumber(doc.getGoldCoord, DPC), 1.0))).toList + } + + def getNgramDocEdges(corpus: StoredCorpus, stoplist: Set[String]): List[Edge] = { + + val ngramsToCounts = new collection.mutable.HashMap[String, Int] { override def default(s: String) = 0 } + val docIdsToNgrams = new collection.mutable.HashMap[String, collection.mutable.HashSet[String]] { + override def default(s: String) = new collection.mutable.HashSet + } + + for(doc <- corpus) { + for(sent <- doc) { + + val unigrams = (for(token <- sent) yield token.getForm).toList + val bigrams = if(unigrams.length >= 2) (for(bi <- unigrams.sliding(2)) yield bi(0)+" "+bi(1)).toList else Nil + val filteredUnigrams = unigrams.filterNot(stoplist.contains(_)) + + for(ngram <- filteredUnigrams ::: bigrams) { + ngramsToCounts.put(ngram, ngramsToCounts(ngram) + 1) + docIdsToNgrams.put(doc.getId, docIdsToNgrams(doc.getId) + ngram) + } + } + } + + (for(docId <- docIdsToNgrams.keys) yield + (for(ngram <- docIdsToNgrams(docId).filter(ngramsToCounts(_) >= MIN_COUNT_THRESHOLD)) yield + new Edge(NGRAM_ + ngram, docId, 1.0)).toList).toList.flatten + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropDecoder.scala b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropDecoder.scala new file mode 100644 index 0000000..4576df1 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropDecoder.scala @@ -0,0 +1,85 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.app._ +import opennlp.fieldspring.tr.util.TopoUtil + +import scala.collection.JavaConversions._ + +object GeoTextLabelPropDecoder extends BaseApp { + + import BaseApp._ + + def DPC = 1.0 + + def CELL_ = "cell_" + def CELL_LABEL_ = "cell_label_" + //def DOC_ = "doc_" + def USER_ = "USER_" + def UNI_ = "uni_" + def BI_ = "bi_" + + def main(args: Array[String]) = { + + this.initializeOptionsFromCommandLine(args) + this.doDecode + + } + + def doDecode() = { + checkExists(getSerializedCorpusInputPath) + checkExists(getGraphInputPath) + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + val docIdsToCells = new collection.mutable.HashMap[String, Int] + + val lines = scala.io.Source.fromFile(getGraphInputPath).getLines + + for(line <- lines) { + val tokens = line.split("\t") + + if(tokens.length >= 4 && tokens(0).startsWith(USER_)) { + val docId = tokens(0) + + val innertokens = tokens(3).split(" ") + + docIdsToCells.put(docId, findGreatestCell(innertokens)) + } + } + + for(document <- corpus) { + if(document.isDev || document.isTest) { + if(docIdsToCells.containsKey(document.getId)) { + val cellNumber = docIdsToCells(document.getId) + if(cellNumber != -1) { + val lat = ((cellNumber / 1000) * DPC) + DPC/2.0 + val lon = ((cellNumber % 1000) * DPC) + DPC/2.0 + document.setSystemCoord(Coordinate.fromDegrees(lat, lon)) + } + } + } + } + + val eval = new EvaluateCorpus + eval.doEval(corpus, corpus, CORPUS_FORMAT.GEOTEXT, true) + } + + def findGreatestCell(innertokens: Array[String]): Int = { + + for(innertoken <- innertokens) { + if(innertoken.startsWith(CELL_LABEL_)) { + return innertoken.substring(CELL_LABEL_.length).toInt + } + } + + return -1 + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropPreproc.scala b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropPreproc.scala new file mode 100644 index 0000000..26d5d4c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/GeoTextLabelPropPreproc.scala @@ -0,0 +1,179 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.app._ +import opennlp.fieldspring.tr.util.StopwordUtil +import opennlp.fieldspring.tr.util.TopoUtil + +import scala.collection.JavaConversions._ + +object GeoTextLabelPropPreproc extends BaseApp { + + import BaseApp._ + + def MIN_COUNT_THRESHOLD = 5 + + def DPC = 1.0 + + def CELL_ = "cell_" + def CELL_LABEL_ = "cell_label_" + def DOC_ = "doc_" + def UNI_ = "uni_" + def BI_ = "bi_" + + def main(args: Array[String]) { + + this.initializeOptionsFromCommandLine(args) + this.doPreproc + + } + + def doPreproc() { + checkExists(getSerializedCorpusInputPath) + + val stoplist:Set[String] = + if(getStoplistInputPath != null) StopwordUtil.populateStoplist(getStoplistInputPath) + else collection.immutable.Set[String]() + + val corpus = TopoUtil.readStoredCorpusFromSerialized(getSerializedCorpusInputPath) + + writeCellSeeds(corpus, getSeedOutputPath) + writeDocCellEdges(corpus, getGraphOutputPath) + writeNGramDocEdges(corpus, getGraphOutputPath, stoplist) + + //writeCellCellEdges(getGraphOutputPath) + } + + def writeCellSeeds(corpus: StoredCorpus, seedOutputPath: String) = { + val out = new FileWriter(seedOutputPath) + + val edges = (for (doc <- corpus.filter(_.isTrain)) yield { + val cellNumber = TopoUtil.getCellNumber(doc.getGoldCoord, DPC) + Tuple3(CELL_ + cellNumber, CELL_LABEL_ + cellNumber, 1.0) + }) + + edges.map(writeEdge(out, _)) + + out.close + } + + def writeDocCellEdges(corpus: StoredCorpus, graphOutputPath: String) = { + val out = new FileWriter(graphOutputPath) + + val edges = corpus.filter(_.isTrain).map(doc => Tuple3(doc.getId, CELL_ + TopoUtil.getCellNumber(doc.getGoldCoord, DPC), 1.0)) + + edges.map(writeEdge(out, _)) + + out.close + } + + def writeNGramDocEdges(corpus: StoredCorpus, graphOutputPath: String, stoplist: Set[String]) = { + + val docIdsToNGrams = new collection.mutable.HashMap[String, collection.mutable.HashSet[String]] + val unigramsToCounts = new collection.mutable.HashMap[String, Int] + val bigramsToCounts = new collection.mutable.HashMap[String, Int] + + for(document <- corpus) { + docIdsToNGrams.put(document.getId, new collection.mutable.HashSet[String]()) + for(sentence <- document) { + var prevUni:String = null + for(token <- sentence) { + val unigram = token.getForm + if(!stoplist.contains(unigram)) { + docIdsToNGrams.get(document.getId).get += UNI_ + unigram + if(!unigramsToCounts.contains(unigram)) { + unigramsToCounts.put(unigram, 1) + //println("saw " + unigram + " for the first time.") + } + else { + unigramsToCounts.put(unigram, unigramsToCounts.get(unigram).get + 1) + //println("saw " + unigram + " for the " + (unigramsToCounts.get(unigram).get + 1) + " time.") + } + } + if(prevUni != null) { + val bigram = prevUni + " " + unigram + docIdsToNGrams.get(document.getId).get += BI_ + bigram + if(!bigramsToCounts.contains(bigram)) + bigramsToCounts.put(bigram, 1) + else + bigramsToCounts.put(bigram, bigramsToCounts.get(bigram).get + 1) + } + prevUni = unigram + } + } + } + + val out = new FileWriter(graphOutputPath, true) + + for(docId <- docIdsToNGrams.keys) { + for(ngram <- docIdsToNGrams.get(docId).get) { + val shortNGram = + if(ngram.startsWith(BI_)) ngram.substring(BI_.length) + else if(ngram.startsWith(UNI_)) ngram.substring(UNI_.length) + else ngram + if((unigramsToCounts.contains(shortNGram) && unigramsToCounts.get(shortNGram).get >= MIN_COUNT_THRESHOLD) + || (bigramsToCounts.contains(shortNGram) && bigramsToCounts.get(shortNGram).get >= MIN_COUNT_THRESHOLD)) { + writeEdge(out, ngram, docId, 1.0) + //println("writing " + ngram + " " + docId) + } + } + } + + /*for(document <- corpus) { + for(sentence <- document) { + var prevUni:String = null + for(token <- sentence) { + if(!stoplist.contains(token.getForm)) + writeEdge(out, UNI_ + token.getForm, document.getId, 1.0) + if(prevUni != null) + writeEdge(out, BI_ + prevUni + " " + token.getForm, document.getId, 1.0) + prevUni = token.getForm + } + } + }*/ + + out.close + } + + def writeCellCellEdges(graphOutputPath: String) = { + val out = new FileWriter(graphOutputPath, true) + + var lon = 0.0 + while(lon < 360.0 / DPC) { + var lat = 0.0 + while(lat < 180.0 / DPC) { + val curCellNumber = TopoUtil.getCellNumber(lat, lon, DPC) + val leftCellNumber = TopoUtil.getCellNumber(lat, lon - DPC, DPC) + val rightCellNumber = TopoUtil.getCellNumber(lat, lon + DPC, DPC) + val topCellNumber = TopoUtil.getCellNumber(lat + DPC, lon, DPC) + val bottomCellNumber = TopoUtil.getCellNumber(lat - DPC, lon, DPC) + + writeEdge(out, CELL_ + curCellNumber, CELL_ + leftCellNumber, 1.0) + writeEdge(out, CELL_ + curCellNumber, CELL_ + rightCellNumber, 1.0) + if(topCellNumber >= 0) + writeEdge(out, CELL_ + curCellNumber, CELL_ + topCellNumber, 1.0) + if(bottomCellNumber >= 0) + writeEdge(out, CELL_ + curCellNumber, CELL_ + bottomCellNumber, 1.0) + + lat += DPC + } + + lon += DPC + } + + out.close + } + + def writeEdge(out: FileWriter, node1: String, node2: String, weight: Double) = { + out.write(node1 + "\t" + node2 + "\t" + weight + "\n") + } + + def writeEdge(out: FileWriter, e: Tuple3[String, String, Double]) = { + out.write(e._1 + "\t" + e._2 + "\t" + e._3 + "\n") + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/Preprocess.scala b/src/main/scala/opennlp/fieldspring/tr/app/Preprocess.scala new file mode 100644 index 0000000..5918fda --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/Preprocess.scala @@ -0,0 +1,49 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.app + +import java.io._ + +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.util.Constants + +object Preprocess extends App { + override def main(args: Array[String]) { + val divider = new OpenNLPSentenceDivider + val tokenizer = new OpenNLPTokenizer + val recognizer = new OpenNLPRecognizer + val gazetteer = new InMemoryGazetteer + + gazetteer.load(new WorldReader(new File( + Constants.getGazetteersDir() + File.separator + "dataen-fixed.txt.gz" + ))) + + val corpus = Corpus.createStreamCorpus + + val in = new BufferedReader(new FileReader(args(0))) + corpus.addSource( + new ToponymAnnotator(new PlainTextSource(in, divider, tokenizer, args(0)), + recognizer, gazetteer + )) + + val writer = new CorpusXMLWriter(corpus) + writer.write(new File(args(1))) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/app/ReprocessTrApp.scala b/src/main/scala/opennlp/fieldspring/tr/app/ReprocessTrApp.scala new file mode 100644 index 0000000..2b0c562 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/ReprocessTrApp.scala @@ -0,0 +1,47 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.app + +import java.io._ + +import opennlp.fieldspring.tr.eval._ +import opennlp.fieldspring.tr.resolver._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.util.Constants + +object ReprocessTrApp { + def main(args: Array[String]) { + val tokenizer = new OpenNLPTokenizer + val recognizer = new OpenNLPRecognizer + + val gazetteer = new InMemoryGazetteer + gazetteer.load(new WorldReader(new File( + Constants.getGazetteersDir() + File.separator + "dataen-fixed.txt.gz" + ))) + + val corpus = Corpus.createStreamCorpus + val source = new TrXMLDirSource(new File(args(0)), tokenizer) + val stripped = new ToponymRemover(source) + corpus.addSource(new ToponymAnnotator(stripped, recognizer, gazetteer)) + + val writer = new CorpusXMLWriter(corpus) + writer.write(new File(args(1))) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/app/SplitDevTest.scala b/src/main/scala/opennlp/fieldspring/tr/app/SplitDevTest.scala new file mode 100644 index 0000000..a9a7d98 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/SplitDevTest.scala @@ -0,0 +1,23 @@ +package opennlp.fieldspring.tr.app + +import java.io._ + +object SplitDevTest extends App { + val dir = new File(args(0)) + + val devDir = new File(dir.getCanonicalPath+"dev") + val testDir = new File(dir.getCanonicalPath+"test") + devDir.mkdir + testDir.mkdir + + val files = dir.listFiles + + var i = 1 + for(file <- files) { + if(i % 3 == 0) + file.renameTo(new File(testDir, file.getName)) + else + file.renameTo(new File(devDir, file.getName)) + i += 1 + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/SupervisedTRMaxentModelTrainer.scala b/src/main/scala/opennlp/fieldspring/tr/app/SupervisedTRMaxentModelTrainer.scala new file mode 100644 index 0000000..7284a33 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/SupervisedTRMaxentModelTrainer.scala @@ -0,0 +1,221 @@ +package opennlp.fieldspring.tr.app + +import java.io._ +import java.util.zip._ + +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.topo.gaz._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +import org.apache.commons.compress.compressors.bzip2._ +import org.clapper.argot._ +import ArgotConverters._ + +import opennlp.maxent._ +import opennlp.maxent.io._ +import opennlp.model._ + +object SupervisedTRFeatureExtractor extends App { + val parser = new ArgotParser("fieldspring run opennlp.fieldspring.tr.app.SupervisedTRMaxentModelTrainer", preUsage = Some("Fieldspring")) + + val wikiCorpusInputFile = parser.option[String](List("c", "corpus"), "corpus", "wiki training corpus input file") + val wikiTextInputFile = parser.option[String](List("w", "wiki"), "wiki", "wiki text input file") + val trInputFile = parser.option[String](List("i", "tr-input"), "tr-input", "TR-CoNLL input path") + val gazInputFile = parser.option[String](List("g", "gaz"), "gaz", "serialized gazetteer input file") + val stoplistInputFile = parser.option[String](List("s", "stoplist"), "stoplist", "stopwords input file") + val modelsOutputDir = parser.option[String](List("d", "models-dir"), "models-dir", "models output directory") + //val thresholdParam = parser.option[Double](List("t", "threshold"), "threshold", "maximum distance threshold") + + val windowSize = 20 + val dpc = 1.0 + //val threshold = if(thresholdParam.value != None) thresholdParam.value.get else 1.0 + + val distanceTable = new DistanceTable + + try { + parser.parse(args) + } + catch { + case e: ArgotUsageException => println(e.message); sys.exit(0) + } + + println("Reading toponyms from TR-CoNLL at " + trInputFile.value.get + " ...") + val toponyms:Set[String] = CorpusInfo.getCorpusInfo(trInputFile.value.get).map(_._1).toSet + + toponyms.foreach(println) + + println("Reading Wikipedia geotags from " + wikiCorpusInputFile.value.get + "...") + val idsToCoords = new collection.mutable.HashMap[String, Coordinate] + val fis = new FileInputStream(wikiCorpusInputFile.value.get) + //fis.read; fis.read + val cbzis = new BZip2CompressorInputStream(fis) + val in = new BufferedReader(new InputStreamReader(cbzis)) + var curLine = in.readLine + while(curLine != null) { + val tokens = curLine.split("\t") + val coordTokens = tokens(2).split(",") + idsToCoords.put(tokens(0), Coordinate.fromDegrees(coordTokens(0).toDouble, coordTokens(1).toDouble)) + curLine = in.readLine + } + in.close + + println("Reading serialized gazetteer from " + gazInputFile.value.get + " ...") + val gis = new GZIPInputStream(new FileInputStream(gazInputFile.value.get)) + val ois = new ObjectInputStream(gis) + val gnGaz = ois.readObject.asInstanceOf[GeoNamesGazetteer] + gis.close + + println("Reading Wiki text corpus from " + wikiTextInputFile.value.get + " ...") + + val recognizer = new OpenNLPRecognizer + val tokenizer = new OpenNLPTokenizer + + val wikiTextCorpus = Corpus.createStreamCorpus + + wikiTextCorpus.addSource(new ToponymAnnotator(new WikiTextSource(new BufferedReader(new FileReader(wikiTextInputFile.value.get))), recognizer, gnGaz)) + wikiTextCorpus.setFormat(BaseApp.CORPUS_FORMAT.WIKITEXT) + + val stoplist:Set[String] = + if(stoplistInputFile.value != None) { + println("Reading stopwords file from " + stoplistInputFile.value.get + " ...") + scala.io.Source.fromFile(stoplistInputFile.value.get).getLines.toSet + } + else { + println("No stopwords file specified. Using an empty stopword list.") + Set() + } + + println("Building training sets for each toponym type...") + + val toponymsToTrainingSets = new collection.mutable.HashMap[String, List[(Array[String], String)]] + for(doc <- wikiTextCorpus) { + if(idsToCoords.containsKey(doc.getId)) { + val docCoord = idsToCoords(doc.getId) + println(doc.getId+" has a geotag: "+docCoord) + val docAsArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + if(toponyms(token.getForm)) + println(token.getForm+" is a toponym we care about.") + else + println(token.getForm+" is a toponym, but we don't care about it.") + } + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0 && toponyms(token.getForm)) { + val toponym = token.asInstanceOf[Toponym] + //val bestCellNum = getBestCellNum(toponym, docCoord, dpc) + val bestCandIndex = getBestCandIndex(toponym, docCoord) + if(bestCandIndex != -1) { + val contextFeatures = TextUtil.getContextFeatures(docAsArray, tokIndex, windowSize, stoplist) + val prevSet = toponymsToTrainingSets.getOrElse(token.getForm, Nil) + print(toponym+": ") + contextFeatures.foreach(f => print(f+",")) + println(bestCandIndex) + + toponymsToTrainingSets.put(token.getForm, (contextFeatures, bestCandIndex.toString) :: prevSet) + } + } + tokIndex += 1 + } + } + else { + println(doc.getId+" does not have a geotag.") + for(sent <- doc) { for(token <- sent) {} } + } + } + + val dir = + if(modelsOutputDir.value.get != None) { + println("Training Maxent models for each toponym type, outputting to directory " + modelsOutputDir.value.get + " ...") + val dirFile:File = new File(modelsOutputDir.value.get) + if(!dirFile.exists) + dirFile.mkdir + if(modelsOutputDir.value.get.endsWith("/")) + modelsOutputDir.value.get + else + modelsOutputDir.value.get+"/" + } + else { + println("Training Maxent models for each toponym type, outputting to current working directory ...") + "" + } + for((toponym, trainingSet) <- toponymsToTrainingSets) { + val outFile = new File(dir + toponym.replaceAll(" ", "_")+".txt") + val out = new BufferedWriter(new FileWriter(outFile)) + for((context, label) <- trainingSet) { + for(feature <- context) out.write(feature+",") + out.write(label+"\n") + } + out.close + } + + println("All done.") + + def getBestCandIndex(toponym:Toponym, docCoord:Coordinate): Int = { + var index = 0 + var minDist = Double.PositiveInfinity + var bestIndex = -1 + val docRegion = new PointRegion(docCoord) + for(loc <- toponym.getCandidates) { + val dist = loc.getRegion.distanceInKm(docRegion) + if(dist < loc.getThreshold && dist < minDist) { + minDist = dist + bestIndex = index + } + index += 1 + } + bestIndex + } + + /*def getBestCellNum(toponym:Toponym, docCoord:Coordinate, dpc:Double): Int = { + for(loc <- toponym.getCandidates) { + if(loc.getRegion.distanceInKm(new PointRegion(docCoord)) < loc.getThreshold) { + return TopoUtil.getCellNumber(loc.getRegion.getCenter, dpc) + } + } + -1 + }*/ + +} + +object SupervisedTRMaxentModelTrainer extends App { + + val iterations = 10 + val cutoff = 2 + + val dir = new File(args(0)) + for(file <- dir.listFiles.filter(_.getName.endsWith(".txt"))) { + try { + val reader = new BufferedReader(new FileReader(file)) + val dataStream = new PlainTextByLineDataStream(reader) + val eventStream = new BasicEventStream(dataStream, ",") + + //GIS.PRINT_MESSAGES = false + val model = GIS.trainModel(eventStream, iterations, cutoff) + val modelWriter = new BinaryGISModelWriter(model, new File(file.getAbsolutePath.replaceAll(".txt", ".mxm"))) + modelWriter.persist() + modelWriter.close() + } catch { + case e: Exception => e.printStackTrace + } + } +} + +object MaxentEventStreamFactory { + def apply(iterator:Iterator[(Array[String], String)]): EventStream = { + new BasicEventStream(new DataStream { + def nextToken: AnyRef = { + val next = iterator.next + val featuresAndLabel = (next._1.toList ::: (next._2 :: Nil)).mkString(",") + println(featuresAndLabel) + featuresAndLabel + } + def hasNext: Boolean = iterator.hasNext + }, ",") + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/app/VisualizeCorpus.scala b/src/main/scala/opennlp/fieldspring/tr/app/VisualizeCorpus.scala new file mode 100644 index 0000000..623dbd9 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/app/VisualizeCorpus.scala @@ -0,0 +1,327 @@ +package opennlp.fieldspring.tr.app + +import processing.core._ +//import processing.opengl._ +//import codeanticode.glgraphics._ +import de.fhpotsdam.unfolding._ +import de.fhpotsdam.unfolding.geo._ +import de.fhpotsdam.unfolding.events._ +import de.fhpotsdam.unfolding.utils._ +import de.fhpotsdam.unfolding.providers.Microsoft +import controlP5._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.io._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.util._ + +import java.io._ +import java.awt.event._ + +import scala.collection.JavaConversions._ + +class VisualizeCorpus extends PApplet { + + class TopoMention(val toponym:Toponym, + val context:String, + val docid:String) + + var mapDetail:de.fhpotsdam.unfolding.Map = null + var topoTextArea:Textarea = null + var checkbox:CheckBox = null + + var allButton:Button = null + var noneButton:Button = null + + val INIT_WIDTH = if(VisualizeCorpus.inputWidth > 0) VisualizeCorpus.inputWidth else 1280 + val INIT_HEIGHT = if(VisualizeCorpus.inputHeight > 0) VisualizeCorpus.inputHeight else 720 + + val RADIUS = 10 + + val BORDER_WIDTH = 10 + val TEXTAREA_WIDTH = 185 + val BUTTONPANEL_WIDTH = 35 + val CHECKBOX_WIDTH = 185 - BUTTONPANEL_WIDTH + var textareaHeight = INIT_HEIGHT - BORDER_WIDTH*2 + var mapWidth = INIT_WIDTH - TEXTAREA_WIDTH - CHECKBOX_WIDTH - BUTTONPANEL_WIDTH - BORDER_WIDTH*4 + var mapHeight = INIT_HEIGHT - BORDER_WIDTH*2 + var checkboxX = BORDER_WIDTH + BUTTONPANEL_WIDTH + var checkboxHeight = textareaHeight + var mapX = checkboxX + CHECKBOX_WIDTH + BORDER_WIDTH + var textareaX = mapWidth + CHECKBOX_WIDTH + BUTTONPANEL_WIDTH + BORDER_WIDTH*3 + + val SLIDER_WIDTH = 5 + + var BUTTON_WIDTH = 27 + var BUTTON_HEIGHT = 15 + val NONE_BUTTON_Y = BORDER_WIDTH + BUTTON_HEIGHT + 10 + + val CONTEXT_SIZE = 20 + + var cp5:ControlP5 = null + + val coordsMap = new scala.collection.mutable.HashMap[(Float, Float), List[TopoMention]] + var docList:Array[String] = null + var shownDocs:Set[String] = null + + var oldWidth = INIT_WIDTH + var oldHeight = INIT_HEIGHT + + var checkboxTotalHeight = 0 + + override def setup { + size(INIT_WIDTH, INIT_HEIGHT/*, GLConstants.GLGRAPHICS*/) + //frame.setResizable(true) + frame.setTitle("Corpus Visualizer") + //textMode(PConstants.SHAPE) + + cp5 = new ControlP5(this) + + mapDetail = new de.fhpotsdam.unfolding.Map(this, "detail", mapX, BORDER_WIDTH, mapWidth, mapHeight/*, true, false, new Microsoft.AerialProvider*/) + mapDetail.setZoomRange(2, 10) + //mapDetail.zoomToLevel(4) + mapDetail.zoomAndPanTo(new Location(25.0f, 12.0f), 2) // map center + //mapDetail.zoomAndPanTo(new Location(38.5f, -98.0f), 2) // USA center + val eventDispatcher = MapUtils.createDefaultEventDispatcher(this, mapDetail) + + topoTextArea = cp5.addTextarea("") + .setPosition(textareaX, BORDER_WIDTH) + .setSize(TEXTAREA_WIDTH, textareaHeight) + .setFont(createFont("arial",12)) + .setLineHeight(14) + .setColor(color(0)) + + allButton = cp5.addButton("all") + .setPosition(BORDER_WIDTH, BORDER_WIDTH) + .setSize(BUTTON_WIDTH, BUTTON_HEIGHT) + + noneButton = cp5.addButton("none") + .setPosition(BORDER_WIDTH, NONE_BUTTON_Y) + .setSize(BUTTON_WIDTH, BUTTON_HEIGHT) + + checkbox = cp5.addCheckBox("checkBox") + .setPosition(checkboxX, BORDER_WIDTH) + .setColorForeground(color(120)) + .setColorActive(color(0, 200, 0)) + .setColorLabel(color(0)) + .setSize(10, 10) + .setItemsPerRow(1) + .setSpacingColumn(30) + .setSpacingRow(10) + + cp5.addSlider("slider") + .setPosition(checkboxX + CHECKBOX_WIDTH - SLIDER_WIDTH, BORDER_WIDTH) + .setSize(SLIDER_WIDTH, checkboxHeight) + .setRange(0, 1) + .setValue(1) + .setLabelVisible(false) + .setSliderMode(Slider.FLEXIBLE) + .setHandleSize(40) + .setColorBackground(color(255)) + + class myMWListener(vc:VisualizeCorpus) extends MouseWheelListener { + def mouseWheelMoved(mwe:MouseWheelEvent) { + vc.mouseWheel(mwe.getWheelRotation) + } + } + + addMouseWheelListener(new myMWListener(this)) + + val tokenizer = new OpenNLPTokenizer + val corpus = TopoUtil.readStoredCorpusFromSerialized(VisualizeCorpus.inputFile) //Stored + + docList = + (for(doc <- corpus) yield { + val docArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docArray) { + if(token.isToponym) { + val toponym = token.asInstanceOf[Toponym] + if(toponym.getAmbiguity > 0 && toponym.hasSelected) { + val coord = toponym.getSelected.getRegion.getCenter + val pair = (coord.getLatDegrees.toFloat, coord.getLngDegrees.toFloat) + val prevList = coordsMap.getOrElse(pair, Nil) + val context = TextUtil.getContext(docArray, tokIndex, CONTEXT_SIZE) + coordsMap.put(pair, (prevList ::: (new TopoMention(toponym, context, doc.getId) :: Nil))) + } + } + tokIndex += 1 + } + //checkbox.addItem(doc.getId, 0) + doc.getId + }).toArray + + docList.toList.sortBy(x => x).foreach(x => checkbox.addItem(x, 0)) + shownDocs = docList.toSet + + checkboxTotalHeight = docList.size * 20 + if(checkboxTotalHeight <= height) + cp5.getController("slider").hide + checkbox.activateAll + } + + var selectedCirc:(Float, Float, List[TopoMention]) = null + var oldSelectedCirc = selectedCirc + var onScreen:List[((Float, Float), List[TopoMention])] = Nil + val sb = new StringBuffer + + override def draw { + background(255) + + mapDetail.draw + + checkbox.setPosition(checkboxX, BORDER_WIDTH - + ((1.0 - cp5.getController("slider").getValue) * (checkboxTotalHeight - checkboxHeight)).toInt) + + onScreen = + (for(((lat,lng),rawTopolist) <- coordsMap) yield { + val topolist = rawTopolist.filter(tm => shownDocs(tm.docid)) + if(topolist.size > 0) { + val ufLoc:de.fhpotsdam.unfolding.geo.Location = new de.fhpotsdam.unfolding.geo.Location(lat, lng) + val xy:Array[Float] = mapDetail.getScreenPositionFromLocation(ufLoc) + if(xy(0) >= mapX + RADIUS && xy(0) <= mapX + mapWidth - RADIUS + && xy(1) >= BORDER_WIDTH + RADIUS && xy(1) <= mapHeight + BORDER_WIDTH - RADIUS) { + if(selectedCirc != null && lat == selectedCirc._1 && lng == selectedCirc._2) + fill(200, 0, 0, 100) + else + fill(0, 200, 0, 100) + ellipse(xy(0), xy(1), RADIUS*2, RADIUS*2) + fill(1) + val num = topolist.size + if(num < 10) + text(num, xy(0)-(RADIUS.toFloat/3.2).toInt, xy(1)+(RADIUS.toFloat/2.1).toInt) + else + text(num, xy(0)-(RADIUS.toFloat/1.3).toInt, xy(1)+(RADIUS.toFloat/2.1).toInt) + + Some(((lat,lng),topolist)) + } + else + None + } + else + None + }).flatten.toList + + if(selectedCirc != oldSelectedCirc) { + if(selectedCirc != null) { + val topoMentions = selectedCirc._3.filter(tm => shownDocs(tm.docid)) + sb.setLength(0) + sb.append(selectedCirc._3(0).toponym.getOrigForm) + sb.append(" (") + sb.append(topoMentions.size) + sb.append(")") + var i = 1 + for(topoMention <- topoMentions) { + sb.append("\n\n") + sb.append(i) + sb.append(". ") + if(!topoMention.context.startsWith("[[")) sb.append("...") + sb.append(topoMention.context) + if(!topoMention.context.endsWith("]]")) sb.append("...") + sb.append(" (") + sb.append(topoMention.docid) + sb.append(")") + i += 1 + } + topoTextArea.setText(sb.toString) + } + else + topoTextArea.setText("") + } + + oldSelectedCirc = selectedCirc + } + + var mousePressedX = -1.0 + var mousePressedY = -1.0 + override def mousePressed { + mousePressedX = mouseX + mousePressedY = mouseY + } + + override def mouseReleased { + + if(mouseX >= mapX && mouseX <= mapX + mapWidth + && mouseY >= BORDER_WIDTH && mouseY <= BORDER_WIDTH + mapHeight) { // clicked in map + var clickedCirc = false + for(((lat, lng), topolist) <- onScreen) { + val xy:Array[Float] = mapDetail.getScreenPositionFromLocation(new de.fhpotsdam.unfolding.geo.Location(lat, lng)) + if(PApplet.dist(mouseX, mouseY, xy(0), xy(1)) <= RADIUS) { + oldSelectedCirc = selectedCirc + selectedCirc = (lat, lng, topolist) + clickedCirc = true + } + } + if(mouseX == mousePressedX && mouseY == mousePressedY) { // didn't drag + if(selectedCirc != null/* && PApplet.dist(mouseX, mouseY, selectedCirc._1, selectedCirc._2) > RADIUS*/) + topoTextArea.scroll(0) + if(!clickedCirc) { + oldSelectedCirc = selectedCirc + selectedCirc = null + } + } + } + + if(mouseX >= BORDER_WIDTH && mouseX <= checkboxX + CHECKBOX_WIDTH + && mouseY >= 0 && mouseY <= BORDER_WIDTH + checkboxHeight) { + shownDocs = checkbox.getItems.filter(_.getState == true).map(_.getLabel).toSet + + if(selectedCirc != null) { + val topolist = coordsMap((selectedCirc._1, selectedCirc._2)).filter(tm => shownDocs(tm.docid)) + if(topolist.length > 0) { + selectedCirc = (selectedCirc._1, selectedCirc._2, topolist) + } + else { + oldSelectedCirc = selectedCirc + selectedCirc = null + } + } + } + } + + def controlEvent(e:ControlEvent) { + //try { + if(mouseX >= BORDER_WIDTH && mouseX <= BORDER_WIDTH + BUTTON_WIDTH) { + if(mouseY >= BORDER_WIDTH && mouseY <= BORDER_WIDTH + BUTTON_HEIGHT) { + checkbox.activateAll + shownDocs = docList.toSet + } + else if(mouseY >= NONE_BUTTON_Y && mouseY <= NONE_BUTTON_Y + BUTTON_HEIGHT) { + checkbox.deactivateAll + shownDocs.clear + } + } + /*} catch { + case e: Exception => //e.printStackTrace + }*/ + } + + def mouseWheel(delta:Int) { + if(mouseX >= BORDER_WIDTH && mouseX <= checkboxX + CHECKBOX_WIDTH - SLIDER_WIDTH + && mouseY >= 0 && mouseY <= BORDER_WIDTH + checkboxHeight) { + val slider = cp5.getController("slider") + slider.setValue(slider.getValue - delta.toFloat / 100) + } + else if(mouseX >= textareaX && mouseX <= textareaX + TEXTAREA_WIDTH + && mouseY >= BORDER_WIDTH && mouseY <= BORDER_WIDTH + textareaHeight) { + topoTextArea.scrolled(delta * 2) + } + } + +} + +object VisualizeCorpus extends PApplet { + + var inputFile:String = null + var inputWidth:Int = -1 + var inputHeight:Int = -1 + + def main(args:Array[String]) { + inputFile = args(0) + if(args.size >= 3) { + inputWidth = args(1).toInt + inputHeight = args(2).toInt + } + PApplet.main(Array(/*"--present", */"opennlp.fieldspring.tr.app.VisualizeCorpus")) + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/model/AltBasicMinDistModel.scala b/src/main/scala/opennlp/fieldspring/tr/model/AltBasicMinDistModel.scala new file mode 100644 index 0000000..73365c3 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/model/AltBasicMinDistModel.scala @@ -0,0 +1,50 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.resolver + +import scala.collection.JavaConversions._ +import opennlp.fieldspring.tr.text._ + + +class AltBasicMinDistResolver extends Resolver { + def disambiguate(corpus: StoredCorpus): StoredCorpus = { + + /* Iterate over documents. */ + corpus.foreach { document => + + /* Collect a list of toponyms with candidates for each document. */ + val toponyms = document.flatMap(_.getToponyms).filter(_.getAmbiguity > 0).toList + + /* For each toponym, pick the best candidate. */ + toponyms.foreach { toponym => + + /* Compute all valid totals with indices. */ + toponym.zipWithIndex.flatMap { case (candidate, idx) => + toponyms.filterNot(_ == toponym) match { + case Nil => None + case ts => Some(ts.map(_.map(_.distance(candidate)).min).sum, idx) + } + } match { + case Nil => () + case xs => toponym.setSelectedIdx(xs.min._2) + } + } + } + + return corpus + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/BayesRuleResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/BayesRuleResolver.scala new file mode 100644 index 0000000..95e7624 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/BayesRuleResolver.scala @@ -0,0 +1,113 @@ +package opennlp.fieldspring.tr.resolver + +import java.io._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import opennlp.maxent._ +import opennlp.maxent.io._ +import opennlp.model._ + +import scala.collection.JavaConversions._ + +class BayesRuleResolver(val logFilePath:String, + val modelDirPath:String) extends Resolver { + + val DPC = 1.0 + val WINDOW_SIZE = 20 + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { +/* + val modelDir = new File(modelDirPath) + + val toponymsToModels:Map[String, AbstractModel] = + (for(file <- modelDir.listFiles.filter(_.getName.endsWith(".mxm"))) yield { + val dataInputStream = new DataInputStream(new FileInputStream(file)); + val reader = new BinaryGISModelReader(dataInputStream) + val model = reader.getModel + + //println(file.getName.dropRight(4).replaceAll("_", " ")) + (file.getName.dropRight(4).replaceAll("_", " "), model) + }).toMap + + val ngramDists = LogUtil.getNgramDists(logFilePath) + //println(ngramDists.size) + + for(doc <- corpus) { + val docAsArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + + // P(l|d_c(t)) + val cellDistGivenLocalContext = + if(toponymsToModels.containsKey(toponym.getForm)) { + val contextFeatures = TextUtil.getContextFeatures(docAsArray, tokIndex, WINDOW_SIZE, Set[String]()) + + /*val d = */MaxentResolver.getCellDist(toponymsToModels(toponym.getForm), contextFeatures, + toponym.getCandidates.toList, DPC) + } + else + null + + var indexToSelect = -1 + var maxLogProb = Double.MinValue + var candIndex = 0 + for(cand <- toponym.getCandidates) { + val curCellNum = TopoUtil.getCellNumber(cand.getRegion.getCenter, DPC) + + val localContextComponent = + if(cellDistGivenLocalContext != null) + cellDistGivenLocalContext.getOrElse(curCellNum, 0.0) + else + 0.0 + + val DISCOUNT_FACTOR = 1.0E-300 + + val dist = ngramDists.getOrElse(curCellNum, null) + val logProbOfDocGivenLocation = + if(dist != null) { + val denom = dist.map(_._2).sum + val unkMass = DISCOUNT_FACTOR * (dist.size+1) + (for(word <- docAsArray.map(_.getForm.split(" ")).flatten) yield { + math.log((dist.getOrElse(word, unkMass) - DISCOUNT_FACTOR) / denom) + }).sum + } + else + 0.0 + //println(" = " + probOfDocGivenLocation) + + /*print(cand.getName + " in cell " + curCellNum + ": ") + if(dist == null) + println("NULL") + else + println(probOfDocGivenLocation)*/ + + val logProbOfLocation = math.log(localContextComponent) + logProbOfDocGivenLocation + + if(logProbOfLocation > maxLogProb) { + indexToSelect = candIndex + maxLogProb = logProbOfLocation + } + + candIndex += 1 + } + + if(indexToSelect >= 0) + toponym.setSelectedIdx(indexToSelect) + } + } + tokIndex += 1 + } + + // Backoff to DocDist: + val docDistResolver = new DocDistResolver(logFilePath) + docDistResolver.overwriteSelecteds = false + docDistResolver.disambiguate(corpus) +*/ + corpus + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/ConstructionTPPResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/ConstructionTPPResolver.scala new file mode 100644 index 0000000..d77bd97 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/ConstructionTPPResolver.scala @@ -0,0 +1,68 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.tpp._ +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import java.util.ArrayList + +import scala.collection.JavaConversions._ + +class ConstructionTPPResolver(val dpc:Double, + val threshold:Double, + val corpus:StoredCorpus, + val modelDirPath:String) + extends TPPResolver(new TPPInstance( + //new MaxentPurchaseCoster(corpus, modelDirPath), + new MultiPurchaseCoster(List(new GaussianPurchaseCoster,//new SimpleContainmentPurchaseCoster, + new MaxentPurchaseCoster(corpus, modelDirPath))), + new GaussianTravelCoster)) { + //new SimpleDistanceTravelCoster)) { + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + for(doc <- corpus) { + + if(threshold < 0) + tppInstance.markets = (new GridMarketCreator(doc, dpc)).apply + else + tppInstance.markets = (new ClusterMarketCreator(doc, threshold)).apply + + // Apply a TPPSolver + val solver = new ConstructionTPPSolver + val tour = solver(tppInstance) + //println(doc.getId+" had a tour of length "+tour.size) + if(doc.getId.equals("d94")) { + solver.writeKML(tour, "d94-tour.kml") + } + + // Decode the tour into the corpus + val solutionMap = solver.getSolutionMap(tour) + + val docAsArray = TextUtil.getDocAsArrayNoFilter(doc) + + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + if(solutionMap.contains((doc.getId, tokIndex))) { + toponym.setSelectedIdx(solutionMap((doc.getId, tokIndex))) + //if(toponym.getSelectedIdx >= toponym.getAmbiguity) { + // println(tokIndex) + // println(toponym.getForm+": "+toponym.getSelectedIdx+" >= "+toponym.getAmbiguity) + //} + } + //else { + // println(doc.getId+": "+toponym.getForm) + //} + } + + tokIndex += 1 + } + + } + + corpus + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/DocDistResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/DocDistResolver.scala new file mode 100644 index 0000000..d00e73a --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/DocDistResolver.scala @@ -0,0 +1,36 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class DocDistResolver(val logFilePath:String) extends Resolver { + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val predDocLocations = (for(pe <- LogUtil.parseLogFile(logFilePath)) yield { + (pe.docName, pe.predCoord) + }).toMap + + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + if(overwriteSelecteds || !toponym.hasSelected) { + val predDocLocation = predDocLocations.getOrElse(doc.getId, null) + if(predDocLocation != null) { + val indexToSelect = toponym.getCandidates.zipWithIndex.minBy( + p => p._1.getRegion.distance(predDocLocation))._2 + if(indexToSelect != -1) { + toponym.setSelectedIdx(indexToSelect) + } + } + } + } + } + } + + corpus + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/HeuristicTPPResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/HeuristicTPPResolver.scala new file mode 100644 index 0000000..4f948fc --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/HeuristicTPPResolver.scala @@ -0,0 +1,186 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import java.util._ + +import scala.collection.JavaConversions._ + +class HeuristicTPPResolver extends Resolver { + + val DPC = 1.0 + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + for(doc <- corpus) { + + println("\n---\n") + + val cellsToLocs = new HashMap[Int, HashSet[Location]] + val candsToNames = new HashMap[Location, HashSet[String]] + val unresolvedToponyms = new HashSet[String] + val localGaz = new HashMap[String, HashSet[Location]] + + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + unresolvedToponyms.add(toponym.getForm) + if(!localGaz.contains(toponym.getForm)) { + localGaz.put(toponym.getForm, new HashSet[Location]) + for(cand <- toponym) localGaz.get(toponym.getForm).add(cand) + } + for(cand <- toponym) { + if(!candsToNames.containsKey(cand)) + candsToNames.put(cand, new HashSet[String]) + candsToNames.get(cand).add(toponym.getForm) + + val cellNum = TopoUtil.getCellNumber(cand.getRegion.getCenter, DPC) + if(!cellsToLocs.containsKey(cellNum)) + cellsToLocs.put(cellNum, new HashSet[Location]) + cellsToLocs.get(cellNum).add(cand) + } + } + } + + val resolvedToponyms = new HashMap[String, Int] + + // Initialize tour to be empty + val tour = new ArrayList[Int] + + // While not all toponyms have been resolved, + while(!unresolvedToponyms.isEmpty) { + println(unresolvedToponyms.size + " unresolved toponyms. (First is " + unresolvedToponyms.toList(0) + ")") + // Iterate over ALL cellNums, choosing one to add to the tour + val bestCell = chooseBestCell(cellsToLocs, tour) + val locs = cellsToLocs.get(bestCell) + cellsToLocs.remove(bestCell) + + addCellToTour(tour, bestCell) + //println("Added " + bestCell + " (" + locs.size + " locations) to the tour, making it length " + tour.length) + + for(loc <- locs) { + println(candsToNames.get(loc).size) + for(name <- candsToNames.get(loc)) { + if(unresolvedToponyms.contains(name)) { + resolvedToponyms.put(name, bestCell) + println(" Resolved " + name + " to " + bestCell + ". " + cellsToLocs.size + " left.") + unresolvedToponyms.remove(name) + + for(otherLoc <- localGaz.get(name)) { + val otherCellNum = TopoUtil.getCellNumber(otherLoc.getRegion.getCenter, DPC) + if(cellsToLocs.contains(otherCellNum) && noneCanResolveTo(unresolvedToponyms, otherLoc, localGaz)) { + cellsToLocs.get(otherCellNum).remove(otherLoc) + println(" Removing " + name + " from " + otherCellNum) + if(cellsToLocs.get(otherCellNum).isEmpty) { + cellsToLocs.remove(otherCellNum) + println(" " + otherCellNum + " is now empty; removing. " + cellsToLocs.size + " left.") + } + } + } + + } + } + } + } + + // Iterate over document, setting each toponym token according to resolvedToponyms + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + if(resolvedToponyms.contains(toponym.getForm)) { + val cellToChoose = resolvedToponyms.get(toponym.getForm) + var index = 0 + for(cand <- toponym) { + val cellNum = TopoUtil.getCellNumber(cand.getRegion.getCenter, DPC) + if(cellNum == cellToChoose) { + toponym.setSelectedIdx(index) + } + index += 1 + } + } + } + } + } + + corpus + } + + def noneCanResolveTo(unresolvedToponyms:HashSet[String], otherLoc:Location, + localGaz:HashMap[String, HashSet[Location]]): Boolean = { + for(top <- unresolvedToponyms) { + for(cand <- localGaz.get(top)) { + if(cand == otherLoc) + return false + } + } + + true + } + + def chooseBestCell(cellsToLocs:HashMap[Int, HashSet[Location]], tour:ArrayList[Int]): Int = { + mostLocsRanker(cellsToLocs) + //leastDistAddedRanker(cellsToLocs, tour) + //comboRanker(cellsToLocs, tour) + } + + def mostLocsRanker(cellsToLocs:HashMap[Int, HashSet[Location]]): Int = { + val sortedCellsToLocs:List[(Int, HashSet[Location])] = cellsToLocs.toList.sortWith( + (x, y) => x._2.size > y._2.size) + + sortedCellsToLocs(0)._1 + } + + def leastDistAddedRanker(cellsToLocs:HashMap[Int, HashSet[Location]], tour:ArrayList[Int]): Int = { + val sortedCellsToLocs:List[(Int, HashSet[Location])] = cellsToLocs.toList.sortWith( + (x, y) => computeBestIndexAndDist(tour, x._1)._2 < computeBestIndexAndDist(tour, y._1)._2) + + sortedCellsToLocs(0)._1 + } + + def comboRanker(cellsToLocs:HashMap[Int, HashSet[Location]], tour:ArrayList[Int]): Int = { + val sortedCellsToLocs:List[(Int, HashSet[Location])] = cellsToLocs.toList.sortWith( + (x, y) => if(x._2.size != y._2.size) x._2.size > y._2.size + else computeBestIndexAndDist(tour, x._1)._2 < computeBestIndexAndDist(tour, y._1)._2) + + sortedCellsToLocs(0)._1 + } + + def addCellToTour(tour:ArrayList[Int], cellNum:Int) { + tour.add(computeBestIndexAndDist(tour, cellNum)._1, cellNum) + } + + def computeBestIndexAndDist(tour:ArrayList[Int], cellNum:Int): (Int, Double) = { + val newCellCenter = TopoUtil.getCellCenter(cellNum, DPC) + + if(tour.length == 0) + return (0, 0.0) + + if(tour.length == 1) + return (0, newCellCenter.distance(TopoUtil.getCellCenter(tour.get(0), DPC))) + + var optimalIndex = -1 + var minDistChange = Double.PositiveInfinity + for(index <- 0 to tour.length) { + var distChange = 0.0 + if(index == 0) + distChange = newCellCenter.distance(TopoUtil.getCellCenter(tour.get(0), DPC)) + else if(index == tour.length) + distChange = newCellCenter.distance(TopoUtil.getCellCenter(tour.get(tour.length-1), DPC)) + else { + val prevCellCenter = TopoUtil.getCellCenter(tour.get(index-1), DPC) + val nextCellCenter = TopoUtil.getCellCenter(tour.get(index), DPC) + distChange = prevCellCenter.distance(newCellCenter) + + newCellCenter.distance(nextCellCenter) + - prevCellCenter.distance(nextCellCenter) + } + + if(distChange < minDistChange) { + minDistChange = distChange + optimalIndex = index + } + } + + (optimalIndex, minDistChange) + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/LabelPropResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/LabelPropResolver.scala new file mode 100644 index 0000000..64e6a14 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/LabelPropResolver.scala @@ -0,0 +1,276 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import upenn.junto.app._ +import upenn.junto.config._ + +import gnu.trove._ + +import scala.collection.JavaConversions._ + +class LabelPropResolver( + val logFilePath:String, + val knn:Int) extends Resolver { + + val DPC = 1.0 + + val DOC = "doc_" + val TYPE = "type_" + val TOK = "tok_" + val LOC = "loc_" + val TPNM_TYPE = "tpnm_type_" + val CELL = "cell_" + val CELL_LABEL = "cell_label_" + + val nonemptyCellNums = new scala.collection.mutable.HashSet[Int]() + + val docTokNodeRE = """^doc_(.+)_tok_(.+)$""".r + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val graph = createGraph(corpus) + + JuntoRunner(graph, 1.0, .01, .01, 10, false) + + // Interpret output graph and setSelectedIdx of toponyms accordingly: + + val tokensToCells = + (for ((id, vertex) <- graph._vertices) yield { + if(docTokNodeRE.findFirstIn(id) != None) { + val docTokNodeRE(docid, tokidx) = id + //println(DOC+docid+"_"+TOK+tokidx+" "+getGreatestCell(vertex.GetEstimatedLabelScores)) + //println(id) + /*val scores = vertex.GetEstimatedLabelScores + val cellScorePairs = (for(key <- scores.keys) yield { + (key,scores.get(key.toString)) + }) + cellScorePairs.sortWith((x, y) => x._2 > y._2).foreach(x => print(x._1+":"+x._2+" ")) + println + println(getGreatestCell(vertex.GetEstimatedLabelScores))*/ + Some(((docid, tokidx.toInt), getGreatestCell(vertex.GetEstimatedLabelScores, nonemptyCellNums.toSet))) + } + else { + //println(id) + None + } + }).flatten.toMap + + for(doc <- corpus) { + var tokenIndex = -1 + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + tokenIndex += 1 + val predCell = tokensToCells.getOrElse((doc.getId, tokenIndex), -1) + if(predCell != -1) { + val indexToSelect = TopoUtil.getCorrectCandidateIndex(toponym, predCell, DPC) + if(indexToSelect != -1) { + toponym.setSelectedIdx(indexToSelect) + //println(toponym.getSelected) + } + } + } + } + } + + corpus + } + + def getGreatestCell(estimatedLabelScores: TObjectDoubleHashMap[String], nonemptyCellNums: Set[Int]): Int = { + estimatedLabelScores.keys(Array[String]()).filter(_.startsWith(CELL_LABEL)) + .map(_.drop(CELL_LABEL.length).toInt).filter(nonemptyCellNums(_)).maxBy(x => estimatedLabelScores.get(CELL_LABEL+x)) + } + + def createGraph(corpus:StoredCorpus) = { + val edges:List[Edge] = getEdges(corpus) + val seeds:List[Label] = getSeeds(corpus) + + GraphBuilder(edges, seeds) + } + + def getEdges(corpus:StoredCorpus): List[Edge] = { + getTokDocEdges(corpus) ::: + //getTokTokEdges(corpus) ::: + //getTokDocTypeEdges(corpus) ::: + //getDocTypeGlobalTypeEdges(corpus) ::: + getTokGlobalTypeEdges(corpus) ::: // alternative to the above two edge sets + getGlobalTypeLocationEdges(corpus) ::: + getLocationCellEdges(corpus) ::: + getCellCellEdges + } + + def getTokDocEdges(corpus:StoredCorpus): List[Edge] = { + var result = + (for(doc <- corpus) yield { + var tokenIndex = -1 + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + tokenIndex += 1 + new Edge(DOC+doc.getId, DOC+doc.getId+"_"+TOK+tokenIndex, 1.0) + }) + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + def getTokTokEdges(corpus:StoredCorpus): List[Edge] = { + var result = + (for(doc <- corpus) yield { + var prevTok:String = null + var tokenIndex = -1 + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + tokenIndex += 1 + val curTok = DOC+doc.getId+"_"+TOK+tokenIndex + val edge = + if(prevTok != null) + Some(new Edge(prevTok, curTok, 1.0)) + else + None + prevTok = curTok + edge + }).flatten + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + def getTokDocTypeEdges(corpus:StoredCorpus): List[Edge] = { + val result = + (for(doc <- corpus) yield { + var tokenIndex = -1 + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + tokenIndex += 1 + new Edge(DOC+doc.getId+"_"+TOK+tokenIndex, DOC+doc.getId+"_"+TYPE+toponym.getForm, 1.0) + }) + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + def getDocTypeGlobalTypeEdges(corpus:StoredCorpus): List[Edge] = { + val result = (for(doc <- corpus) yield { + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + new Edge(DOC+doc.getId+"_"+TYPE+toponym.getForm, TPNM_TYPE+toponym.getForm, 1.0) + }) + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + // This is an alternative to using BOTH of the above two edge sets: + // Don't instantiate DocType edges and go straight from tokens to global types + def getTokGlobalTypeEdges(corpus:StoredCorpus): List[Edge] = { + val result = + (for(doc <- corpus) yield { + var tokenIndex = -1 + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + tokenIndex += 1 + new Edge(DOC+doc.getId+"_"+TOK+tokenIndex, TPNM_TYPE+toponym.getForm, 1.0) + }) + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + def getGlobalTypeLocationEdges(corpus:StoredCorpus): List[Edge] = { + val result = (for(doc <- corpus) yield { + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + (for(loc <- toponym.getCandidates) yield { + new Edge(TPNM_TYPE+toponym.getForm, LOC+loc.getId, 1.0) + }) + }).flatten + }).flatten + }).flatten.toList + //result.foreach(println) + result + } + + def getLocationCellEdges(corpus:StoredCorpus): List[Edge] = { + val result = (for(doc <- corpus) yield { + (for(sent <- doc) yield { + (for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) yield { + (for(loc <- toponym.getCandidates) yield { + (for(cellNum <- TopoUtil.getCellNumbers(loc, DPC)) yield { + nonemptyCellNums.add(cellNum) + new Edge(LOC+loc.getId, CELL+cellNum, 1.0) + }) + }).flatten + }).flatten + }).flatten + }).flatten.toList + + //result.foreach(println) + + result + } + + def getCellCellEdges: List[Edge] = { + var lon = 0.0 + var lat = 0.0 + val edges = new collection.mutable.ListBuffer[Edge] + while(lon < 360.0/DPC) { + lat = 0.0 + while(lat < 180.0/DPC) { + val curCellNumber = TopoUtil.getCellNumber(lat, lon, DPC) + //if(nonemptyCellNums contains curCellNumber) { + val leftCellNumber = TopoUtil.getCellNumber(lat, lon - DPC, DPC) + //if(nonemptyCellNums contains leftCellNumber) + edges.append(new Edge(CELL+curCellNumber, CELL+leftCellNumber, 1.0)) + val rightCellNumber = TopoUtil.getCellNumber(lat, lon + DPC, DPC) + //if(nonemptyCellNums contains rightCellNumber) + edges.append(new Edge(CELL+curCellNumber, CELL+rightCellNumber, 1.0)) + val topCellNumber = TopoUtil.getCellNumber(lat + DPC, lon, DPC) + //if(nonemptyCellNums contains topCellNumber) + edges.append(new Edge(CELL+curCellNumber, CELL+topCellNumber, 1.0)) + val bottomCellNumber = TopoUtil.getCellNumber(lat - DPC, lon, DPC) + //if(nonemptyCellNums contains bottomCellNumber) + edges.append(new Edge(CELL+curCellNumber, CELL+bottomCellNumber, 1.0)) + //} + lat += DPC + } + lon += DPC + } + //edges.foreach(println) + edges.toList + } + + def getSeeds(corpus:StoredCorpus): List[Label] = { + getCellCellLabelSeeds(corpus) ::: getDocCellLabelSeeds + } + + def getCellCellLabelSeeds(corpus:StoredCorpus): List[Label] = { + val result = (for(cellNum <- nonemptyCellNums) yield { + new Label(CELL+cellNum, CELL_LABEL+cellNum, 1.0) + }).toList + //result.foreach(println) + result + } + + def getDocCellLabelSeeds: List[Label] = { + val result = + if(logFilePath != null) { + (for(pe <- LogUtil.parseLogFile(logFilePath)) yield { + (for((cellNum, probMass) <- pe.getProbDistOverPredCells(knn, DPC)) yield { + new Label(DOC+pe.docName, CELL_LABEL+cellNum, probMass) + }) + }).flatten.toList + } + else + Nil + //result.foreach(println) + result + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/MaxentResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/MaxentResolver.scala new file mode 100644 index 0000000..afe4784 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/MaxentResolver.scala @@ -0,0 +1,131 @@ +package opennlp.fieldspring.tr.resolver + +import java.io._ + +import opennlp.maxent._ +import opennlp.maxent.io._ +import opennlp.model._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class MaxentResolver(val logFilePath:String, + val modelDirPath:String) extends Resolver { + + val windowSize = 20 + val dpc = 1.0 + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val modelDir = new File(modelDirPath) + + val toponymsToModels:Map[String, AbstractModel] = + (for(file <- modelDir.listFiles.filter(_.getName.endsWith(".mxm"))) yield { + val dataInputStream = new DataInputStream(new FileInputStream(file)); + val reader = new BinaryGISModelReader(dataInputStream) + val model = reader.getModel + + //println("Found model for "+file.getName.dropRight(4)) + + (file.getName.dropRight(4).replaceAll("_", " "), model) + }).toMap + + for(doc <- corpus) { + val docAsArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0 + && toponymsToModels.containsKey(token.getForm)) { + val toponym = token.asInstanceOf[Toponym] + val contextFeatures = TextUtil.getContextFeatures(docAsArray, tokIndex, windowSize, Set[String]()) + //print("Features for "+token.getForm+": ") + //contextFeatures.foreach(f => print(f+",")) + //println + //print("\n" + token.getForm+" ") + val bestIndex = MaxentResolver.getBestIndex(toponymsToModels(token.getForm), contextFeatures) + + //val bestIndex = MaxentResolver.getCellDist(toponymsToModels(toponym.getForm), contextFeatures, + // toponym.getCandidates.toList, dpc) + //println("best index for "+token.getForm+": "+bestIndex) + if(bestIndex != -1) + token.asInstanceOf[Toponym].setSelectedIdx(bestIndex) + } + tokIndex += 1 + } + } + + // Backoff to DocDist: + val docDistResolver = new DocDistResolver(logFilePath) + docDistResolver.overwriteSelecteds = false + docDistResolver.disambiguate(corpus) + + corpus + } +} + +object MaxentResolver { + def getBestIndex(model:AbstractModel, features:Array[String]): Int = { + //print(candidates.map(c => TopoUtil.getCellNumber(c.getRegion.getCenter, dpc)) + " ") + //val candCellNums = candidates.map(c => TopoUtil.getCellNumber(c.getRegion.getCenter, dpc)).toSet + //candCellNums.foreach(println) + //val cellNumToCandIndex = candidates.zipWithIndex + // .map(p => (TopoUtil.getCellNumber(p._1.getRegion.getCenter, dpc), p._2)).toMap + //cellNumToCandIndex.foreach(p => println(p._1+": "+p._2)) + //val labels = model.getDataStructures()(2).asInstanceOf[Array[String]] + //labels.foreach(l => print(l+",")) + //println + //val result = model.eval(features) + //result.foreach(r => print(r+",")) + //println + //features.foreach(f => print(f+",")) + //result.foreach(r => print(r+" ")) + //println + //val sortedResult = result.zipWithIndex.sortWith((x, y) => x._1 > y._1) + + //labels(sortedResult(0)._2).toInt + + //sortedResult.foreach(r => print(r+" ")) + /*for(i <- 0 until sortedResult.size) { // i should never make it past 0 + val label = labels(sortedResult(i)._2).toInt + if(candCellNums contains label) + return cellNumToCandIndex(label) + } + -1*/ + + getIndexToWeightMap(model, features).maxBy(_._2)._1//.toList.sortBy(_._2).get(0)._1 + } + + def getIndexToWeightMap(model:AbstractModel, features:Array[String]): Map[Int, Double] = { + val labels = model.getDataStructures()(2).asInstanceOf[Array[String]] + val result = model.eval(features).zipWithIndex + + (for(p <- result) yield { + (labels(p._2).toInt, p._1) + }).toMap + } + + /*def getCellDist(model:AbstractModel, features:Array[String], candidates:List[Location], dpc:Double): Map[Int, Double] = { + val candCellNums = candidates.map(c => TopoUtil.getCellNumber(c.getRegion.getCenter, dpc)).toSet + //candCellNums.foreach(n => print(n+",")) + //println + + val labels = model.getDataStructures()(2).asInstanceOf[Array[String]].map(_.toInt) + //labels.foreach(l => print(l+",")) + //println + val result = model.eval(features) + //result.foreach(r => print(r+",")) + //println + val relevantResult = result.zipWithIndex.filter(r => candCellNums contains labels(r._2)) + //relevantResult.foreach(r => print(r+",")) + //println + val normFactor = relevantResult.map(_._1).sum + /*val toReturn = */relevantResult.map(r => (labels(r._2), r._1 / normFactor)).toMap + //toReturn.foreach(r => print(r+",")) + //println + //toReturn + + }*/ +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/PopulationResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/PopulationResolver.scala new file mode 100644 index 0000000..1a3c4d1 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/PopulationResolver.scala @@ -0,0 +1,29 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class PopulationResolver extends Resolver { + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val rand = new scala.util.Random + + for(doc <- corpus) { + for(sent <- doc) { + for(toponym <- sent.getToponyms.filter(_.getAmbiguity > 0)) { + val maxPopLocPair = toponym.getCandidates.zipWithIndex.maxBy(_._1.getPopulation) + if(maxPopLocPair._1.getPopulation > 0) + toponym.setSelectedIdx(maxPopLocPair._2) + else + toponym.setSelectedIdx(rand.nextInt(toponym.getAmbiguity)) // back off to random + } + } + } + + corpus + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/ProbabilisticResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/ProbabilisticResolver.scala new file mode 100644 index 0000000..44a8368 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/ProbabilisticResolver.scala @@ -0,0 +1,282 @@ +package opennlp.fieldspring.tr.resolver + +import java.io._ +import java.util.ArrayList + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import opennlp.maxent._ +import opennlp.maxent.io._ +import opennlp.model._ + +import scala.collection.JavaConversions._ + +class ProbabilisticResolver(val logFilePath:String, + val modelDirPath:String, + val popComponentCoefficient:Double, + val dgProbOnly:Boolean, + val meProbOnly:Boolean) extends Resolver { + + val KNN = -1 + val DPC = 1.0 + val WINDOW_SIZE = 20 + + val C = 1.0E-4 // used in f(t)/(f(t)+C) + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val toponymLexicon:Lexicon[String] = TopoUtil.buildLexicon(corpus) + val weightsForWMD:ArrayList[ArrayList[Double]] = new ArrayList[ArrayList[Double]](toponymLexicon.size) + for(i <- 0 until toponymLexicon.size) weightsForWMD.add(null) + + val docIdToCellDist:Map[String, Map[Int, Double]] = + (for(pe <- LogUtil.parseLogFile(logFilePath)) yield { + (pe.docName, pe.getProbDistOverPredCells(KNN, DPC).toMap) + }).toMap + + val predDocLocations = (for(pe <- LogUtil.parseLogFile(logFilePath)) yield { + (pe.docName, pe.predCoord) + }).toMap + + val modelDir = new File(modelDirPath) + + val toponymsToModels:Map[String, AbstractModel] = + (for(file <- modelDir.listFiles.filter(_.getName.endsWith(".mxm"))) yield { + val dataInputStream = new DataInputStream(new FileInputStream(file)); + val reader = new BinaryGISModelReader(dataInputStream) + val model = reader.getModel + + //println(file.getName.dropRight(4).replaceAll("_", " ")) + (file.getName.dropRight(4).replaceAll("_", " "), model) + }).toMap + + var total = 0 + var toponymsToCounts = //new scala.collection.mutable.HashMap[String, Int] + (for(file <- modelDir.listFiles.filter(_.getName.endsWith(".txt"))) yield { + val count = scala.io.Source.fromFile(file).getLines.toList.size + total += count + + //println(file.getName.dropRight(4).replaceAll("_", " ") + " " + count) + + (file.getName.dropRight(4).replaceAll("_", " "), count) + }).toMap + + //println(total) + /*var total = 0 + for(doc <- corpus) { + for(sent <- doc) { + for(token <- sent) { + if(token.isToponym) { + val prevCount = toponymsToCounts.getOrElse(token.getForm, 0) + toponymsToCounts.put(token.getForm, prevCount + 1) + } + total += 1 + } + } + }*/ + + val toponymsToFrequencies = toponymsToCounts.map(p => (p._1, p._2.toDouble / total)).toMap + //toponymsToCounts.clear // trying to free up some memory + toponymsToCounts = null + //toponymsToFrequencies.foreach(p => println(p._1+": "+p._2)) + + for(doc <- corpus) { + val docAsArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + + //print("\n" + toponym.getForm + " ") + + // P(l|t,d_c(t)) + val candDistGivenLocalContext = + if(toponymsToModels.containsKey(toponym.getForm)) { + val contextFeatures = TextUtil.getContextFeatures(docAsArray, tokIndex, WINDOW_SIZE, Set[String]()) + + MaxentResolver.getIndexToWeightMap(toponymsToModels(toponym.getForm), contextFeatures) + + //println("getting a cell dist for "+toponym.getForm) + + /*val d = */ //MaxentResolver.getIndexToWeightMap(toponymsToModels(toponym.getForm), contextFeatures) + //println(d.size) + //d.foreach(println) + //d + + + } + else + null + + // P(l|d) + //val prev = docIdToCellDist.getOrElse(doc.getId, null) + val cellDistGivenDocument = filterAndNormalize(docIdToCellDist.getOrElse(doc.getId, null), toponym) + /*if(prev != null) { + println("prev size = " + prev.size) + println(" new size = " + cellDistGivenDocument.size) + println("-----") + }*/ + + val topFreq = toponymsToFrequencies.getOrElse(toponym.getForm, 0.0) + val lambda = topFreq / (topFreq + C)//0.7 + //println(toponym.getForm+" "+lambda) + + var indexToSelect = -1 + var maxProb = 0.0 + var candIndex = 0 + + val totalPopulation = toponym.getCandidates.map(_.getPopulation).sum + val maxPopulation = toponym.getCandidates.map(_.getPopulation).max + + var candDist = weightsForWMD.get(toponymLexicon.get(toponym.getForm)) + if(candDist == null) { + candDist = new ArrayList[Double](toponym.getAmbiguity) + for(i <- 0 until toponym.getAmbiguity) candDist.add(0.0) + } + + //print("\n"+toponym.getForm+" ") + + for(cand <- toponym.getCandidates) { + val curCellNum = TopoUtil.getCellNumber(cand.getRegion.getCenter, DPC) + + //print(" " + curCellNum) + + val localContextComponent = + if(candDistGivenLocalContext != null) + candDistGivenLocalContext.getOrElse(candIndex, 0.0) + else + 0.0 + + //print(localContextComponent + " ") + + val documentComponent = + if(cellDistGivenDocument != null && cellDistGivenDocument.size > 0) + cellDistGivenDocument.getOrElse(curCellNum, 0.0) + else + 0.0 + + /*if(localContextComponent == 0.0) { + if(documentComponent == 0.0) { + println("BOTH ZERO") + } + else { + println("LOCAL ZERO") + } + } + else if(documentComponent == 0.0) + println("DOC ZERO")*/ + + // Incorporate administrative level here + val adminLevelComponent = getAdminLevelComponent(cand, toponym.getCandidates.toList/*cand.getType, cand.getAdmin1Code*/) + + // P(l|t,d) + val probComponent = adminLevelComponent * (lambda * localContextComponent + (1-lambda) * documentComponent) + + val probOfLocation = + if(meProbOnly) + localContextComponent + else if(dgProbOnly) + documentComponent + else if(totalPopulation > 0) { + val popComponent = cand.getPopulation.toDouble / totalPopulation + popComponentCoefficient * popComponent + (1 - popComponentCoefficient) * probComponent + } + else probComponent + /*if(totalPopulation > 0 && maxPopulation.toDouble / totalPopulation > .89) + cand.getPopulation.toDouble / totalPopulation + else probComponent*/ + + candDist.set(candIndex, candDist.get(candIndex) + probOfLocation) + + //print(" " + probOfLocation) + + if(probOfLocation > maxProb) { + indexToSelect = candIndex + maxProb = probOfLocation + } + + candIndex += 1 + } + + weightsForWMD.set(toponymLexicon.get(toponym.getForm), candDist) + + /*if(indexToSelect == -1) { + val predDocLocation = predDocLocations.getOrElse(doc.getId, null) + if(predDocLocation != null) { + val indexToSelectBackoff = toponym.getCandidates.zipWithIndex.minBy(p => p._1.getRegion.distance(predDocLocation))._2 + if(indexToSelectBackoff != -1) { + indexToSelect = indexToSelectBackoff + } + } + }*/ + + toponym.setSelectedIdx(indexToSelect) + + } + tokIndex += 1 + } + } + + val out = new DataOutputStream(new FileOutputStream("probToWMD.dat")) + for(weights <- weightsForWMD/*.filterNot(x => x == null)*/) { + if(weights == null) + out.writeInt(0) + else { + val sum = weights.sum + out.writeInt(weights.size) + for(i <- 0 until weights.size) { + val newWeight = if(sum > 0) (weights.get(i) / sum) * weights.size else 1.0 + weights.set(i, newWeight) + out.writeDouble(newWeight) + //println(newWeight) + } + //println + } + } + out.close + + // Backoff to DocDist: + val docDistResolver = new DocDistResolver(logFilePath) + docDistResolver.overwriteSelecteds = false + docDistResolver.disambiguate(corpus) + + corpus + } + + def getAdminLevelComponent(loc:Location, candList:List[Location]): Double = { + val numerator = loc.getRegion.getRepresentatives.size + val denominator = candList.map(_.getRegion.getRepresentatives.size).sum + val frac = numerator.toDouble / denominator + frac + } + + def filterAndNormalize(dist:Map[Int, Double], toponym:Toponym): Map[Int, Double] = { + val cells = toponym.getCandidates.map(l => TopoUtil.getCellNumber(l.getRegion.getCenter, DPC)).toSet + val filteredDist = dist.filter(c => cells(c._1)) + val sum = filteredDist.map(_._2).sum + filteredDist.map(c => (c._1, c._2 / sum)) + } + + //val countryRE = """^\w\w\.\d\d$""".r + val usStateRE = """^US\.[A-Za-z][A-Za-z]$""".r + + def getAdminLevelComponentOld(locType:Location.Type, admin1Code:String): Double = { + if(locType == Location.Type.STATE) { + if(usStateRE.findFirstIn(admin1Code) != None) { // US State + .2 + } + else { // Country + .7 + } + } + else if(locType == Location.Type.CITY) { // City + 0.095 + } + else { // Other + 0.005 + } + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/TPPResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/TPPResolver.scala new file mode 100644 index 0000000..c0e21f7 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/TPPResolver.scala @@ -0,0 +1,7 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.tpp._ + +abstract class TPPResolver(val tppInstance:TPPInstance) extends Resolver { + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/resolver/ToponymAsDocDistResolver.scala b/src/main/scala/opennlp/fieldspring/tr/resolver/ToponymAsDocDistResolver.scala new file mode 100644 index 0000000..9b44faf --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/resolver/ToponymAsDocDistResolver.scala @@ -0,0 +1,42 @@ +package opennlp.fieldspring.tr.resolver + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class ToponymAsDocDistResolver(val logFilePath:String) extends Resolver { + + val docTokRE = """(.+)_([0-9]+)""".r + val alphanumRE = """^[a-zA-Z0-9]+$""".r + + def disambiguate(corpus:StoredCorpus): StoredCorpus = { + + val predLocations = (for(pe <- LogUtil.parseLogFile(logFilePath)) yield { + val docTokRE(docName, tokenIndex) = pe.docName + ((docName, tokenIndex.toInt), pe.predCoord) + }).toMap + + for(doc <- corpus) { + var tokenIndex = 0 + for(sent <- doc) { + for(token <- sent.filter(t => alphanumRE.findFirstIn(t.getForm) != None)) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + val predLocation = predLocations.getOrElse((doc.getId, tokenIndex), null) + if(predLocation != null) { + val indexToSelect = toponym.getCandidates.zipWithIndex.minBy(p => p._1.getRegion.distance(predLocation))._2 + if(indexToSelect != -1) { + toponym.setSelectedIdx(indexToSelect) + } + } + } + tokenIndex += 1 + } + } + } + + corpus + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/text/io/DynamicKMLWriter.scala b/src/main/scala/opennlp/fieldspring/tr/text/io/DynamicKMLWriter.scala new file mode 100644 index 0000000..d969ab2 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/text/io/DynamicKMLWriter.scala @@ -0,0 +1,53 @@ +package opennlp.fieldspring.tr.text.io + +import java.io._ +import java.util._ +import javax.xml.datatype._ +import javax.xml.stream._ + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class DynamicKMLWriter(val corpus:StoredCorpus/*, + val outputGoldLocations:Boolean = false*/) { + + lazy val factory = XMLOutputFactory.newInstance + + val CONTEXT_SIZE = 20 + + def write(out:XMLStreamWriter) { + KMLUtil.writeHeader(out, "corpus") + + var globalTokIndex = 0 + var globalTopIndex = 1 + for(doc <- corpus) { + val docArray = TextUtil.getDocAsArray(doc) + var tokIndex = 0 + for(token <- docArray) { + if(token.isToponym) { + val toponym = token.asInstanceOf[Toponym] + if(toponym.getAmbiguity > 0 && toponym.hasSelected) { + val coord = toponym.getSelected.getRegion.getCenter + val context = TextUtil.getContext(docArray, tokIndex, CONTEXT_SIZE) + KMLUtil.writePinTimeStampPlacemark(out, toponym.getOrigForm, coord, context, globalTopIndex) + globalTopIndex += 1 + } + } + tokIndex += 1 + globalTokIndex += 1 + } + } + + KMLUtil.writeFooter(out) + out.close + } + + def write(file:File) { + val stream = new BufferedOutputStream(new FileOutputStream(file)) + this.write(this.factory.createXMLStreamWriter(stream, "UTF-8")) + stream.close() + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/text/io/GigawordSource.scala b/src/main/scala/opennlp/fieldspring/tr/text/io/GigawordSource.scala new file mode 100644 index 0000000..b5e10be --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/text/io/GigawordSource.scala @@ -0,0 +1,62 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.text.io + +import java.io.BufferedReader +import java.io.File +import java.io.FileReader +import java.util.ArrayList +import java.util.List +import scala.collection.JavaConversions._ +import scala.collection.mutable.Buffer + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ + +class GigawordSource( + reader: BufferedReader, + private val sentencesPerDocument: Int, + private val numberOfDocuments: Int) + extends TextSource(reader) { + + def this(reader: BufferedReader, sentencesPerDocument: Int) = + this(reader, sentencesPerDocument, Int.MaxValue) + def this(reader: BufferedReader) = this(reader, 50) + + val sentences = new Iterator[Sentence[Token]] { + var current = GigawordSource.this.readLine + def hasNext: Boolean = current != null + def next: Sentence[Token] = new Sentence[Token](null) { + val buffer = Buffer(new SimpleToken(current)) + current = GigawordSource.this.readLine + while (current.trim.length > 0) { + buffer += new SimpleToken(current) + current = GigawordSource.this.readLine + } + current = GigawordSource.this.readLine + + def tokens: java.util.Iterator[Token] = buffer.toIterator + } + }.grouped(sentencesPerDocument).take(numberOfDocuments) + + def hasNext: Boolean = sentences.hasNext + + def next: Document[Token] = new Document[Token](null) { + def iterator: java.util.Iterator[Sentence[Token]] = + sentences.next.toIterator + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/text/io/WikiTextSource.scala b/src/main/scala/opennlp/fieldspring/tr/text/io/WikiTextSource.scala new file mode 100644 index 0000000..2180f8c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/text/io/WikiTextSource.scala @@ -0,0 +1,52 @@ +package opennlp.fieldspring.tr.text.io + +import java.io._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Buffer +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ + +class WikiTextSource( + reader: BufferedReader +) extends TextSource(reader) { + + val TITLE_PREFIX = "Article title: " + val TITLE_INDEX = TITLE_PREFIX.length + val ID_INDEX = "Article ID: ".length + + var id = "-1" + var title = "" + + val sentences = new Iterator[Sentence[Token]] { + var current = WikiTextSource.this.readLine + + def hasNext: Boolean = current != null + def next: Sentence[Token] = new Sentence[Token](null) { + if(current != null) { + title = current.drop(TITLE_INDEX).trim + current = WikiTextSource.this.readLine + id = current.drop(ID_INDEX).trim + current = WikiTextSource.this.readLine + } + val buffer = Buffer(new SimpleToken(current)) + current = WikiTextSource.this.readLine + while (current != null && !current.trim.startsWith(TITLE_PREFIX)) { + buffer += new SimpleToken(current) + current = WikiTextSource.this.readLine + } + + def tokens: java.util.Iterator[Token] = buffer.toIterator + } + }.grouped(1) // assume each document is a whole sentence, since we don't have sentence boundaries + + def hasNext: Boolean = sentences.hasNext + + def next: Document[Token] = { + new Document[Token](id, title) { + def iterator: java.util.Iterator[Sentence[Token]] = { + sentences.next.toIterator + } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/topo/SphericalGeometry.scala b/src/main/scala/opennlp/fieldspring/tr/topo/SphericalGeometry.scala new file mode 100644 index 0000000..fdf5cd3 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/topo/SphericalGeometry.scala @@ -0,0 +1,54 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo + +import scala.io._ + +import scala.collection.JavaConversions._ + +import opennlp.fieldspring.tr.util.cluster._ + +object SphericalGeometry { + implicit def g: Geometry[Coordinate] = new Geometry[Coordinate] { + def distance(x: Coordinate)(y: Coordinate): Double = x.distance(y) + def centroid(ps: Seq[Coordinate]): Coordinate = Coordinate.centroid(ps) + } + + def main(args: Array[String]) { + val max = args(1).toInt + val k = args(2).toInt + val style = args(3) + + val cs = Source.fromFile(args(0)).getLines.map { line => + val Array(lat, lng) = line.split("\t").map(_.toDouble) + Coordinate.fromDegrees(lat, lng) + }.toIndexedSeq + println("Loaded...") + + val xs = scala.util.Random.shuffle(cs).take(max) + + println(Coordinate.centroid(xs)) + + val clusterer = new KMeans + val clusters = clusterer.cluster(xs, k) + clusters.foreach { + case c => println("" + + style + "" + + c.getLngDegrees + "," + c.getLatDegrees + + "") + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/topo/gaz/CorpusGazetteerReader.scala b/src/main/scala/opennlp/fieldspring/tr/topo/gaz/CorpusGazetteerReader.scala new file mode 100644 index 0000000..8408c43 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/topo/gaz/CorpusGazetteerReader.scala @@ -0,0 +1,39 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz + +import java.util.Iterator +import scala.collection.JavaConversions._ + +import opennlp.fieldspring.tr.text.Corpus +import opennlp.fieldspring.tr.text.Token +import opennlp.fieldspring.tr.topo.Location + +class CorpusGazetteerReader(private val corpus: Corpus[_ <: Token]) + extends GazetteerReader { + + private val it = corpus.flatMap(_.flatMap { + _.getToponyms.flatMap(_.getCandidates) + }).toIterator + + def hasNext: Boolean = it.hasNext + def next: Location = it.next + + def close() { + corpus.close() + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/topo/gaz/geonames/GeoNamesParser.scala b/src/main/scala/opennlp/fieldspring/tr/topo/gaz/geonames/GeoNamesParser.scala new file mode 100644 index 0000000..09fcfdd --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/topo/gaz/geonames/GeoNamesParser.scala @@ -0,0 +1,34 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.gaz.geonames + +import java.io._ +import scala.collection.JavaConversions._ +import scala.io._ + +import opennlp.fieldspring.tr.text.Corpus +import opennlp.fieldspring.tr.text.Token +import opennlp.fieldspring.tr.topo.Location + +class GeoNamesParser(private val file: File) { + val locs = scala.collection.mutable.Map[String, List[(Double, Double)]]() + + Source.fromFile(file).getLines.foreach { line => + val Array(lat, lng) = line.split("\t").map(_.toDouble) + + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/topo/util/CodeConverter.scala b/src/main/scala/opennlp/fieldspring/tr/topo/util/CodeConverter.scala new file mode 100644 index 0000000..4a0d094 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/topo/util/CodeConverter.scala @@ -0,0 +1,76 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.topo.util + +import java.io._ +import java.io.InputStream +import scala.collection.mutable.Map + +class CodeConverter(in: InputStream) { + def this() = this { + getClass.getResourceAsStream("/data/geo/country-codes.txt") + } + + case class Country( + name: String, + fips: Option[String], + iso: Option[(String, String, Int)], + stanag: Option[String], + tld: Option[String]) + + private val countriesF = Map[String, Country]() + private val countriesI = Map[String, Country]() + private val reader = new BufferedReader(new InputStreamReader(in)) + + private var line = reader.readLine + while (line != null) { + val fs = line.split("\t") + val country = Country( + fs(0), + if (fs(1) == "-") None else Some(fs(1)), + if (fs(2) == "-") None else Some(fs(2), fs(3), fs(4).toInt), + if (fs(5) == "-") None else Some(fs(5)), + if (fs(6) == "-") None else Some(fs(6)) + ) + country.fips match { + case Some(fips) => countriesF(fips) = country + case _ => + } + country.iso match { + case Some((iso2, _, _)) => countriesI(iso2) = country + case _ => + } + line = reader.readLine + } + reader.close() + + def convertFipsToIso2(code: String): Option[String] = + countriesF.get(code).flatMap(_.iso.map(_._1)) + + def convertIso2ToFips(code: String): Option[String] = + countriesI.get(code).flatMap(_.fips) +} + +object CodeConverter { + def main(args: Array[String]) { + val converter = new CodeConverter() + println(args(0) match { + case "f2i" => converter.convertIso2ToFips(args(1)) + case "i2f" => converter.convertFipsToIso2(args(1)) + }) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/ClusterMarketCreator.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/ClusterMarketCreator.scala new file mode 100644 index 0000000..b0164be --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/ClusterMarketCreator.scala @@ -0,0 +1,111 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util._ + +import java.util.ArrayList + +import scala.collection.JavaConversions._ + +class ClusterMarketCreator(doc:Document[StoredToken], val thresholdInKm:Double) extends MarketCreator(doc) { + + val threshold = thresholdInKm / 6372.8 + + override def apply:List[Market] = { + + // Initialize singleton clusters: + val clusters = new scala.collection.mutable.HashSet[Cluster] + val docAsArray = TextUtil.getDocAsArrayNoFilter(doc) + var tokIndex = 0 + var clusterIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + var gazIndex = 0 + for(loc <- toponym.getCandidates) { + //val topMen = new ToponymMention(doc.getId, tokIndex) + val potLoc = new PotentialLocation(doc.getId, tokIndex, gazIndex, loc) + + val cluster = new Cluster(clusterIndex) + cluster.add(potLoc) + clusters.add(cluster) + + clusterIndex += 1 + + gazIndex += 1 + } + } + tokIndex += 1 + } + + // Repeatedly merge until no more merging happens: + var atLeastOneMerge = true + while(atLeastOneMerge) { + atLeastOneMerge = false + for(cluster <- clusters.toArray) { + if(clusters.size >= 2 && clusters.contains(cluster)) { + val closest = clusters.minBy(c => if(c.equals(cluster)) Double.PositiveInfinity else cluster.distance(c)) + + if(cluster.distance(closest) <= threshold) { + cluster.merge(closest) + clusters.remove(closest) + atLeastOneMerge = true + } + } + } + } + + // Turn clusters into Markets: + (for(cluster <- clusters) yield { + val tmsToPls = new scala.collection.mutable.HashMap[ToponymMention, PotentialLocation] + for(potLoc <- cluster.potLocs) + tmsToPls.put(new ToponymMention(potLoc.docId, potLoc.tokenIndex), potLoc) + new Market(cluster.id, tmsToPls.toMap) + }).toList + } +} + +class Cluster(val id:Int) { + + val potLocs = new ArrayList[PotentialLocation] + var centroid:Coordinate = null + + def add(potLoc:PotentialLocation) { + potLocs.add(potLoc) + val newCoord = potLoc.loc.getRegion.getCenter + if(size == 1) + centroid = newCoord + else { + val newLat = (centroid.getLat * (size-1) + newCoord.getLat) / size + val newLng = (centroid.getLng * (size-1) + newCoord.getLng) / size + centroid = Coordinate.fromRadians(newLat, newLng) + } + } + + def size = potLocs.size + + def distance(other:Cluster) = this.centroid.distance(other.centroid) + + // This cluster absorbs the cluster given as a parameter, keeping this.id: + def merge(other:Cluster) { + val newLat = (this.centroid.getLat * this.size + other.centroid.getLat * other.size) / (this.size + other.size) + val newLng = (this.centroid.getLng * this.size + other.centroid.getLng * other.size) / (this.size + other.size) + this.centroid = Coordinate.fromRadians(newLat, newLng) + + for(otherPotLoc <- other.potLocs) { + this.potLocs.add(otherPotLoc) + } + } + + override def equals(other:Any):Boolean = { + if(!other.isInstanceOf[Cluster]) + false + else { + val o = other.asInstanceOf[Cluster] + this.id.equals(o.id) + } + } + + override def hashCode:Int = id +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/ConstructionTPPSolver.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/ConstructionTPPSolver.scala new file mode 100644 index 0000000..eb6a3ee --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/ConstructionTPPSolver.scala @@ -0,0 +1,132 @@ +package opennlp.fieldspring.tr.tpp + +import java.util.ArrayList + +import scala.collection.JavaConversions._ + +class ConstructionTPPSolver extends TPPSolver { + def apply(tppInstance:TPPInstance): List[MarketVisit] = { + + val tour = new ArrayList[MarketVisit] + + val unvisitedMarkets = new ArrayList[Market](tppInstance.markets) + + val unresolvedToponymMentions = getUnresolvedToponymMentions(tppInstance) + + while(!unresolvedToponymMentions.isEmpty) { + //println(unresolvedToponymMentions.size) + val bestMarketAndIndexes = chooseBestMarketAndIndexes(unvisitedMarkets, unresolvedToponymMentions, tour, tppInstance) + + insertIntoTour(tour, bestMarketAndIndexes._1, bestMarketAndIndexes._3, tppInstance) + + unvisitedMarkets.remove(bestMarketAndIndexes._2) + + resolveToponymMentions(bestMarketAndIndexes._1, unresolvedToponymMentions) + } + + //println("Tour had " + tour.size + " market visits.") + + tour.toList + } + + // First index is index in parameter named markets containing the chosen market; second index is optimal position in the tour to insert it + def chooseBestMarketAndIndexes(markets:ArrayList[Market], unresolvedToponymMentions:scala.collection.mutable.HashSet[ToponymMention], tour:ArrayList[MarketVisit], tppInstance:TPPInstance): (Market, Int, Int) = { + + //val mostUnresolvedToponymMentions = markets.map(m => countOverlap(m.locations, unresolvedToponymMentions)).max + val pc = tppInstance.purchaseCoster + val leastAveragePurchaseCost = markets.map(m => m.locations.map(l => pc(m, l._2)).sum/m.locations.size).min ////////// + //println(leastAveragePurchaseCost) + + //val potentialBestMarkets = markets.zipWithIndex.filter(p => countOverlap(p._1.locations, unresolvedToponymMentions) == mostUnresolvedToponymMentions) + val potentialBestMarkets = markets.zipWithIndex.filter(p => p._1.locations.map(l => pc(p._1, l._2)).sum/p._1.locations.size <= (leastAveragePurchaseCost+.000000001)) // Prevent rounding errors + + val r = potentialBestMarkets.map(p => (p, getBestCostIncreaseAndIndex(tour, p._1, tppInstance))).minBy(q => q._2._1) + + (r._1._1, r._1._2, r._2._2) + + + //markets.zipWithIndex.maxBy(p => p._1.locations.map(_._1).map(tm => if(unresolvedToponymMentions.contains(tm)) 1 else 0).sum) // market with the greatest number of goods (types) I haven't puchased yet; but this is bugged, so why does it work well? -- it doesn't seem to anymore after fixing other bugs + //markets.zipWithIndex.maxBy(_._1.locations.map(_._2).map(_.loc).toSet.size) // biggest market by types + //markets.zipWithIndex.maxBy(_._1.size) // biggest market by tokens + } + + def countOverlap(pls:Map[ToponymMention, PotentialLocation], urtms:scala.collection.mutable.HashSet[ToponymMention]): Int = { + var sum = 0 + for(tm <- pls.map(_._1)) + if(urtms.contains(tm)) + sum += 1 + + sum + } + + def insertIntoTour(tour:ArrayList[MarketVisit], market:Market, index:Int, tppInstance:TPPInstance) { + val marketVisit = new MarketVisit(market) + + // Buy everything at the new market + for((topMen, potLoc) <- market.locations) { + marketVisit.purchasedLocations.put(topMen, potLoc) + } + + val pc = tppInstance.purchaseCoster + + // Unbuy goods that have already been purchased elsewhere for the same or cheaper prices + for(existingMarketVisit <- tour) { + var index = 0 + val purLocs = marketVisit.purchasedLocations.toList + while(index < purLocs.size) { + //for((topMen, newPotLoc) <- marketVisit.purchasedLocations) { + val topMen = purLocs(index)._1 + val newPotLoc = purLocs(index)._2 + val prevPotLoc = existingMarketVisit.purchasedLocations.getOrElse(topMen, null) + if(prevPotLoc != null) { + if(pc(existingMarketVisit.market, prevPotLoc) <= pc(marketVisit.market, newPotLoc)) { + //print(purLocs.size+" => ") + marketVisit.purchasedLocations.remove(topMen) + //println(purLocs.size) + } + else { + existingMarketVisit.purchasedLocations.remove(topMen) + } + } + index += 1 + } + } + + if(marketVisit.purchasedLocations.size > 0) + tour.insert(index, marketVisit) // This puts the market in the place that minimizes the added travel cost + } + + def getBestCostIncreaseAndIndex(tour:ArrayList[MarketVisit], market:Market, tppInstance:TPPInstance): (Double, Int) = { + + val tc = tppInstance.travelCoster + + if(tour.size == 0) + (0.0, 0) + else if(tour.size == 1) + (tc(tour(0).market, market), 1) + else { + var minAddedCost = Double.PositiveInfinity + var bestIndex = -1 + for(index <- 0 to tour.size) { + var addedCost = 0.0 + if(index == 0) { + addedCost = tc(market, tour(0).market) + } + else if(index == tour.size) { + addedCost = tc(tour(tour.size-1).market, market) + } + else { + addedCost = tc(tour(index-1).market, market) + tc(market, tour(index).market) - tc(tour(index-1).market, tour(index).market) + } + + if(addedCost < minAddedCost) { + minAddedCost = addedCost + bestIndex = index + } + } + //println(minAddedCost+" at "+bestIndex) + (minAddedCost, bestIndex) + } + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianPurchaseCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianPurchaseCoster.scala new file mode 100644 index 0000000..9c0943e --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianPurchaseCoster.scala @@ -0,0 +1,27 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.topo._ + +class GaussianPurchaseCoster extends PurchaseCoster { + + val VARIANCE_KM = 161.0 + val variance = VARIANCE_KM / 6372.8 + + def g(x:Double, y:Double) = GaussianUtil.g(x,y) + + val storedCosts = new scala.collection.mutable.HashMap[(Int, Int), Double] // (location.id, market.id) => distance + def cost(l:Location, m:Market): Double = { + val key = (l.getId, m.id) + if(storedCosts.contains(key)) + storedCosts(key) + else { + val cost = 1.0-g(l.getRegion.distance(m.centroid)/variance, 0)///max + storedCosts.put(key, cost) + cost + } + } + + def apply(m:Market, potLoc:PotentialLocation): Double = { + cost(potLoc.loc, m) + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianTravelCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianTravelCoster.scala new file mode 100644 index 0000000..4b92531 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianTravelCoster.scala @@ -0,0 +1,15 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.topo._ + +class GaussianTravelCoster extends TravelCoster { + + val VARIANCE_KM = 1610 + val variance = VARIANCE_KM / 6372.8 + + def g(x:Double, y:Double) = GaussianUtil.g(x,y) + + def apply(m1:Market, m2:Market): Double = { + 1.0-g(m1.centroid.distance(m2.centroid)/variance, 0) + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianUtil.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianUtil.scala new file mode 100644 index 0000000..6474704 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/GaussianUtil.scala @@ -0,0 +1,11 @@ +package opennlp.fieldspring.tr.tpp + +object GaussianUtil { + def left(sig_x:Double, sig_y:Double, rho:Double) = 1.0/(2*math.Pi*sig_x*sig_y*math.pow(1-rho*rho,.5)) + + def right(x:Double, y:Double, mu_x:Double, mu_y:Double, sig_x:Double, sig_y:Double, rho:Double) = math.exp(-1.0/(2*(1-rho*rho))*( math.pow(x-mu_x,2)/math.pow(sig_x,2) + math.pow(y-mu_y,2)/math.pow(sig_y,2) - (2*rho*(x-mu_x)*(y-mu_y))/(sig_x*sig_y))) + + def f(x:Double, y:Double, mu_x:Double, mu_y:Double, sig_x:Double, sig_y:Double, rho:Double) = left(sig_x,sig_y,rho) * right(x,y,mu_x,mu_y,sig_x,sig_y,rho) + + def g(x:Double,y:Double) = f(x,y,0,0,1,1,0) +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/GridMarketCreator.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/GridMarketCreator.scala new file mode 100644 index 0000000..c6dc11e --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/GridMarketCreator.scala @@ -0,0 +1,39 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.util._ + +import scala.collection.JavaConversions._ + +class GridMarketCreator(doc:Document[StoredToken], val dpc:Double) extends MarketCreator(doc) { + override def apply:List[Market] = { + val cellNumsToPotLocs = new scala.collection.mutable.HashMap[Int, scala.collection.mutable.HashMap[ToponymMention, PotentialLocation]] + + val docAsArray = TextUtil.getDocAsArrayNoFilter(doc) + + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0) { + val toponym = token.asInstanceOf[Toponym] + var gazIndex = 0 + for(loc <- toponym.getCandidates) { + val topMen = new ToponymMention(doc.getId, tokIndex) + val potLoc = new PotentialLocation(doc.getId, tokIndex, gazIndex, loc) + + val cellNums = TopoUtil.getCellNumbers(loc, dpc) + for(cellNum <- cellNums) { + val potLocs = cellNumsToPotLocs.getOrElse(cellNum, new scala.collection.mutable.HashMap[ToponymMention, PotentialLocation]) + potLocs.put(topMen, potLoc) + cellNumsToPotLocs.put(cellNum, potLocs) + } + gazIndex += 1 + } + } + tokIndex += 1 + } + + (for(p <- cellNumsToPotLocs) yield { + new Market(p._1, p._2.toMap) + }).toList + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/MarketCreator.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/MarketCreator.scala new file mode 100644 index 0000000..5c3b64a --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/MarketCreator.scala @@ -0,0 +1,7 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.text._ + +abstract class MarketCreator(val doc:Document[StoredToken]) { + def apply:List[Market] +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/MaxentPurchaseCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/MaxentPurchaseCoster.scala new file mode 100644 index 0000000..e5e85fd --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/MaxentPurchaseCoster.scala @@ -0,0 +1,61 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.util._ +import opennlp.fieldspring.tr.resolver._ +import opennlp.maxent._ +import opennlp.maxent.io._ +import opennlp.model._ + +import java.io._ + +import scala.collection.JavaConversions._ + +class MaxentPurchaseCoster(corpus:StoredCorpus, modelDirPath:String) extends PurchaseCoster { + + val windowSize = 20 + + val modelDir = new File(modelDirPath) + + val toponymsToModels:Map[String, AbstractModel] = + (for(file <- modelDir.listFiles.filter(_.getName.endsWith(".mxm"))) yield { + val dataInputStream = new DataInputStream(new FileInputStream(file)); + val reader = new BinaryGISModelReader(dataInputStream) + val model = reader.getModel + + (file.getName.dropRight(4).replaceAll("_", " "), model) + }).toMap + + val potLocsToCosts = new scala.collection.mutable.HashMap[PotentialLocation, Double] + + for(doc <- corpus) { + val docAsArray = TextUtil.getDocAsArrayNoFilter(doc) + var tokIndex = 0 + for(token <- docAsArray) { + if(token.isToponym && token.asInstanceOf[Toponym].getAmbiguity > 0 + && toponymsToModels.containsKey(token.getForm)) { + val toponym = token.asInstanceOf[Toponym] + val contextFeatures = TextUtil.getContextFeatures(docAsArray, tokIndex, windowSize, Set[String]()) + + val indexToWeightMap = MaxentResolver.getIndexToWeightMap(toponymsToModels(token.getForm), contextFeatures) + //contextFeatures.foreach(f => print(f+",")); println + for((gazIndex, weight) <- indexToWeightMap.toList.sortBy(_._1)) { + val loc = toponym.getCandidates.get(gazIndex) + val potLoc = new PotentialLocation(doc.getId, tokIndex, gazIndex, loc) + //println(" "+gazIndex+": "+(1.0-weight)) + potLocsToCosts.put(potLoc, 1.0-weight) // Here's where the cost is defined in terms of the probability mass + } + + } + tokIndex += 1 + } + } + + def apply(m:Market, potLoc:PotentialLocation): Double = { + //if(m.locations.map(_._2).toSet.contains(potLoc)) { + potLocsToCosts.getOrElse(potLoc, 1.0) // Not sure what the default cost should be + //} + //else + // Double.PositiveInfinity + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/MultiPurchaseCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/MultiPurchaseCoster.scala new file mode 100644 index 0000000..d60f11b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/MultiPurchaseCoster.scala @@ -0,0 +1,8 @@ +package opennlp.fieldspring.tr.tpp + +class MultiPurchaseCoster(val purchaseCosters:List[PurchaseCoster]) extends PurchaseCoster { + + def apply(m:Market, potLoc:PotentialLocation): Double = { + purchaseCosters.map(pc => pc(m, potLoc)).reduce(_*_) + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/PurchaseCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/PurchaseCoster.scala new file mode 100644 index 0000000..864c459 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/PurchaseCoster.scala @@ -0,0 +1,6 @@ +package opennlp.fieldspring.tr.tpp + +abstract class PurchaseCoster { + + def apply(m:Market, potLoc:PotentialLocation): Double +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleContainmentPurchaseCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleContainmentPurchaseCoster.scala new file mode 100644 index 0000000..eb35545 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleContainmentPurchaseCoster.scala @@ -0,0 +1,11 @@ +package opennlp.fieldspring.tr.tpp + +class SimpleContainmentPurchaseCoster extends PurchaseCoster { + + def apply(m:Market, potLoc:PotentialLocation): Double = { + if(m.locations.map(_._2).toSet.contains(potLoc)) + 1.0 + else + Double.PositiveInfinity + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleDistanceTravelCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleDistanceTravelCoster.scala new file mode 100644 index 0000000..18ca984 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/SimpleDistanceTravelCoster.scala @@ -0,0 +1,32 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.util._ + +class SimpleDistanceTravelCoster extends TravelCoster { + + val storedDistances = new scala.collection.mutable.HashMap[(Int, Int), Double] + val distanceTable = new DistanceTable + + def apply(m1:Market, m2:Market): Double = { + + if(storedDistances.contains((m1.id, m2.id))) { + //println(storedDistances((m1.id, m2.id))) + storedDistances((m1.id, m2.id)) + } + + else { + var minDist = Double.PositiveInfinity + for(loc1 <- m1.locations.map(_._2).map(_.loc)) { + for(loc2 <- m2.locations.map(_._2).map(_.loc)) { + val dist = distanceTable.distance(loc1, loc2) + if(dist < minDist) + minDist = dist + } + } + + storedDistances.put((m1.id, m2.id), minDist) + //println(minDist) + minDist + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/TPPInstance.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/TPPInstance.scala new file mode 100644 index 0000000..5acd885 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/TPPInstance.scala @@ -0,0 +1,48 @@ +package opennlp.fieldspring.tr.tpp + +import opennlp.fieldspring.tr.topo._ + +class TPPInstance(val purchaseCoster:PurchaseCoster, + val travelCoster:TravelCoster) { + + var markets:List[Market] = null +} + +class Market(val id:Int, + val locations:Map[ToponymMention, PotentialLocation]) { + + def size = locations.size + + lazy val centroid: Coordinate = { + val lat:Double = locations.map(_._2.loc.getRegion.getCenter.getLat).sum/locations.size + val lng:Double = locations.map(_._2.loc.getRegion.getCenter.getLng).sum/locations.size + Coordinate.fromRadians(lat, lng) + } +} + +class PotentialLocation(val docId:String, + val tokenIndex:Int, + val gazIndex:Int, + val loc:Location) { + + override def toString: String = { + docId+":"+tokenIndex+":"+gazIndex + } + + override def equals(other:Any):Boolean = { + if(!other.isInstanceOf[PotentialLocation]) + false + else { + val o = other.asInstanceOf[PotentialLocation] + this.docId.equals(o.docId) && this.tokenIndex == o.tokenIndex && this.gazIndex == o.gazIndex && this.loc.equals(o.loc) + } + } + + val S = 41*41 + val C = S*41 + + override def hashCode: Int = { + C * (C + tokenIndex) + S * (S + docId.hashCode) + 41 * (41 * gazIndex) + loc.getId + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/TPPSolver.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/TPPSolver.scala new file mode 100644 index 0000000..683cc0c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/TPPSolver.scala @@ -0,0 +1,118 @@ +package opennlp.fieldspring.tr.tpp + +import java.util.ArrayList +import java.io._ +import javax.xml.datatype._ +import javax.xml.stream._ + +import opennlp.fieldspring.tr.topo._ +import opennlp.fieldspring.tr.util.KMLUtil + +import scala.collection.JavaConversions._ + +abstract class TPPSolver { + + def getUnresolvedToponymMentions(tppInstance:TPPInstance): scala.collection.mutable.HashSet[ToponymMention] = { + val utms = new scala.collection.mutable.HashSet[ToponymMention] + + for(market <- tppInstance.markets) { + for(tm <- market.locations.map(_._1)) { + utms.add(tm) + } + } + + utms + } + + def resolveToponymMentions(market:Market, unresolvedToponymMentions:scala.collection.mutable.HashSet[ToponymMention]) { + //print(unresolvedToponymMentions.size) + for(tm <- market.locations.map(_._1)) { + unresolvedToponymMentions.remove(tm) + } + //println(" --> " + unresolvedToponymMentions.size) + } + + def getSolutionMap(tour:List[MarketVisit]): Map[(String, Int), Int] = { + + val s = new scala.collection.mutable.HashSet[(String, Int)] + + (for(marketVisit <- tour) yield { + (for(potLoc <- marketVisit.purchasedLocations.map(_._2)) yield { + //if(s.contains((potLoc.docId, potLoc.tokenIndex))) + //println("Already had "+potLoc.docId+":"+potLoc.tokenIndex) + s.add((potLoc.docId, potLoc.tokenIndex)) + ((potLoc.docId, potLoc.tokenIndex), potLoc.gazIndex) + }) + }).flatten.toMap + } + + def writeKML(tour:List[MarketVisit], filename:String) { + val bw = new BufferedWriter(new FileWriter(filename)) + val factory = XMLOutputFactory.newInstance + val out = factory.createXMLStreamWriter(bw) + + KMLUtil.writeHeader(out, "tour") + + var prevMV:MarketVisit = null + var index = 0 + for(mv <- tour) { + val style = (index % 4) match { + case 0 => "yellow" + case 1 => "green" + case 2 => "blue" + case 3 => "white" + } + + for(purLoc <- mv.purchasedLocations.map(_._2)) { + //for(coord <- purLoc.loc.getRegion.getRepresentatives) { + val coord = purLoc.loc.getRegion.getCenter + KMLUtil.writePinPlacemark(out, purLoc.loc.getName+"("+purLoc+")", coord, style) + //} + } + if(index >= 1) { + KMLUtil.writeArcLinePlacemark(out, prevMV.centroid, mv.centroid) + } + + prevMV = mv + index += 1 + } + + KMLUtil.writeFooter(out) + + out.close + } + + def apply(tppInstance:TPPInstance): List[MarketVisit] +} + +class MarketVisit(val market:Market) { + val purchasedLocations = new scala.collection.mutable.HashMap[ToponymMention, PotentialLocation] + + // This should only be accessed after purchasedLocations has stopped changing: + lazy val centroid: Coordinate = { + val lat:Double = purchasedLocations.map(_._2.loc.getRegion.getCenter.getLat).sum/purchasedLocations.size + val lng:Double = purchasedLocations.map(_._2.loc.getRegion.getCenter.getLng).sum/purchasedLocations.size + Coordinate.fromRadians(lat, lng) + } +} + +class ToponymMention(val docId:String, + val tokenIndex:Int) { + + override def toString: String = { + docId+":"+tokenIndex + } + + override def equals(other:Any):Boolean = { + if(!other.isInstanceOf[ToponymMention]) + false + else { + val o = other.asInstanceOf[ToponymMention] + this.docId.equals(o.docId) && this.tokenIndex == o.tokenIndex + } + } + + override def hashCode: Int = { + 41 * (41 + tokenIndex) + docId.hashCode + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/tpp/TravelCoster.scala b/src/main/scala/opennlp/fieldspring/tr/tpp/TravelCoster.scala new file mode 100644 index 0000000..4a09e89 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/tpp/TravelCoster.scala @@ -0,0 +1,6 @@ +package opennlp.fieldspring.tr.tpp + +abstract class TravelCoster { + + def apply(m1:Market, m2:Market): Double +} diff --git a/src/main/scala/opennlp/fieldspring/tr/util/DistanceTable.scala b/src/main/scala/opennlp/fieldspring/tr/util/DistanceTable.scala new file mode 100644 index 0000000..8432eb2 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/DistanceTable.scala @@ -0,0 +1,32 @@ +package opennlp.fieldspring.tr.util + +import opennlp.fieldspring.tr.topo._ + +class DistanceTable { + + val storedDistances = new scala.collection.mutable.HashMap[(Int, Int), Double] + + def distance(l1:Location, l2:Location): Double = { + var leftLoc = l1 + var rightLoc = l2 + if(l1.getId > l2.getId) { + leftLoc = l2 + rightLoc = l1 + } + + if(leftLoc.getRegion.getRepresentatives.size == 1 && rightLoc.getRegion.getRepresentatives.size == 1) { + leftLoc.distance(rightLoc) + } + else { + val key = (leftLoc.getId, rightLoc.getId) + if(storedDistances.contains(key)) { + storedDistances(key) + } + else { + val dist = leftLoc.distance(rightLoc) + storedDistances.put(key, dist) + dist + } + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/util/LogUtil.scala b/src/main/scala/opennlp/fieldspring/tr/util/LogUtil.scala new file mode 100644 index 0000000..6bafdb4 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/LogUtil.scala @@ -0,0 +1,143 @@ +package opennlp.fieldspring.tr.util + +import opennlp.fieldspring.tr.topo._ + +object LogUtil { + + val DPC = 1.0 + + val DOC_PREFIX = "Document " + val PRED_CELL_RANK_PREFIX = " Predicted cell (at rank " + val PRED_CELL_KL_PREFIX = ", kl-div " + val PRED_CELL_BOTTOM_LEFT_COORD_PREFIX = "): GeoCell((" + val TRUE_COORD_PREFIX = " at (" + val PRED_COORD_PREFIX = " predicted cell center at (" + val NEIGHBOR_PREFIX = " close neighbor: (" + + val CELL_BOTTOM_LEFT_COORD_PREFIX = "Cell (" + val NGRAM_DIST_PREFIX = "unseen mass, " + val ngramAndCountRE = """^(\S+)\=(\S+)$""".r + + def parseLogFile(filename: String): List[LogFileParseElement]/*List[(String, Coordinate, Coordinate, List[(Coordinate, Int)])]*/ = { + val lines = scala.io.Source.fromFile(filename).getLines + + var docName:String = null + var neighbors:List[(Coordinate, Int)] = null + var predCells:List[(Int, Double, Coordinate)] = null + var trueCoord:Coordinate = null + var predCoord:Coordinate = null + + (for(line <- lines) yield { + if(line.startsWith("#")) { + + if(line.contains(DOC_PREFIX)) { + var startIndex = line.indexOf(DOC_PREFIX) + DOC_PREFIX.length + var endIndex = line.indexOf(TRUE_COORD_PREFIX, startIndex) + docName = line.slice(startIndex, endIndex) + if(docName.contains("/")) docName = docName.drop(docName.indexOf("/")+1) + + startIndex = line.indexOf(TRUE_COORD_PREFIX) + TRUE_COORD_PREFIX.length + endIndex = line.indexOf(")", startIndex) + val rawCoords = line.slice(startIndex, endIndex).split(",") + trueCoord = Coordinate.fromDegrees(rawCoords(0).toDouble, rawCoords(1).toDouble) + + predCells = List() + neighbors = List() + + None + } + + else if(line.contains(PRED_CELL_RANK_PREFIX)) { + val rankStartIndex = line.indexOf(PRED_CELL_RANK_PREFIX) + PRED_CELL_RANK_PREFIX.length + val rankEndIndex = line.indexOf(PRED_CELL_KL_PREFIX, rankStartIndex) + val rank = line.slice(rankStartIndex, rankEndIndex).toInt + val klStartIndex = rankEndIndex + PRED_CELL_KL_PREFIX.length + val klEndIndex = line.indexOf(PRED_CELL_BOTTOM_LEFT_COORD_PREFIX, klStartIndex) + val kl = line.slice(klStartIndex, klEndIndex).toDouble + val blCoordStartIndex = klEndIndex + PRED_CELL_BOTTOM_LEFT_COORD_PREFIX.length + val blCoordEndIndex = line.indexOf(")", blCoordStartIndex) + val rawBlCoord = line.slice(blCoordStartIndex, blCoordEndIndex).split(",") + val blCoord = Coordinate.fromDegrees(rawBlCoord(0).toDouble, rawBlCoord(1).toDouble) + + predCells = predCells ::: ((rank, kl, blCoord) :: Nil) + + None + } + + else if(line.contains(NEIGHBOR_PREFIX)) { + val startIndex = line.indexOf(NEIGHBOR_PREFIX) + NEIGHBOR_PREFIX.length + val endIndex = line.indexOf(")", startIndex) + val rawCoords = line.slice(startIndex, endIndex).split(",") + val curNeighbor = Coordinate.fromDegrees(rawCoords(0).toDouble, rawCoords(1).toDouble) + val rankStartIndex = line.indexOf("#", 1)+1 + val rankEndIndex = line.indexOf(" ", rankStartIndex) + val rank = line.slice(rankStartIndex, rankEndIndex).toInt + + neighbors = neighbors ::: ((curNeighbor, rank) :: Nil) + + None + } + + else if(line.contains(PRED_COORD_PREFIX)) { + val startIndex = line.indexOf(PRED_COORD_PREFIX) + PRED_COORD_PREFIX.length + val endIndex = line.indexOf(")", startIndex) + val rawCoords = line.slice(startIndex, endIndex).split(",") + predCoord = Coordinate.fromDegrees(rawCoords(0).toDouble, rawCoords(1).toDouble) + + Some(new LogFileParseElement(docName, trueCoord, predCoord, predCells, neighbors)) + } + + else None + } + else None + }).flatten.toList + } + + def getNgramDists(filename: String): Map[Int, Map[String, Double]] = { + val lines = scala.io.Source.fromFile(filename).getLines + + (for(line <- lines) yield { + if(line.startsWith(CELL_BOTTOM_LEFT_COORD_PREFIX)) { + val blCoordStartIndex = CELL_BOTTOM_LEFT_COORD_PREFIX.length + val blCoordEndIndex = line.indexOf(")", blCoordStartIndex) + val rawBlCoord = line.slice(blCoordStartIndex, blCoordEndIndex).split(",") + val cellNum = TopoUtil.getCellNumber(rawBlCoord(0).toDouble, rawBlCoord(1).toDouble, DPC) + + val ngramDistRawStartIndex = line.indexOf(NGRAM_DIST_PREFIX, blCoordEndIndex) + NGRAM_DIST_PREFIX.length + val ngramDistRawEndIndex = line.indexOf(")", ngramDistRawStartIndex) + val dist = + (for(token <- line.slice(ngramDistRawStartIndex, ngramDistRawEndIndex).split(" ")) yield { + if(ngramAndCountRE.findFirstIn(token) != None) { + val ngramAndCountRE(ngram, count) = token + Some((ngram, count.toDouble)) + } + else + None + }).flatten.toMap + + Some((cellNum, dist)) + } + else + None + }).flatten.toMap + } + +} + +class LogFileParseElement( + val docName: String, + val trueCoord: Coordinate, + val predCoord: Coordinate, + val predCells: List[(Int, Double, Coordinate)], + val neighbors: List[(Coordinate, Int)]) { + + def getProbDistOverPredCells(knn:Int, dpc:Double): List[(Int, Double)] = { + var sum = 0.0 + val myKNN = if(knn < 0) predCells.size else knn + (for((rank, kl, blCoord) <- predCells.take(myKNN)) yield { + val unnormalized = math.exp(-kl) + sum += unnormalized + (TopoUtil.getCellNumber(blCoord, dpc), unnormalized) + }).map(p => (p._1, p._2/sum)).toList + } +} diff --git a/src/main/scala/opennlp/fieldspring/tr/util/StopwordUtil.scala b/src/main/scala/opennlp/fieldspring/tr/util/StopwordUtil.scala new file mode 100644 index 0000000..c581b17 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/StopwordUtil.scala @@ -0,0 +1,14 @@ +package opennlp.fieldspring.tr.util + +import java.io._ + +object StopwordUtil { + + def populateStoplist(filename: String): Set[String] = { + var stoplist:Set[String] = Set() + io.Source.fromFile(filename).getLines.foreach(line => stoplist += line) + stoplist.toSet() + stoplist + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/util/TextUtil.scala b/src/main/scala/opennlp/fieldspring/tr/util/TextUtil.scala new file mode 100644 index 0000000..3cc26c5 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/TextUtil.scala @@ -0,0 +1,80 @@ +package opennlp.fieldspring.tr.util + +import opennlp.fieldspring.tr.text._ +import opennlp.fieldspring.tr.text.prep._ +import opennlp.fieldspring.tr.text.io._ + +import scala.collection.JavaConversions._ + +object TextUtil { + + val alphanumRE = """^[a-zA-Z0-9 ]+$""".r + + def getDocAsArray(doc:Document[Token]): Array[Token] = { + (for(sent <- doc) yield { + (for(token <- sent.filter(t => alphanumRE.findFirstIn(t.getForm) != None)) yield { + token + }) + }).flatten.toArray + } + + def getDocAsArray(doc:Document[StoredToken]): Array[StoredToken] = { + (for(sent <- doc) yield { + (for(token <- sent.filter(t => alphanumRE.findFirstIn(t.getForm) != None)) yield { + token + }) + }).flatten.toArray + } + + def getDocAsArrayNoFilter(doc:Document[StoredToken]): Array[StoredToken] = { + (for(sent <- doc) yield { + (for(token <- sent/*.filter(t => alphanumRE.findFirstIn(t.getForm) != None)*/) yield { + token + }) + }).flatten.toArray + } + + def getContextFeatures(docAsArray:Array[Token], tokIndex:Int, windowSize:Int, stoplist:Set[String]): Array[String] = { + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(docAsArray.length, tokIndex + windowSize + 1) + + Array.concat(docAsArray.slice(startIndex, tokIndex).map(_.getForm), + docAsArray.slice(tokIndex + 1, endIndex).map(_.getForm)).filterNot(stoplist(_)) + } + + def getContextFeatures(docAsArray:Array[StoredToken], tokIndex:Int, windowSize:Int, stoplist:Set[String]): Array[String] = { + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(docAsArray.length, tokIndex + windowSize + 1) + + Array.concat(docAsArray.slice(startIndex, tokIndex).map(_.getForm), + docAsArray.slice(tokIndex + 1, endIndex).map(_.getForm)).filterNot(stoplist(_)) + } + + def getContext(docAsArray:Array[Token], tokIndex:Int, windowSize:Int): String = { + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(docAsArray.length, tokIndex + windowSize + 1) + + (docAsArray.slice(startIndex, tokIndex).map(_.getOrigForm).mkString("", " ", "") + + " [["+docAsArray(tokIndex).getOrigForm+"]] " + + docAsArray.slice(tokIndex + 1, endIndex).map(_.getOrigForm).mkString("", " ", "")).trim + } + + def getContext(docAsArray:Array[StoredToken], tokIndex:Int, windowSize:Int): String = { + val startIndex = math.max(0, tokIndex - windowSize) + val endIndex = math.min(docAsArray.length, tokIndex + windowSize + 1) + + (docAsArray.slice(startIndex, tokIndex).map(_.getOrigForm).mkString("", " ", "") + + " [["+docAsArray(tokIndex).getOrigForm+"]] " + + docAsArray.slice(tokIndex + 1, endIndex).map(_.getOrigForm).mkString("", " ", "")).trim + } + + def stripPunc(s: String): String = { + var toReturn = s.trim + while(toReturn.length > 0 && !Character.isLetterOrDigit(toReturn.charAt(0))) + toReturn = toReturn.substring(1) + while(toReturn.length > 0 && !Character.isLetterOrDigit(toReturn.charAt(toReturn.length-1))) + toReturn = toReturn.substring(0,toReturn.length()-1) + toReturn + } + +} diff --git a/src/main/scala/opennlp/fieldspring/tr/util/cluster/KMeans.scala b/src/main/scala/opennlp/fieldspring/tr/util/cluster/KMeans.scala new file mode 100644 index 0000000..194ab15 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/cluster/KMeans.scala @@ -0,0 +1,86 @@ +/** + * Copyright (C) 2010 Travis Brown, The University of Texas at Austin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package opennlp.fieldspring.tr.util.cluster + + +import java.io._ +import scala.math._ +import scala.collection.immutable.Vector +import scala.collection.mutable.Buffer +import scala.collection.JavaConversions._ +import scala.util.Random + +trait Geometry[A] { + def distance(x: A)(y: A): Double + def centroid(ps: Seq[A]): A + + def nearest(cs: Seq[A], p: A): Int = + cs.map(distance(p)(_)).zipWithIndex.min._2 +} + +trait Clusterer { + def clusterList[A](ps: java.util.List[A], k: Int)(implicit g: Geometry[A]): java.util.List[A] + def cluster[A](ps: Seq[A], k: Int)(implicit g: Geometry[A]): Seq[A] +} + +class KMeans extends Clusterer { + def clusterList[A](ps: java.util.List[A], k: Int)(implicit g: Geometry[A]): java.util.List[A] = { + cluster(ps.toIndexedSeq, k)(g) + } + + def cluster[A](ps: Seq[A], k: Int)(implicit g: Geometry[A]): Seq[A] = { + var ips = ps.toIndexedSeq + var cs = init(ips, k) + var as = ps.map(g.nearest(cs, _)) + var done = false + val clusters = IndexedSeq.fill(k)(Buffer[A]()) + while (!done) { + clusters.foreach(_.clear) + + as.zipWithIndex.foreach { case (i, j) => + clusters(i) += ips(j) + } + + cs = clusters.map(g.centroid(_)) + + val bs = ips.map(g.nearest(cs, _)) + done = as == bs + as = bs + } + cs + } + + def init[A](ps: Seq[A], k: Int): IndexedSeq[A] = { + (1 to k).map(_ => ps(Random.nextInt(ps.size))) + } +} + +object EuclideanGeometry { + type Point = (Double, Double) + + implicit def g = new Geometry[Point] { + def distance(x: Point)(y: Point): Double = + sqrt(pow(x._1 - y._1, 2) + pow(x._2 - y._2, 2)) + + def centroid(ps: Seq[Point]): Point = { + def pointPlus(x: Point, y: Point) = (x._1 + y._1, x._2 + y._2) + ps.reduceLeft(pointPlus) match { + case (a, b) => (a / ps.size, b / ps.size) + } + } + } +} + diff --git a/src/main/scala/opennlp/fieldspring/tr/util/sanity/CandidateCheck.scala b/src/main/scala/opennlp/fieldspring/tr/util/sanity/CandidateCheck.scala new file mode 100644 index 0000000..b2070c1 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/tr/util/sanity/CandidateCheck.scala @@ -0,0 +1,57 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util.sanity + +import java.io._ +import scala.collection.JavaConversions._ + +import opennlp.fieldspring.tr.text.Corpus +import opennlp.fieldspring.tr.text.Toponym +import opennlp.fieldspring.tr.text.io.TrXMLDirSource +import opennlp.fieldspring.tr.text.prep.OpenNLPTokenizer +import opennlp.fieldspring.tr.topo.Location + +object CandidateCheck extends App { + override def main(args: Array[String]) { + val tokenizer = new OpenNLPTokenizer + val corpus = Corpus.createStreamCorpus + val cands = scala.collection.mutable.Map[java.lang.String, java.util.List[Location]]() + + corpus.addSource(new TrXMLDirSource(new File(args(0)), tokenizer)) + corpus.foreach { _.foreach { _.getToponyms.foreach { + case toponym: Toponym => { + if (!cands.contains(toponym.getForm)) { + //println("Doesn't contain: " + toponym.getForm) + cands(toponym.getForm) = toponym.getCandidates + } else { + val prev = cands(toponym.getForm) + val here = toponym.getCandidates + //println("Contains: " + toponym.getForm) + if (prev.size != here.size) { + println("=====Size error for " + toponym.getForm + ": " + prev.size + " " + here.size) + } else { + prev.zip(here).foreach { case (p, h) => + println(p.getRegion.getCenter + " " + h.getRegion.getCenter) + //case (p, h) if p != h => println("=====Mismatch for " + toponym.getForm) + //case _ => () + } + } + } + } + }}} + } +} + diff --git a/src/main/scala/opennlp/fieldspring/util/MeteredTask.scala b/src/main/scala/opennlp/fieldspring/util/MeteredTask.scala new file mode 100644 index 0000000..77e4ce5 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/MeteredTask.scala @@ -0,0 +1,130 @@ +/////////////////////////////////////////////////////////////////////////////// +// MeteredTask.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import osutil._ +import printutil.errprint +import textutil._ +import timeutil.format_minutes_seconds + +///////////////////////////////////////////////////////////////////////////// +// Metered Tasks // +///////////////////////////////////////////////////////////////////////////// + +/** + * Class for tracking number of items processed in a long task, and + * reporting periodic status messages concerning how many items + * processed and how much real time involved. Call `item_processed` + * every time you have processed an item. + * + * @param item_name Generic name of the items being processed, for the + * status messages + * @param verb Transitive verb in its -ing form indicating what is being + * done to the items + * @param secs_between_output Number of elapsed seconds between successive + * periodic status messages + */ +class MeteredTask(item_name: String, verb: String, + secs_between_output: Double = 15, maxtime: Double = 0.0) { + val plural_item_name = pluralize(item_name) + var items_processed = 0 + // Whether we've already printed stats after the most recent item + // processed + var printed_stats = false + errprint("--------------------------------------------------------") + val first_time = curtimesecs() + var last_time = first_time + errprint("Beginning %s %s at %s.", verb, plural_item_name, + humandate_full(first_time)) + errprint("") + + def num_processed() = items_processed + + def elapsed_time() = curtimesecs() - first_time + + def item_unit() = { + if (items_processed == 1) + item_name + else + plural_item_name + } + + def print_elapsed_time_and_rate(curtime: Double = curtimesecs(), + nohuman: Boolean = false) { + /* Don't do anything if already printed for this item. */ + if (printed_stats) + return + printed_stats = true + val total_elapsed_secs = curtime - first_time + val attime = + if (nohuman) "" else "At %s: " format humandate_time(curtime) + errprint("%sElapsed time: %s, %s %s processed", + attime, + format_minutes_seconds(total_elapsed_secs, hours=false), + items_processed, item_unit()) + val items_per_second = items_processed.toDouble / total_elapsed_secs + val seconds_per_item = total_elapsed_secs / items_processed + errprint("Processing rate: %s items per second (%s seconds per item)", + format_float(items_per_second), + format_float(seconds_per_item)) + } + + def item_processed() = { + val curtime = curtimesecs() + items_processed += 1 + val total_elapsed_secs = curtime - first_time + val last_elapsed_secs = curtime - last_time + printed_stats = false + if (last_elapsed_secs >= secs_between_output) { + // Rather than directly recording the time, round it down to the + // nearest multiple of secs_between_output; else we will eventually + // see something like 0, 15, 45, 60, 76, 91, 107, 122, ... + // rather than like 0, 15, 45, 60, 76, 90, 106, 120, ... + val rounded_elapsed = + ((total_elapsed_secs / secs_between_output).toInt * + secs_between_output) + last_time = first_time + rounded_elapsed + print_elapsed_time_and_rate(curtime) + } + if (maxtime > 0 && total_elapsed_secs >= maxtime) { + errprint("Maximum time reached, interrupting processing") + print_elapsed_time_and_rate(curtime) + true + } + else false + } + + /** + * Output a message indicating that processing is finished, along with + * stats given number of items processed, time, and items/time, time/item. + * The total message looks like "Finished _doing_ _items_." where "doing" + * comes from the `doing` parameter to this function and should be a + * lower-case transitive verb in the -ing form. The actual value of + * "items" comes from the `item_name` constructor parameter to this + * class. */ + def finish() = { + val curtime = curtimesecs() + errprint("") + errprint("Finished %s %s at %s.", verb, plural_item_name, + humandate_full(curtime)) + print_elapsed_time_and_rate(curtime, nohuman = true) + errprint("--------------------------------------------------------") + } +} + diff --git a/src/main/scala/opennlp/fieldspring/util/Serializer.scala b/src/main/scala/opennlp/fieldspring/util/Serializer.scala new file mode 100644 index 0000000..5998ec0 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/Serializer.scala @@ -0,0 +1,39 @@ +/////////////////////////////////////////////////////////////////////////////// +// Serializer.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +/** A type class for converting to and from values in serialized form. */ +@annotation.implicitNotFound(msg = "No implicit Serializer defined for ${T}.") +trait Serializer[T] { + def deserialize(foo: String): T + def serialize(foo: T): String + /** + * Validate the serialized form of the string. Return true if valid, + * false otherwise. Can be overridden for efficiency. By default, + * simply tries to deserialize, and checks whether an error was thrown. + */ + def validate_serialized_form(foo: String): Boolean = { + try { + deserialize(foo) + } catch { + case _ => return false + } + return true + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/WikiRelFreqs.scala b/src/main/scala/opennlp/fieldspring/util/WikiRelFreqs.scala new file mode 100644 index 0000000..52e20b5 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/WikiRelFreqs.scala @@ -0,0 +1,52 @@ +/////////////////////////////////////////////////////////////////////////////// +// WikiRelFreqs.scala +// +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +object WikiRelFreqs extends App { + + val geoFreqs = getFreqs(args(0)) + val allFreqs = getFreqs(args(1)) + + val relFreqs = allFreqs.map(p => (p._1, geoFreqs.getOrElse(p._1, 0.0) / p._2)).toList.sortWith((x, y) => if(x._2 != y._2) x._2 > y._2 else x._1 < y._1) + + relFreqs.foreach(println) + + def getFreqs(filename:String):Map[String, Double] = { + val wordCountRE = """^(\w+)\s=\s(\d+)$""".r + val lines = scala.io.Source.fromFile(filename).getLines + val freqs = new scala.collection.mutable.HashMap[String, Long] + var total = 0l + var lineCount = 0 + + for(line <- lines) { + if(wordCountRE.findFirstIn(line) != None) { + val wordCountRE(word, count) = line + val lowerWord = word.toLowerCase + val oldCount = freqs.getOrElse(lowerWord, 0l) + freqs.put(lowerWord, oldCount + count.toInt) + total += count.toInt + } + if(lineCount % 10000000 == 0) + println(filename+" "+lineCount) + lineCount += 1 + } + + freqs.map(p => (p._1, p._2.toDouble / total)).toMap + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/argparser.scala b/src/main/scala/opennlp/fieldspring/util/argparser.scala new file mode 100644 index 0000000..98cc0b8 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/argparser.scala @@ -0,0 +1,1421 @@ +/////////////////////////////////////////////////////////////////////////////// +// argparser.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.util.control.Breaks._ +import scala.collection.mutable + +import org.clapper.argot._ + +/** + This module implements an argument parser for Scala, which handles + both options (e.g. --output-file foo.txt) and positional arguments. + It is built on top of Argot and has an interface that is designed to + be quite similar to the argument-parsing mechanisms in Python. + + The parser tries to be easy to use, and in particular to emulate the + field-based method of accessing values used in Python. This leads to + the need to be very slightly tricky in the way that arguments are + declared; see below. + + The basic features here that Argot doesn't have are: + + (1) Argument specification is simplified through the use of optional + parameters to the argument-creation functions. + (2) A simpler and easier-to-use interface is provided for accessing + argument values given on the command line, so that the values can be + accessed as simple field references (re-assignable if needed). + (3) Default values can be given. + (4) "Choice"-type arguments can be specified (only one of a fixed number + of choices allowed). This can include multiple aliases for a given + choice. + (5) The "flag" type is simplified to just handle boolean flags, which is + the most common usage. (FIXME: Eventually perhaps we should consider + also allowing more general Argot-type flags.) + (6) The help text can be split across multiple lines using multi-line + strings, and extra whitespace/carriage returns will be absorbed + appropriately. In addition, directives can be given, such as %default + for the default value, %choices for the list of choices, %metavar + for the meta-variable (see below), %prog for the name of the program, + %% for a literal percent sign, etc. + (7) A reasonable default value is provided for the "meta-variable" + parameter for options (which specifies the type of the argument). + (8) Conversion functions are easier to specify, since one function suffices + for all types of arguments. + (9) A conversion function is provided for type Boolean for handling + valueful boolean options, where the value can be any of "yes, no, + y, n, true, false, t, f, on, off". + (10) There is no need to manually catch exceptions (e.g. from usage errors) + if you don't want to. By default, exceptions are caught + automatically and printed out nicely, and then the program exits + with return code 1. You can turn this off if you want to handle + exceptions yourself. + + In general, to parse arguments, you first create an object of type + ArgParser (call it `ap`), and then add options to it by calling + functions, typically: + + -- ap.option[T]() for a single-valued option of type T + -- ap.multiOption[T]() for a multi-valued option of type T (i.e. the option + can be specified multiple times on the command line, and all such + values will be accumulated into a List) + -- ap.flag() for a boolean flag + -- ap.positional[T]() for a positional argument (coming after all options) + -- ap.multiPositional[T]() for a multi-valued positional argument (i.e. + eating up any remaining positional argument given) + + There are two styles for accessing the values of arguments specified on the + command line. One possibility is to simply declare arguments by calling the + above functions, then parse a command line using `ap.parse()`, then retrieve + values using `ap.get[T]()` to get the value of a particular argument. + + However, this style is not as convenient as we'd like, especially since + the type must be specified. In the original Python API, once the + equivalent calls have been made to specify arguments and a command line + parsed, the argument values can be directly retrieved from the ArgParser + object as if they were fields; e.g. if an option `--outfile` were + declared using a call like `ap.option[String]("outfile", ...)`, then + after parsing, the value could simply be fetched using `ap.outfile`, + and assignment to `ap.outfile` would be possible, e.g. if the value + is to be defaulted from another argument. + + This functionality depends on the ability to dynamically intercept + field references and assignments, which doesn't currently exist in + Scala. However, it is possible to achieve a near-equivalent. It works + like this: + + 1) Functions like `ap.option[T]()` are set up so that the first time they + are called for a given ArgParser object and argument, they will note + the argument, and return the default value of this argument. If called + again after parsing, however, they will return the value specified in + the command line (or the default if no value was specified). (If called + again *before* parsing, they simply return the default value, as before.) + + 2) A class, e.g. ProgParams, is created to hold the values returned from + the command line. This class typically looks like this: + + class ProgParams(ap: ArgParser) { + var outfile = ap.option[String]("outfile", "o", ...) + var verbose = ap.flag("verbose", "v", ...) + ... + } + + 3) To parse a command line, we proceed as follows: + + a) Create an ArgParser object. + b) Create an instance of ProgParams, passing in the ArgParser object. + c) Call `parse()` on the ArgParser object, to parse a command line. + d) Create *another* instance of ProgParams, passing in the *same* + ArgParser object. + e) Now, the argument values specified on the command line can be + retrieved from this new instance of ProgParams simply using field + accesses, and new values can likewise be set using field accesses. + + Note how this actually works. When the first instance of ProgParams is + created, the initialization of the variables causes the arguments to be + specified on the ArgParser -- and the variables have the default values + of the arguments. When the second instance is created, after parsing, and + given the *same* ArgParser object, the respective calls to "initialize" + the arguments have no effect, but now return the values specified on the + command line. Because these are declared as `var`, they can be freely + re-assigned. Furthermore, because no actual reflection or any such thing + is done, the above scheme will work completely fine if e.g. ProgParams + subclasses another class that also declares some arguments (e.g. to + abstract out common arguments for multiple applications). In addition, + there is no problem mixing and matching the scheme described here with the + conceptually simpler scheme where argument values are retrieved using + `ap.get[T]()`. + + Work being considered: + + (1) Perhaps most important: Allow grouping of options. Groups would be + kept together in the usage message, displayed under the group name. + (2) Add constructor parameters to allow for specification of the other + things allowed in Argot, e.g. pre-usage, post-usage, whether to sort + arguments in the usage message or leave as-is. + (3) Provide an option to control how meta-variable generation works. + Normally, an unspecified meta-variable is derived from the + canonical argument name in all uppercase (e.g. FOO or STRATEGY), + but some people e.g. might instead want it derived from the argument + type (e.g. NUM or STRING). This would be controlled with a + constructor parameter to ArgParser. + (4) Add converters for other basic types, e.g. Float, Char, Byte, Short. + (5) Allow for something similar to Argot's typed flags (and related + generalizations). I'd call it `ap.typedFlag[T]()` or something + similar. But rather than use Argot's interface of "on" and "off" + flags, I'd prefer to follow the lead of Python's argparse, allowing + the "destination" argument name to be specified independently of + the argument name as it appears on the command line, so that + multiple arguments could write to the same place. I'd also add a + "const" parameter that stores an arbitrary constant value if a + flag is tripped, so that you could simulate the equivalent of a + limited-choice option using multiple flag options. In addition, + I'd add an optional "action" parameter that is passed in the old + and new values and returns the actual value to be stored; that + way, incrementing/decrementing or whatever could be implemented. + Note that I believe it's better to separate the conversion and + action routines, unlike what Argot does -- that way, the action + routine can work with properly-typed values and doesn't have to + worry about how to convert them to/from strings. This also makes + it possible to supply action routines for all the various categories + of arguments (e.g. flags, options, multi-options), while keeping + the conversion routines simple -- the action routines necessarily + need to be differently-typed at least for single vs. multi-options, + but the conversion routines shouldn't have to worry about this. + In fact, to truly implement all the generality of Python's 'argparse' + routine, we'd want expanded versions of option[], multiOption[], + etc. that take both a source type (to which the raw values are + initially converted) and a destination type (final type of the + value stored), so that e.g. a multiOption can sum values into a + single 'accumulator' destination argument, or a single String + option can parse a classpath into a List of File objects, or + whatever. (In fact, however, I think it's better to dispense + with such complexity in the ArgParser and require instead that the + calling routine deal with it on its own. E.g. there's absolutely + nothing preventing a calling routine using field-style argument + values from declaring extra vars to hold destination values and + then e.g. simply fetching the classpath value, parsing it and + storing it, or fetching all values of a multiOption and summing + them. The minimal support for the Argot example of increment and + decrement flags would be something like a call `ap.multiFlag` + that accumulates a list of Boolean "true" values, one per + invocation. Then we just count the number of increment flags and + number of decrement flags given. If we cared about the relative + way that these two flags were interleaved, we'd need a bit more + support -- (1) a 'destination' argument to allow two options to + store into the same place; (2) a typed `ap.multiFlag[T]`; (3) + a 'const' argument to specify what value to store. Then our + destination gets a list of exactly which flags were invoked and + in what order. On the other hand, it's easily arguable that no + program should have such intricate option processing that requires + this -- it's unlikely the user will have a good understanding + of what these interleaved flags end up doing. + */ + +package object argparser { + /* + + NOTE: At one point, in place of the second scheme described above, there + was a scheme involving reflection. This didn't work as well, and ran + into various problems. One such problem is described here, because it + shows some potential limitations/bugs in Scala. In particular, in that + scheme, calls to `ap.option[T]()` and similar were declared using `def` + instead of `var`, and the first call to them was made using reflection. + Underlyingly, all defs, vars and vals look like functions, and fields + declared as `def` simply look like no-argument functions. Because the + return type can vary and generally is a simple type like Int or String, + there was no way to reliably recognize defs of this sort from other + variables, functions, etc. in the object. To make this recognition + reliable, I tried wrapping the return value in some other object, with + bidirectional implicit conversions to/from the wrapped value, something + like this: + + class ArgWrap[T](vall: T) extends ArgAny[T] { + def value = vall + def specified = true + } + + implicit def extractValue[T](arg: ArgAny[T]): T = arg.value + + implicit def wrapValue[T](vall: T): ArgAny[T] = new ArgWrap(vall) + + Unfortunately, this didn't work, for somewhat non-obvious reasons. + Specifically, the problems were: + + (1) Type unification between wrapped and unwrapped values fails. This is + a problem e.g. if I try to access a field value in an if-then-else + statements like this, I run into problems: + + val files = + if (use_val) + Params.files + else + Seq[String]() + + This unifies to AnyRef, not Seq[String], even if Params.files wraps + a Seq[String]. + + (2) Calls to methods on the wrapped values (e.g. strings) can fail in + weird ways. For example, the following fails: + + def words = ap.option[String]("words") + + ... + + val split_words = ap.words.split(',') + + However, it doesn't fail if the argument passed in is a string rather + than a character. In this case, if I have a string, I *can* call + split with a character as an argument - perhaps this fails in the case + of an implicit conversion because there is a split() implemented on + java.lang.String that takes only strings, whereas split() that takes + a character is stored in StringOps, which itself is handled using an + implicit conversion. + */ + + /** + * Implicit conversion function for Ints. Automatically selected + * for Int-type arguments. + */ + implicit def convertInt(rawval: String, name: String, ap: ArgParser) = { + try { rawval.toInt } + catch { + case e: NumberFormatException => + throw new ArgParserConversionException( + """Cannot convert argument "%s" to an integer.""" format rawval) + } + } + + /** + * Implicit conversion function for Doubles. Automatically selected + * for Double-type arguments. + */ + implicit def convertDouble(rawval: String, name: String, ap: ArgParser) = { + try { rawval.toDouble } + catch { + case e: NumberFormatException => + throw new ArgParserConversionException( + """Cannot convert argument "%s" to a floating-point number.""" + format rawval) + } + } + + /** + * Check restrictions on `value`, the parsed value for option named + * `name`. Restrictions can specify that the number must be greater than or + * equal, or strictly greater than, a given number -- and/or that the + * number must be less than or equal, or strictly less than, a given + * number. Signal an error if restrictions not met. + * + * @tparam T Numeric type of `value` (e.g. Int or Double). + * @param minposs Required: Minimum possible value for this numeric type; + * used as the default value for certain non-specified arguments. + * @param maxposs Required: Maximum possible value for this numeric type; + * used as the default value for certain non-specified arguments. + * @param gt If specified, `value` must be greater than the given number. + * @param ge If specified, `value` must be greater than or equal to the + * given number. + * @param lt If specified, `value` must be less than the given number. + * @param le If specified, `value` must be less than or equal to the + * given number. + */ + def check_restrictions[T <% Ordered[T]](value: T, name: String, ap: ArgParser, + minposs: T, maxposs: T)(gt: T = minposs, ge: T = minposs, + lt: T = maxposs, le: T = maxposs) { + val has_lower_bound = !(gt == minposs && ge == minposs) + val has_upper_bound = !(lt == maxposs && le == maxposs) + val has_open_lower_bound = (gt != minposs) + val has_open_upper_bound = (lt != maxposs) + + def check_lower_bound(): String = { + if (!has_lower_bound) null + else if (has_open_lower_bound) { + if (value > gt) null + else "strictly greater than %s" format gt + } else { + if (value >= ge) null + else "at least %s" format ge + } + } + def check_upper_bound(): String = { + if (!has_upper_bound) null + else if (has_open_upper_bound) { + if (value < lt) null + else "strictly less than %s" format lt + } else { + if (value <= le) null + else "at most %s" format le + } + } + + val lowerstr = check_lower_bound() + val upperstr = check_upper_bound() + + def range_error(restriction: String) { + val msg = """Argument "%s" has value %s, but must be %s.""" format + (name, value, restriction) + throw new ArgParserRangeException(msg) + } + + if (lowerstr != null && upperstr != null) + range_error("%s and %s" format (lowerstr, upperstr)) + else if (lowerstr != null) + range_error("%s" format lowerstr) + else if (upperstr != null) + range_error("%s" format upperstr) + } + + /** + * Conversion function for Ints. Also check that the result meets the + * given restrictions (conditions). + */ + def convertRestrictedInt( + gt: Int = Int.MinValue, ge: Int = Int.MinValue, + lt: Int = Int.MaxValue, le: Int = Int.MaxValue + ) = { + (rawval: String, name: String, ap: ArgParser) => { + val retval = convertInt(rawval, name, ap) + check_restrictions[Int](retval, name, ap, Int.MinValue, Int.MaxValue)( + gt = gt, ge = ge, lt = lt, le = le) + retval + } + } + + /** + * Conversion function for Doubles. Also check that the result meets the + * given restrictions (conditions). + */ + def convertRestrictedDouble( + gt: Double = Double.NegativeInfinity, ge: Double = Double.NegativeInfinity, + lt: Double = Double.PositiveInfinity, le: Double = Double.PositiveInfinity + ) = { + (rawval: String, name: String, ap: ArgParser) => { + val retval = convertDouble(rawval, name, ap) + check_restrictions[Double](retval, name, ap, + Double.NegativeInfinity, Double.PositiveInfinity)( + gt = gt, ge = ge, lt = lt, le = le) + retval + } + } + + /** + * Conversion function for positive Int. Checks that the result is > 0. + */ + def convertPositiveInt = convertRestrictedInt(gt = 0) + /** + * Conversion function for non-negative Int. Check that the result is >= + * 0. + */ + def convertNonNegativeInt = convertRestrictedInt(ge = 0) + + /** + * Conversion function for positive Double. Checks that the result is > 0. + */ + def convertPositiveDouble = convertRestrictedDouble(gt = 0) + /** + * Conversion function for non-negative Double. Check that the result is >= + * 0. + */ + def convertNonNegativeDouble = convertRestrictedDouble(ge = 0) + + /** + * Implicit conversion function for Strings. Automatically selected + * for String-type arguments. + */ + implicit def convertString(rawval: String, name: String, ap: ArgParser) = { + rawval + } + + /** + * Implicit conversion function for Boolean arguments, used for options + * that take a value (rather than flags). + */ + implicit def convertBoolean(rawval: String, name: String, ap: ArgParser) = { + rawval.toLowerCase match { + case "yes" => true + case "no" => false + case "y" => true + case "n" => false + case "true" => true + case "false" => false + case "t" => true + case "f" => false + case "on" => true + case "off" => false + case _ => throw new ArgParserConversionException( + ("""Cannot convert argument "%s" to a boolean. """ + + """Recognized values (case-insensitive) are """ + + """yes, no, y, n, true, false, t, f, on, off.""") format rawval) + } + } + + /** + * Superclass of all exceptions related to `argparser`. These exceptions + * are generally thrown during argument parsing. Normally, the exceptions + * are automatically caught, their message displayed, and then the + * program exited with code 1, indicating a problem. However, this + * behavior can be suppressed by setting the constructor parameter + * `catchErrors` on `ArgParser` to false. In such a case, the exceptions + * will be propagated to the caller, which should catch them and handle + * appropriately; otherwise, the program will be terminated with a stack + * trace. + * + * @param message Message of the exception + * @param cause If not None, an exception, used for exception chaining + * (when one exception is caught, wrapped in another exception and + * rethrown) + */ + class ArgParserException(val message: String, + val cause: Option[Throwable] = None) extends Exception(message) { + if (cause != None) + initCause(cause.get) + + /** + * Alternate constructor. + * + * @param message exception message + */ + def this(msg: String) = this(msg, None) + + /** + * Alternate constructor. + * + * @param message exception message + * @param cause wrapped, or nested, exception + */ + def this(msg: String, cause: Throwable) = this(msg, Some(cause)) + } + + /** + * Thrown to indicate usage errors. + * + * @param message fully fleshed-out usage string. + * @param cause exception, if propagating an exception + */ + class ArgParserUsageException( + message: String, + cause: Option[Throwable] = None + ) extends ArgParserException(message, cause) + + /** + * Thrown to indicate that ArgParser could not convert a command line + * argument to the desired type. + * + * @param message exception message + * @param cause exception, if propagating an exception + */ + class ArgParserConversionException( + message: String, + cause: Option[Throwable] = None + ) extends ArgParserException(message, cause) + + /** + * Thrown to indicate that a command line argument was outside of a + * required range. + * + * @param message exception message + * @param cause exception, if propagating an exception + */ + class ArgParserRangeException( + message: String, + cause: Option[Throwable] = None + ) extends ArgParserException(message, cause) + + /** + * Thrown to indicate that an invalid choice was given for a limited-choice + * argument. The message indicates both the problem and the list of + * possible choices. + * + * @param message exception message + */ + class ArgParserInvalidChoiceException(message: String, + cause: Option[Throwable] = None + ) extends ArgParserConversionException(message, cause) + + /** + * Thrown to indicate that ArgParser encountered a problem in the caller's + * argument specification, or something else indicating invalid coding. + * This indicates a bug in the caller's code. These exceptions are not + * automatically caught. + * + * @param message exception message + */ + class ArgParserCodingError(message: String, + cause: Option[Throwable] = None + ) extends ArgParserException("(CALLER BUG) " + message, cause) + + /** + * Thrown to indicate that ArgParser encountered a problem that should + * never occur under any circumstances, indicating a bug in the ArgParser + * code itself. These exceptions are not automatically caught. + * + * @param message exception message + */ + class ArgParserInternalError(message: String, + cause: Option[Throwable] = None + ) extends ArgParserException("(INTERNAL BUG) " + message, cause) + + /* Some static functions related to ArgParser; all are for internal use */ + protected object ArgParser { + // Given a list of aliases for an argument, return the canonical one + // (first one that's more than a single letter). + def canonName(name: Seq[String]): String = { + assert(name.length > 0) + for (n <- name) { + if (n.length > 1) return n + } + return name(0) + } + + // Compute the metavar for an argument. If the metavar has already + // been given, use it; else, use the upper case version of the + // canonical name of the argument. + def computeMetavar(metavar: String, name: Seq[String]) = { + if (metavar != null) metavar + else canonName(name).toUpperCase + } + + // Return a sequence of all the given strings that aren't null. + def nonNullVals(val1: String, val2: String, val3: String, val4: String, + val5: String, val6: String, val7: String, val8: String, + val9: String) = { + val retval = + Seq(val1, val2, val3, val4, val5, val6, val7, val8, val9) filter + (_ != null) + if (retval.length == 0) + throw new ArgParserCodingError( + "Need to specify at least one name for each argument") + retval + } + + // Combine `choices` and `aliasedChoices` into a larger list of the + // format of `aliasedChoices`. Note that before calling this, a special + // check must be done for the case where `choices` and `aliasedChoices` + // are both null, which includes that no limited-choice restrictions + // apply at all (and is actually the most common situation). + def canonicalizeChoicesAliases[T](choices: Seq[T], + aliasedChoices: Seq[Seq[T]]) = { + val newchoices = if (choices != null) choices else Seq[T]() + val newaliased = + if (aliasedChoices != null) aliasedChoices else Seq[Seq[T]]() + for (spellings <- newaliased) { + if (spellings.length == 0) + throw new ArgParserCodingError( + "Zero-length list of spellings not allowed in `aliasedChoices`:\n%s" + format newaliased) + } + newchoices.map(x => Seq(x)) ++ newaliased + } + + // Convert a list of choices in the format of `aliasedChoices` + // (a sequence of sequences, first item is the canonical spelling) + // into a mapping that canonicalizes choices. + def getCanonMap[T](aliasedChoices: Seq[Seq[T]]) = { + (for {spellings <- aliasedChoices + canon = spellings.head + spelling <- spellings} + yield (spelling, canon)).toMap + } + + // Return a human-readable list of all choices, based on the specifications + // of `choices` and `aliasedChoices`. If 'includeAliases' is true, include + // the aliases in the list of choices, in parens after the canonical name. + def choicesList[T](choices: Seq[T], aliasedChoices: Seq[Seq[T]], + includeAliases: Boolean) = { + val fullaliased = + canonicalizeChoicesAliases(choices, aliasedChoices) + if (!includeAliases) + fullaliased.map(_.head) mkString ", " + else + ( + for { spellings <- fullaliased + canon = spellings.head + altspellings = spellings.tail + } + yield { + if (altspellings.length > 0) + "%s (%s)" format (canon, altspellings mkString "/") + else canon.toString + } + ) mkString ", " + } + + // Check that the given value passes any restrictions imposed by + // `choices` and/or `aliasedChoices`. If not, throw an exception. + def checkChoices[T](converted: T, + choices: Seq[T], aliasedChoices: Seq[Seq[T]]) = { + if (choices == null && aliasedChoices == null) converted + else { + val fullaliased = + canonicalizeChoicesAliases(choices, aliasedChoices) + val canonmap = getCanonMap(fullaliased) + if (canonmap contains converted) + canonmap(converted) + else + throw new ArgParserInvalidChoiceException( + "Choice '%s' not one of the recognized choices: %s" + format (converted, choicesList(choices, aliasedChoices, true))) + } + } + } + + /** + * Base class of all argument-wrapping classes. These are used to + * wrap the appropriate argument-category class from Argot, and return + * values by querying Argot for the value, returning the default value + * if Argot doesn't have a value recorded. + * + * NOTE that these classes are not meant to leak out to the user. They + * should be considered implementation detail only and subject to change. + * + * @param parser ArgParser for which this argument exists. + * @param name Name of the argument. + * @param default Default value of the argument, used when the argument + * wasn't specified on the command line. + * @tparam T Type of the argument (e.g. Int, Double, String, Boolean). + */ + + abstract protected class ArgAny[T]( + val parser: ArgParser, + val name: String, + val default: T + ) { + /** + * Return the value of the argument, if specified; else, the default + * value. */ + def value = { + if (overridden) + overriddenValue + else if (specified) + wrappedValue + else + default + } + + def setValue(newval: T) { + overriddenValue = newval + overridden = true + } + + /** + * When dereferenced as a function, also return the value. + */ + def apply() = value + + /** + * Whether the argument's value was specified. If not, the default + * value applies. + */ + def specified: Boolean + + /** + * Clear out any stored values so that future queries return the default. + */ + def clear() { + clearWrapped() + overridden = false + } + + /** + * Return the value of the underlying Argot object, assuming it exists + * (possibly error thrown if not). + */ + protected def wrappedValue: T + + /** + * Clear away the wrapped value. + */ + protected def clearWrapped() + + /** + * Value if the user explicitly set a value. + */ + protected var overriddenValue: T = _ + + /** + * Whether the user explicit set a value. + */ + protected var overridden: Boolean = false + + } + + /** + * Class for wrapping simple Boolean flags. + * + * @param parser ArgParser for which this argument exists. + * @param name Name of the argument. + */ + + protected class ArgFlag( + parser: ArgParser, + name: String + ) extends ArgAny[Boolean](parser, name, false) { + var wrap: FlagOption[Boolean] = null + def wrappedValue = wrap.value.get + def specified = (wrap != null && wrap.value != None) + def clearWrapped() { if (wrap != null) wrap.reset() } + } + + /** + * Class for wrapping a single (non-multi) argument (either option or + * positional param). + * + * @param parser ArgParser for which this argument exists. + * @param name Name of the argument. + * @param default Default value of the argument, used when the argument + * wasn't specified on the command line. + * @param is_positional Whether this is a positional argument rather than + * option (default false). + * @tparam T Type of the argument (e.g. Int, Double, String, Boolean). + */ + + protected class ArgSingle[T]( + parser: ArgParser, + name: String, + default: T, + val is_positional: Boolean = false + ) extends ArgAny[T](parser, name, default) { + var wrap: SingleValueArg[T] = null + def wrappedValue = wrap.value.get + def specified = (wrap != null && wrap.value != None) + def clearWrapped() { if (wrap != null) wrap.reset() } + } + + /** + * Class for wrapping a multi argument (either option or positional param). + * + * @param parser ArgParser for which this argument exists. + * @param name Name of the argument. + * @param default Default value of the argument, used when the argument + * wasn't specified on the command line even once. + * @param is_positional Whether this is a positional argument rather than + * option (default false). + * @tparam T Type of the argument (e.g. Int, Double, String, Boolean). + */ + + protected class ArgMulti[T]( + parser: ArgParser, + name: String, + default: Seq[T], + val is_positional: Boolean = false + ) extends ArgAny[Seq[T]](parser, name, default) { + var wrap: MultiValueArg[T] = null + val wrapSingle = new ArgSingle[T](parser, name, null.asInstanceOf[T]) + def wrappedValue = wrap.value + def specified = (wrap != null && wrap.value.length > 0) + def clearWrapped() { if (wrap != null) wrap.reset() } + } + + /** + * Main class for parsing arguments from a command line. + * + * @param prog Name of program being run, for the usage mssage. + * @param description Text describing the operation of the program. It is + * placed between the line "Usage: ..." and the text describing the + * options and positional arguments; hence, it should not include either + * of these, just a description. + * @param preUsage Optional text placed before the usage message (e.g. + * a copyright and/or version string). + * @param postUsage Optional text placed after the usage message. + * @param return_defaults If true, field values in field-based value + * access always return the default value, even aft3r parsing. + */ + class ArgParser(prog: String, + description: String = "", + preUsage: String = "", + postUsage: String = "", + return_defaults: Boolean = false) { + import ArgParser._ + import ArgotConverters._ + /* The underlying ArgotParser object. */ + protected val argot = new ArgotParser(prog, + description = if (description.length > 0) Some(description) else None, + preUsage = if (preUsage.length > 0) Some(preUsage) else None, + postUsage = if (postUsage.length > 0) Some(postUsage) else None) + /* A map from the argument's canonical name to the subclass of ArgAny + describing the argument and holding its value. The canonical name + of options comes from the first non-single-letter name. The + canonical name of positional arguments is simply the name of the + argument. Iteration over the map yields keys in the order they + were added rather than random. */ + protected val argmap = mutable.LinkedHashMap[String, ArgAny[_]]() + /* The type of each argument. For multi options and multi positional + arguments this will be of type Seq. Because of type erasure, the + type of sequence must be stored separately, using argtype_multi. */ + protected val argtype = mutable.Map[String, Class[_]]() + /* For multi arguments, the type of each individual argument. */ + protected val argtype_multi = mutable.Map[String, Class[_]]() + /* Set specifying arguments that are positional arguments. */ + protected val argpositional = mutable.Set[String]() + /* Set specifying arguments that are flag options. */ + protected val argflag = mutable.Set[String]() + + /* NOTE NOTE NOTE: Currently we don't provide any programmatic way of + accessing the ArgAny-subclass object by name. This is probably + a good thing -- these objects can be viewed as internal + */ + /** + * Return the value of an argument, or the default if not specified. + * + * @param arg The canonical name of the argument, i.e. the first + * non-single-letter alias given. + * @return The value, of type Any. It must be cast to the appropriate + * type. + * @see #get[T] + */ + def apply(arg: String) = argmap(arg).value + + /** + * Return the value of an argument, or the default if not specified. + * + * @param arg The canonical name of the argument, i.e. the first + * non-single-letter alias given. + * @tparam T The type of the argument, which must match the type given + * in its definition + * + * @return The value, of type T. + */ + def get[T](arg: String) = argmap(arg).asInstanceOf[ArgAny[T]].value + + /** + * Explicitly set the value of an argument. + * + * @param arg The canonical name of the argument, i.e. the first + * non-single-letter alias given. + * @param value The new value of the argument. + * @tparam T The type of the argument, which must match the type given + * in its definition + * + * @return The value, of type T. + */ + def set[T](arg: String, value: T) { + argmap(arg).asInstanceOf[ArgAny[T]].setValue(value) + } + + /** + * Return the default value of an argument. + * + * @param arg The canonical name of the argument, i.e. the first + * non-single-letter alias given. + * @tparam T The type of the argument, which must match the type given + * in its definition + * + * @return The value, of type T. + */ + def defaultValue[T](arg: String) = + argmap(arg).asInstanceOf[ArgAny[T]].default + + /** + * Return whether an argument (either option or positional argument) + * exists with the given canonical name. + */ + def exists(arg: String) = argmap contains arg + + /** + * Return whether an argument exists with the given canonical name. + */ + def isOption(arg: String) = exists(arg) && !isPositional(arg) + + /** + * Return whether a positional argument exists with the given name. + */ + def isPositional(arg: String) = argpositional contains arg + + /** + * Return whether a flag option exists with the given canonical name. + */ + def isFlag(arg: String) = argflag contains arg + + /** + * Return whether a multi argument (either option or positional argument) + * exists with the given canonical name. + */ + def isMulti(arg: String) = argtype_multi contains arg + + /** + * Return whether the given argument's value was specified. If not, + * fetching the argument's value returns its default value instead. + */ + def specified(arg: String) = argmap(arg).specified + + /** + * Return the type of the given argument. For multi arguments, the + * type will be Seq, and the type of the individual arguments can only + * be retrieved using `getMultiType`, due to type erasure. + */ + def getType(arg: String) = argtype(arg) + + /** + * Return the type of an individual argument value of a multi argument. + * The actual type of the multi argument is a Seq of the returned type. + */ + def getMultiType(arg: String) = argtype_multi(arg) + + /** + * Iterate over all defined arguments. + * + * @return an Iterable over the names of the arguments. The argument + * categories (e.g. option, multi-option, flag, etc.), argument + * types (e.g. Int, Boolean, Double, String, Seq[String]), default + * values and actual values can be retrieved using other functions. + */ + def argNames: Iterable[String] = { + for ((name, argobj) <- argmap) yield name + } + + protected def handle_argument[T : Manifest, U : Manifest]( + name: Seq[String], + default: U, + metavar: String, + choices: Seq[T], + aliasedChoices: Seq[Seq[T]], + help: String, + create_underlying: (String, String, String) => ArgAny[U], + is_multi: Boolean = false, + is_positional: Boolean = false, + is_flag: Boolean = false + ) = { + val canon = canonName(name) + if (return_defaults) + default + else if (argmap contains canon) + argmap(canon).asInstanceOf[ArgAny[U]].value + else { + val canon_metavar = computeMetavar(metavar, name) + val helpsplit = """(%%|%default|%choices|%allchoices|%metavar|%prog|%|[^%]+)""".r.findAllIn( + help.replaceAll("""\s+""", " ")) + val canon_help = + (for (s <- helpsplit) yield { + s match { + case "%default" => default.toString + case "%choices" => choicesList(choices, aliasedChoices, false) + case "%allchoices" => choicesList(choices, aliasedChoices, true) + case "%metavar" => canon_metavar + case "%%" => "%" + case "%prog" => this.prog + case _ => s + } + }) mkString "" + val underobj = create_underlying(canon, canon_metavar, canon_help) + argmap(canon) = underobj + argtype(canon) = manifest[U].erasure + if (is_multi) + argtype_multi(canon) = manifest[T].erasure + if (is_positional) + argpositional += canon + if (is_flag) + argflag += canon + default + } + } + + protected def argot_converter[T]( + convert: (String, String, ArgParser) => T, canon_name: String, + choices: Seq[T], aliasedChoices: Seq[Seq[T]]) = { + (rawval: String, argop: CommandLineArgument[T]) => { + val converted = convert(rawval, canon_name, this) + checkChoices(converted, choices, aliasedChoices) + } + } + + def optionSeq[T](name: Seq[String], + default: T = null.asInstanceOf[T], + metavar: String = null, + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "") + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + def create_underlying(canon_name: String, canon_metavar: String, + canon_help: String) = { + val arg = new ArgSingle(this, canon_name, default) + arg.wrap = + (argot.option[T](name.toList, canon_metavar, canon_help) + (argot_converter(convert, canon_name, choices, aliasedChoices))) + arg + } + handle_argument[T,T](name, default, metavar, choices, aliasedChoices, + help, create_underlying _) + } + + /** + * Define a single-valued option of type T. Various standard types + * are recognized, e.g. String, Int, Double. (This is handled through + * the implicit `convert` argument.) Up to nine aliases for the + * option can be specified. Single-letter aliases are specified using + * a single dash, whereas longer aliases generally use two dashes. + * The "canonical" name of the option is the first non-single-letter + * alias given. + * + * @param name1 + * @param name2 + * @param name3 + * @param name4 + * @param name5 + * @param name6 + * @param name7 + * @param name8 + * @param name9 + * Up to nine aliases for the option; see above. + * + * @param default Default value, if option not specified; if not given, + * it will end up as 0, 0.0 or false for value types, null for + * reference types. + * @param metavar "Type" of the option, as listed in the usage string. + * This is so that the relevant portion of the usage string will say + * e.g. "--counts-file FILE File containing word counts." (The + * value of `metavar` would be "FILE".) If not given, automatically + * computed from the canonical option name by capitalizing it. + * @param choices List of possible choices for this option. If specified, + * it should be a sequence of possible choices that will be allowed, + * and only the choices that are either in this list of specified via + * `aliasedChoices` will be allowed. If neither `choices` nor + * `aliasedChoices` is given, all values will be allowed. + * @param aliasedChoices List of possible choices for this option, + * including alternative spellings (aliases). If specified, it should + * be a sequence of sequences, each of which specifies the possible + * alternative spellings for a given choice and where the first listed + * spelling is considered the "canonical" one. All choices that + * consist of any given spelling will be allowed, but any non-canonical + * spellings will be replaced by equivalent canonical spellings. + * For example, the choices of "dev", "devel" and "development" may + * all mean the same thing; regardless of how the user spells this + * choice, the same value will be passed to the program (whichever + * spelling comes first). Note that the value of `choices` lists + * additional choices, which are equivalent to choices listed in + * `aliasedChoices` without any alternative spellings. If both + * `choices` and `aliasedChoices` are omitted, all values will be + * allowed. + * @param help Help string for the option, shown in the usage string. + * @param convert Function to convert the raw option (a string) into + * a value of type `T`. The second and third parameters specify + * the name of the argument whose value is being converted, and the + * ArgParser object that the argument is defined on. Under normal + * circumstances, these parameters should not affect the result of + * the conversion function. For standard types, no conversion + * function needs to be specified, as the correct conversion function + * will be located automatically through Scala's 'implicit' mechanism. + * @tparam T The type of the option. For non-standard types, a + * converter must explicitly be given. (The standard types recognized + * are currently Int, Double, Boolean and String.) + * + * @return If class parameter `return_defaults` is true, the default + * value. Else, if the first time called, exits non-locally; this + * is used internally. Otherwise, the value of the parameter. + */ + def option[T]( + name1: String, name2: String = null, name3: String = null, + name4: String = null, name5: String = null, name6: String = null, + name7: String = null, name8: String = null, name9: String = null, + default: T = null.asInstanceOf[T], + metavar: String = null, + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "") + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + optionSeq[T](nonNullVals(name1, name2, name3, name4, name5, name6, + name7, name8, name9), + metavar = metavar, default = default, choices = choices, + aliasedChoices = aliasedChoices, help = help)(convert, m) + } + + def flagSeq(name: Seq[String], + help: String = "") = { + import ArgotConverters._ + def create_underlying(canon_name: String, canon_metavar: String, + canon_help: String) = { + val arg = new ArgFlag(this, canon_name) + arg.wrap = argot.flag[Boolean](name.toList, canon_help) + arg + } + handle_argument[Boolean,Boolean](name, false, null, Seq(true, false), + null, help, create_underlying _) + } + + /** + * Define a boolean flag option. Unlike other options, flags have no + * associated value. Instead, their type is always Boolean, with the + * value 'true' if the flag is specified, 'false' if not. + * + * @param name1 + * @param name2 + * @param name3 + * @param name4 + * @param name5 + * @param name6 + * @param name7 + * @param name8 + * @param name9 + * Up to nine aliases for the option; same as for `option[T]()`. + * + * @param help Help string for the option, shown in the usage string. + */ + def flag(name1: String, name2: String = null, name3: String = null, + name4: String = null, name5: String = null, name6: String = null, + name7: String = null, name8: String = null, name9: String = null, + help: String = "") = { + flagSeq(nonNullVals(name1, name2, name3, name4, name5, name6, + name7, name8, name9), + help = help) + } + + def multiOptionSeq[T](name: Seq[String], + default: Seq[T] = Seq[T](), + metavar: String = null, + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "") + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + def create_underlying(canon_name: String, canon_metavar: String, + canon_help: String) = { + val arg = new ArgMulti[T](this, canon_name, default) + arg.wrap = + (argot.multiOption[T](name.toList, canon_metavar, canon_help) + (argot_converter(convert, canon_name, choices, aliasedChoices))) + arg + } + handle_argument[T,Seq[T]](name, default, metavar, choices, aliasedChoices, + help, create_underlying _, is_multi = true) + } + + /** + * Specify an option that can be repeated multiple times. The resulting + * option value will be a sequence (Seq) of all the values given on the + * command line (one value per occurrence of the option). If there are + * no occurrences of the option, the value will be an empty sequence. + * (NOTE: This is different from single-valued options, where the + * default value can be explicitly specified, and if not given, will be + * `null` for reference types. Here, `null` will never occur.) + * + * FIXME: There should be a way of allowing for specifying multiple values + * in a single argument, separated by spaces, commas, etc. We'd want the + * caller to be able to pass in a function to split the string. Currently + * Argot doesn't seem to have a way of allowing a converter function to + * take a single argument and stuff in multiple values, so we probably + * need to modify Argot. (At some point we should just incorporate the + * relevant parts of Argot directly.) + */ + def multiOption[T]( + name1: String, name2: String = null, name3: String = null, + name4: String = null, name5: String = null, name6: String = null, + name7: String = null, name8: String = null, name9: String = null, + default: Seq[T] = Seq[T](), + metavar: String = null, + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "") + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + multiOptionSeq[T](nonNullVals(name1, name2, name3, name4, name5, name6, + name7, name8, name9), + default = default, metavar = metavar, choices = choices, + aliasedChoices = aliasedChoices, help = help)(convert, m) + } + + /** + * Specify a positional argument. Positional argument are processed + * in order. Optional argument must occur after all non-optional + * argument. The name of the argument is only used in the usage file + * and as the "name" parameter of the ArgSingle[T] object passed to + * the (implicit) conversion routine. Usually the name should be in + * all caps. + * + * @see #multiPositional[T] + */ + def positional[T](name: String, + default: T = null.asInstanceOf[T], + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "", + optional: Boolean = false) + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + def create_underlying(canon_name: String, canon_metavar: String, + canon_help: String) = { + val arg = new ArgSingle(this, canon_name, default, is_positional = true) + arg.wrap = + (argot.parameter[T](canon_name, canon_help, optional) + (argot_converter(convert, canon_name, choices, aliasedChoices))) + arg + } + handle_argument[T,T](Seq(name), default, null, choices, aliasedChoices, + help, create_underlying _, is_positional = true) + } + + /** + * Specify any number of positional arguments. These must come after + * all other arguments. + * + * @see #positional[T]. + */ + def multiPositional[T](name: String, + default: Seq[T] = Seq[T](), + choices: Seq[T] = null, + aliasedChoices: Seq[Seq[T]] = null, + help: String = "", + optional: Boolean = true) + (implicit convert: (String, String, ArgParser) => T, m: Manifest[T]) = { + def create_underlying(canon_name: String, canon_metavar: String, + canon_help: String) = { + val arg = new ArgMulti[T](this, canon_name, default, is_positional = true) + arg.wrap = + (argot.multiParameter[T](canon_name, canon_help, optional) + (argot_converter(convert, canon_name, choices, aliasedChoices))) + arg + } + handle_argument[T,Seq[T]](Seq(name), default, null, choices, + aliasedChoices, help, create_underlying _, + is_multi = true, is_positional = true) + } + + /** + * Parse the given command-line arguments. Extracted values of the + * arguments can subsequently be obtained either using the `#get[T]` + * function, by directly treating the ArgParser object as if it were + * a hash table and casting the result, or by using a separate class + * to hold the extracted values in fields, as described above. The + * last method is the recommended one and generally the easiest-to- + * use for the consumer of the values. + * + * @param args Command-line arguments, from main() or the like + * @param catchErrors If true (the default), usage errors will + * be caught, a message outputted (without a stack trace), and + * the program will exit. Otherwise, the errors will be allowed + * through, and the application should catch them. + */ + def parse(args: Seq[String], catchErrors: Boolean = true) = { + if (argmap.size == 0) + throw new ArgParserCodingError("No arguments initialized. If you thought you specified arguments, you might have defined the corresponding fields with 'def' instead of 'var' or 'val'.") + + def call_parse() { + // println(argmap) + try { + argot.parse(args.toList) + } catch { + case e: ArgotUsageException => { + throw new ArgParserUsageException(e.message, Some(e)) + } + } + } + + // Reset everything, in case the user explicitly set some values + // (which otherwise override values retrieved from parsing) + clear() + if (catchErrors) { + try { + call_parse() + } catch { + case e: ArgParserException => { + System.out.println(e.message) + System.exit(1) + } + } + } else call_parse() + } + + /** + * Clear all arguments back to their default values. + */ + def clear() { + for (obj <- argmap.values) { + obj.clear() + } + } + + def error(msg: String) = { + throw new ArgParserConversionException(msg) + } + + def usageError(msg: String) = { + throw new ArgParserUsageException(msg) + } + } +} + +object TestArgParser extends App { + import argparser._ + class MyParams(ap: ArgParser) { + /* An integer option named --foo, with a default value of 5. Can also + be specified using --spam or -f. */ + var foo = ap.option[Int]("foo", "spam", "f", default = 5, + help="""An integer-valued option. Default %default.""") + /* A string option named --bar, with a default value of "chinga". Can + also be specified using -b. */ + var bar = ap.option[String]("bar", "b", default = "chinga") + /* A string option named --baz, which can be given multiple times. + Default value is an empty sequence. */ + var baz = ap.multiOption[String]("baz") + /* A floating-point option named --tick, which can be given multiple times. + Default value is the sequence Seq(2.5, 5.0, 9.0), which will obtain + when the option hasn't been given at all. */ + var tick = ap.multiOption[Double]("tick", default = Seq(2.5, 5.0, 9.0), + help = """Option --tick, perhaps for specifying the position of +tick marks along the X axis. Multiple such options can be given. If +no marks are specified, the default is %default. Note that we can + freely insert + spaces and carriage + returns into the help text; whitespace is compressed + to a single space.""") + /* A flag --bezzaaf, alias -q. Value is true if given, false if not. */ + var bezzaaf = ap.flag("bezzaaf", "q") + /* An integer option --blop, with only the values 1, 2, 4 or 7 are + allowed. Default is 1. Note, in this case, if the default is + not given, it will end up as 0, even though this isn't technically + a valid choice. This could be considered a bug -- perhaps instead + we should default to the first choice listed, or throw an error. + (It could also be considered a possibly-useful hack to allow + detection of when no choice is given; but this can be determined + in a more reliable way using `ap.specified("blop")`.) + */ + var blop = ap.option[Int]("blop", default = 1, choices = Seq(1, 2, 4, 7), + help = """An integral argument with limited choices. Default is %default, +possible choices are %choices.""") + /* A string option --daniel, with only the values "mene", "tekel", and + "upharsin" allowed, but where values can be repeated, e.g. + --daniel mene --daniel mene --daniel tekel --daniel upharsin + . */ + var daniel = ap.multiOption[String]("daniel", + choices = Seq("mene", "tekel", "upharsin")) + var strategy = + ap.multiOption[String]("s", "strategy", + aliasedChoices = Seq( + Seq("baseline"), + Seq("none"), + Seq("full-kl-divergence", "full-kldiv", "full-kl"), + Seq("partial-kl-divergence", "partial-kldiv", "partial-kl", "part-kl"), + Seq("symmetric-full-kl-divergence", "symmetric-full-kldiv", + "symmetric-full-kl", "sym-full-kl"), + Seq("symmetric-partial-kl-divergence", + "symmetric-partial-kldiv", "symmetric-partial-kl", "sym-part-kl"), + Seq("cosine-similarity", "cossim"), + Seq("partial-cosine-similarity", "partial-cossim", "part-cossim"), + Seq("smoothed-cosine-similarity", "smoothed-cossim"), + Seq("smoothed-partial-cosine-similarity", "smoothed-partial-cossim", + "smoothed-part-cossim"), + Seq("average-cell-probability", "avg-cell-prob", "acp"), + Seq("naive-bayes-with-baseline", "nb-base"), + Seq("naive-bayes-no-baseline", "nb-nobase")), + help = """A multi-string option. This is an actual option in +one of my research programs. Possible choices are %choices; the full list +of choices, including all aliases, is %allchoices.""") + /* A required positional argument. */ + var destfile = ap.positional[String]("DESTFILE", + help = "Destination file to store output in") + /* A multi-positional argument that sucks up all remaining arguments. */ + var files = ap.multiPositional[String]("FILES", help = "Files to process") + } + val ap = new ArgParser("test") + // This first call is necessary, even though it doesn't appear to do + // anything. In particular, this ensures that all arguments have been + // defined on `ap` prior to parsing. + new MyParams(ap) + // ap.parse(List("--foo", "7")) + ap.parse(args) + val Params = new MyParams(ap) + // Print out values of all arguments, whether options or positional. + // Also print out types and default values. + for (name <- ap.argNames) + println("%30s: %s (%s) (default=%s)" format ( + name, ap(name), ap.getType(name), ap.defaultValue[Any](name))) + // Examples of how to retrieve individual arguments + for (file <- Params.files) + println("Process file: %s" format file) + println("Maximum tick mark seen: %s" format (Params.tick max)) + // We can freely change the value of arguments if we want, since they're + // just vars. + if (Params.daniel contains "upharsin") + Params.bar = "chingamos" +} diff --git a/src/main/scala/opennlp/fieldspring/util/collectionutil.scala b/src/main/scala/opennlp/fieldspring/util/collectionutil.scala new file mode 100644 index 0000000..3b89e8b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/collectionutil.scala @@ -0,0 +1,706 @@ +/////////////////////////////////////////////////////////////////////////////// +// collectiontuil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.util.control.Breaks._ +import scala.collection.mutable +import scala.collection.mutable.{Builder, MapBuilder} +import scala.collection.generic.CanBuildFrom + +/** + * A package containing various collection-related classes and functions. + * + * Provided: + * + * -- Default hash tables (which automatically provide a default value if + * no key exists; these already exist in Scala but this package makes + * them easier to use) + * -- Dynamic arrays (similar to ArrayBuilder but specialized to use + * primitive arrays underlyingly, and allow direct access to that array) + * -- LRU (least-recently-used) caches + * -- Hash tables by range (which keep track of a subtable for each range + * of numeric keys) + * -- A couple of miscellaneous functions: + * -- 'fromto', which does a range that is insensitive to order of its + * arguments + * -- 'merge_numbered_sequences_uniquely' + */ + +package object collectionutil { + + //////////////////////////////////////////////////////////////////////////// + // Default maps // + //////////////////////////////////////////////////////////////////////////// + + abstract class DefaultHashMap[T,U] extends mutable.HashMap[T,U] { + def getNoSet(key: T): U + } + + /** + * Create a default hash table, i.e. a hash table where accesses to + * undefined values automatically return 'defaultval'. This class also + * automatically sets the undefined key to 'defaultval' upon first + * access to that key. If you don't want this behavior, call getNoSet() + * or use the non-setting variant below. See the discussion below in + * defaultmap() for a discussion of when setting vs. non-setting is useful + * (in a nutshell, use the setting variant when type T is a mutable + * collection; otherwise, use the nonsetting variant). + */ + class SettingDefaultHashMap[T,U]( + create_default: T => U + ) extends DefaultHashMap[T,U] { + var internal_setkey = true + + override def default(key: T) = { + val buf = create_default(key) + if (internal_setkey) + this(key) = buf + buf + } + + /** + * Retrieve the value of 'key'. If the value isn't found, the + * default value (from 'defaultval') will be returned, but the + * key will *NOT* added to the table with that value. + * + * FIXME: This code should have the equivalent of + * synchronized(internal_setkey) around it so that it will work + * in a multi-threaded environment. + */ + def getNoSet(key: T) = { + val oi_setkey = internal_setkey + try { + internal_setkey = false + this(key) + } finally { internal_setkey = oi_setkey } + } + } + + /** + * Non-setting variant class for creating a default hash table. + * See class SettingDefaultHashMap and function defaultmap(). + */ + class NonSettingDefaultHashMap[T,U]( + create_default: T => U + ) extends DefaultHashMap[T,U] { + override def default(key: T) = { + val buf = create_default(key) + buf + } + + def getNoSet(key: T) = this(key) + } + + /** + * Create a default hash map that maps keys of type T to values of type + * U, automatically returning 'defaultval' rather than throwing an exception + * if the key is undefined. + * + * @param defaultval The default value to return. Note the delayed + * evaluation using =>. This is done on purpose so that, for example, + * if we use mutable Buffers or Sets as the value type, things will + * work: We want a *different* empty vector or set each time we call + * default(), so that different keys get different empty vectors. + * Otherwise, adding an element to the buffer associated with one key + * will also add it to the buffers for other keys, which is not what + * we want. (Note, when mutable objects are used as values, you need to + * set the `setkey` parameter to true.) + * + * @param setkey indicates whether we set the key to the default upon + * access. This is necessary when the value is something mutable, but + * probably a bad idea otherwise, since looking up a nonexistent value + * in the table will cause a later contains() call to return true on the + * value. + * + * For example: + + val foo = defaultmap[String,Int](0, setkey = false) + foo("bar") -> 0 + foo contains "bar" -> false + + val foo = defaultmap[String,Int](0, setkey = true) + foo("bar") -> 0 + foo contains "bar" -> true (Probably not what we want) + + + val foo = defaultmap[String,mutable.Buffer[String]](mutable.Buffer(), setkey = false) + foo("myfoods") += "spam" + foo("myfoods") += "eggs" + foo("myfoods") += "milk" + foo("myfoods") -> ArrayBuffer(milk) (OOOOOPS) + + val foo = defaultmap[String,mutable.Buffer[String]](mutable.Buffer(), setkey = true) + foo("myfoods") += "spam" + foo("myfoods") += "eggs" + foo("myfoods") += "milk" + foo("myfoods") -> ArrayBuffer(spam, eggs, milk) (Good) + */ + def defaultmap[T,U](defaultval: => U, setkey: Boolean = false) = { + def create_default(key: T) = defaultval + if (setkey) new SettingDefaultHashMap[T,U](create_default _) + else new NonSettingDefaultHashMap[T,U](create_default _) + } + /** + * A default map where the values are mutable collections; need to have + * `setkey` true for the underlying call to `defaultmap`. + * + * @see #defaultmap[T,U] + */ + def collection_defaultmap[T,U](defaultval: => U) = + defaultmap[T,U](defaultval, setkey = true) + + /** + * A default map where the values are primitive types, and the default + * value is the "zero" value for the primitive (0, 0L, 0.0 or false). + * Note that this will "work", at least syntactically, for non-primitive + * types, but will set the default value to null, which is probably not + * what is wanted. (FIXME: Sort of. It works but if you try to print + * a result -- or more generally, pass to a function that accepts Any -- + * the boxed version shows up and you get null instead of 0.) + * + * @see #defaultmap[T,U] + */ + def primmap[T,@specialized U](default: U = null.asInstanceOf[U]) = { + defaultmap[T,U](default) + } + + /** + * A default map from type T to an Int, with 0 as the default value. + * + * @see #defaultmap[T,U] + */ + def intmap[T](default: Int = 0) = defaultmap[T,Int](default) + //def intmap[T]() = primmap[T,Int]() + def longmap[T](default: Long = 0L) = defaultmap[T,Long](default) + def shortmap[T](default: Short = 0) = defaultmap[T,Short](default) + def bytemap[T](default: Byte = 0) = defaultmap[T,Byte](default) + def doublemap[T](default: Double = 0.0d) = defaultmap[T,Double](default) + def floatmap[T](default: Float = 0.0f) = defaultmap[T,Float](default) + def booleanmap[T](default: Boolean = false) = defaultmap[T,Boolean](default) + + /** + * A default map from type T to a string, with an empty string as the + * default value. + * + * @see #defaultmap[T,U] + */ + def stringmap[T](default: String = "") = defaultmap[T,String](default) + + /** + * A default map which maps from T to an (extendable) array of type U. + * The default value is an empty Buffer of type U. Calls of the sort + * `map(key) += item` will add the item to the Buffer stored as the + * value of the key rather than changing the value itself. (After doing + * this, the result of 'map(key)' will be the same collection, but the + * contents of the collection will be modified. On the other hand, in + * the case of the above maps, the result of 'map(key)' will be + * different.) + * + * @see #setmap[T,U] + * @see #mapmap[T,U,V] + */ + def bufmap[T,U]() = + collection_defaultmap[T,mutable.Buffer[U]](mutable.Buffer[U]()) + /** + * A default map which maps from T to a set of type U. The default + * value is an empty Set of type U. Calls of the sort + * `map(key) += item` will add the item to the Set stored as the + * value of the key rather than changing the value itself, similar + * to how `bufmap` works. + * + * @see #bufmap[T,U] + * @see #mapmap[T,U,V] + */ + def setmap[T,U]() = + collection_defaultmap[T,mutable.Set[U]](mutable.Set[U]()) + /** + * A default map which maps from T to a map from U to V. The default + * value is an empty Map of type U->V. Calls of the sort + * `map(key)(key2) = value2` will add the mapping `key2 -> value2` + * to the Map stored as the value of the key rather than changing + * the value itself, similar to how `bufmap` works. + * + * @see #bufmap[T,U] + * @see #setmap[T,U] + */ + def mapmap[T,U,V]() = + collection_defaultmap[T,mutable.Map[U,V]](mutable.Map[U,V]()) + /** + * A default map which maps from T to a `primmap` from U to V, i.e. + * another default map. The default value is a `primmap` of type + * U->V. Calls of the sort `map(key)(key2) += value2` will add to + * the value of the mapping `key2 -> value2` stored for `key` in the + * main map. If `key` hasn't been seen before, a new `primmap` is + * created, and if `key2` hasn't been seen before in the `primmap` + * associated with `key`, it will be initialized to the zero value + * for type V. + * + * FIXME: Warning, doesn't always do what you want, whereas e.g. + * `intmapmap` always will. + * + * @see #mapmap[T,U,V] + */ + def primmapmap[T,U,V](default: V = null.asInstanceOf[V]) = + collection_defaultmap[T,mutable.Map[U,V]](primmap[U,V](default)) + + def intmapmap[T,U](default: Int = 0) = + collection_defaultmap[T,mutable.Map[U,Int]](intmap[U](default)) + def longmapmap[T,U](default: Long = 0L) = + collection_defaultmap[T,mutable.Map[U,Long]](longmap[U](default)) + def doublemapmap[T,U](default: Double = 0.0d) = + collection_defaultmap[T,mutable.Map[U,Double]](doublemap[U](default)) + def stringmapmap[T,U](default: String = "") = + collection_defaultmap[T,mutable.Map[U,String]](stringmap[U](default)) + + // Another way to do this, using subclassing. + // + // abstract class defaultmap[From,To] extends HashMap[From, To] { + // val defaultval: To + // override def default(key: From) = defaultval + // } + // + // class intmap[T] extends defaultmap[T, Int] { val defaultval = 0 } + // + + // The original way + // + // def booleanmap[String]() = { + // new HashMap[String, Boolean] { + // override def default(key: String) = false + // } + // } + + ///////////////////////////////////////////////////////////////////////////// + // Dynamic arrays // + ///////////////////////////////////////////////////////////////////////////// + + /** + A simple class like ArrayBuilder but which gets you direct access + to the underlying array and lets you easily reset things, so that you + can reuse a Dynamic Array multiple times without constantly creating + new objects. Also has specialization. + */ + class DynamicArray[@specialized T: ClassManifest](initial_alloc: Int = 8) { + protected val multiply_factor = 1.5 + var array = new Array[T](initial_alloc) + var length = 0 + def ensure_at_least(size: Int) { + if (array.length < size) { + var newsize = array.length + while (newsize < size) + newsize = (newsize * multiply_factor).toInt + val newarray = new Array[T](newsize) + System.arraycopy(array, 0, newarray, 0, length) + array = newarray + } + } + + def += (item: T) { + ensure_at_least(length + 1) + array(length) = item + length += 1 + } + + def clear() { + length = 0 + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Sorted lists // + ///////////////////////////////////////////////////////////////////////////// + + // Return a tuple (keys, values) of lists of items corresponding to a hash + // table. Stored in sorted order according to the keys. Use + // lookup_sorted_list(key) to find the corresponding value. The purpose of + // doing this, rather than just directly using a hash table, is to save + // memory. + +// def make_sorted_list(table): +// items = sorted(table.items(), key=lambda x:x[0]) +// keys = [""]*len(items) +// values = [""]*len(items) +// for i in xrange(len(items)): +// item = items[i] +// keys[i] = item[0] +// values[i] = item[1] +// return (keys, values) +// +// // Given a sorted list in the tuple form (KEYS, VALUES), look up the item KEY. +// // If found, return the corresponding value; else return None. +// +// def lookup_sorted_list(sorted_list, key, default=None): +// (keys, values) = sorted_list +// i = bisect.bisect_left(keys, key) +// if i != len(keys) and keys[i] == key: +// return values[i] +// return default +// +// // A class that provides a dictionary-compatible interface to a sorted list +// +// class SortedList(object, UserDict.DictMixin): +// def __init__(self, table): +// self.sorted_list = make_sorted_list(table) +// +// def __len__(self): +// return len(self.sorted_list[0]) +// +// def __getitem__(self, key): +// retval = lookup_sorted_list(self.sorted_list, key) +// if retval is None: +// raise KeyError(key) +// return retval +// +// def __contains__(self, key): +// return lookup_sorted_list(self.sorted_list, key) is not None +// +// def __iter__(self): +// (keys, values) = self.sorted_list +// for x in keys: +// yield x +// +// def keys(self): +// return self.sorted_list[0] +// +// def itervalues(self): +// (keys, values) = self.sorted_list +// for x in values: +// yield x +// +// def iteritems(self): +// (keys, values) = self.sorted_list +// for (key, value) in izip(keys, values): +// yield (key, value) +// + +// ////////////////////////////////////////////////////////////////////////////// +// // Priority Queues // +// ////////////////////////////////////////////////////////////////////////////// +// +// // Priority queue implementation, based on Python heapq documentation. +// // Note that in Python 2.6 and on, there is a priority queue implementation +// // in the Queue module. +// class PriorityQueue(object): +// INVALID = 0 // mark an entry as deleted +// +// def __init__(self): +// self.pq = [] // the priority queue list +// self.counter = itertools.count(1) // unique sequence count +// self.task_finder = {} // mapping of tasks to entries +// +// def add_task(self, priority, task, count=None): +// if count is None: +// count = self.counter.next() +// entry = [priority, count, task] +// self.task_finder[task] = entry +// heappush(self.pq, entry) +// +// //Return the top-priority task. If 'return_priority' is false, just +// //return the task itself; otherwise, return a tuple (task, priority). +// def get_top_priority(self, return_priority=false): +// while true: +// priority, count, task = heappop(self.pq) +// if count is not PriorityQueue.INVALID: +// del self.task_finder[task] +// if return_priority: +// return (task, priority) +// else: +// return task +// +// def delete_task(self, task): +// entry = self.task_finder[task] +// entry[1] = PriorityQueue.INVALID +// +// def reprioritize(self, priority, task): +// entry = self.task_finder[task] +// self.add_task(priority, task, entry[1]) +// entry[1] = PriorityQueue.INVALID +// + /////////////////////////////////////////////////////////////////////////// + // Least-recently-used (LRU) Caches // + /////////////////////////////////////////////////////////////////////////// + + class LRUCache[T,U](maxsize: Int=1000) extends mutable.Map[T,U] + with mutable.MapLike[T,U,LRUCache[T,U]] { + val cache = mutable.LinkedHashMap[T,U]() + + // def length = return cache.length + + private def reprioritize(key: T) { + val value = cache(key) + cache -= key + cache(key) = value + } + + def get(key: T): Option[U] = { + if (cache contains key) { + reprioritize(key) + Some(cache(key)) + } + else None + } + + override def update(key: T, value: U) { + if (cache contains key) + reprioritize(key) + else { + while (cache.size >= maxsize) { + val (key2, value) = cache.head + cache -= key2 + } + cache(key) = value + } + } + + override def remove(key: T): Option[U] = cache.remove(key) + + def iterator: Iterator[(T, U)] = cache.iterator + + // All the rest Looks like pure boilerplate! Why necessary? + def += (kv: (T, U)): this.type = { + update(kv._1, kv._2); this } + def -= (key: T): this.type = { remove(key); this } + + override def empty = new LRUCache[T,U]() + } + + // This whole object looks like boilerplate! Why necessary? + object LRUCache extends { + def empty[T,U] = new LRUCache[T,U]() + + def apply[T,U](kvs: (T,U)*): LRUCache[T,U] = { + val m: LRUCache[T,U] = empty + for (kv <- kvs) m += kv + m + } + + def newBuilder[T,U]: Builder[(T,U), LRUCache[T,U]] = + new MapBuilder[T, U, LRUCache[T,U]](empty) + + implicit def canBuildFrom[T,U] + : CanBuildFrom[LRUCache[T,U], (T,U), LRUCache[T,U]] = + new CanBuildFrom[LRUCache[T,U], (T,U), LRUCache[T,U]] { + def apply(from: LRUCache[T,U]) = newBuilder[T,U] + def apply() = newBuilder[T,U] + } + } + + //////////////////////////////////////////////////////////////////////////// + // Hash tables by range // + //////////////////////////////////////////////////////////////////////////// + + // A table that groups all keys in a specific range together. Instead of + // directly storing the values for a group of keys, we store an object (termed a + // "collector") that the user can use to keep track of the keys and values. + // This way, the user can choose to use a list of values, a set of values, a + // table of keys and values, etc. + + // 'ranges' is a sorted list of numbers, indicating the + // boundaries of the ranges. One range includes all keys that are + // numerically below the first number, one range includes all keys that are + // at or above the last number, and there is a range going from each number + // up to, but not including, the next number. 'collector' is used to create + // the collectors used to keep track of keys and values within each range; + // it is either a type or a no-argument factory function. We only create + // ranges and collectors as needed. 'lowest_bound' is the value of the + // lower bound of the lowest range; default is 0. This is used only + // it iter_ranges() when returning the lower bound of the lowest range, + // and can be an item of any type, e.g. the number 0, the string "-infinity", + // etc. + abstract class TableByRange[Coll,Numtype <% Ordered[Numtype]]( + ranges: Seq[Numtype], + create: (Numtype)=>Coll + ) { + val min_value: Numtype + val max_value: Numtype + val items_by_range = mutable.Map[Numtype,Coll]() + var seen_negative = false + + def get_collector(key: Numtype) = { + // This somewhat scary-looking cast produces 0 for Int and 0.0 for + // Double. If you write it as 0.asInstanceOf[Numtype], you get a + // class-cast error when < is called if Numtype is Double because the + // result of the cast ends up being a java.lang.Integer which can't + // be cast to java.lang.Double. (FMH!!!) + if (key < null.asInstanceOf[Numtype]) + seen_negative = true + var lower_range = min_value + // upper_range = "infinity" + breakable { + for (i <- ranges) { + if (i <= key) + lower_range = i + else { + // upper_range = i + break + } + } + } + if (!(items_by_range contains lower_range)) + items_by_range(lower_range) = create(lower_range) + items_by_range(lower_range) + } + + /** + Return an iterator over ranges in the table. Each returned value is + a tuple (LOWER, UPPER, COLLECTOR), giving the lower and upper bounds + (inclusive and exclusive, respectively), and the collector item for this + range. The lower bound of the lowest range comes from the value of + 'lowest_bound' specified during creation, and the upper bound of the range + that is higher than any numbers specified during creation in the 'ranges' + list will be the string "infinity" if such a range is returned. + + The optional arguments 'unseen_between' and 'unseen_all' control the + behavior of this iterator with respect to ranges that have never been seen + (i.e. no keys in this range have been passed to 'get_collector'). If + 'unseen_all' is true, all such ranges will be returned; else if + 'unseen_between' is true, only ranges between the lowest and highest + actually-seen ranges will be returned. + */ + def iter_ranges(unseen_between: Boolean=true, unseen_all: Boolean=false) = { + var highest_seen: Numtype = 0.asInstanceOf[Numtype] + val iteration_range = + (List(if (seen_negative) min_value else 0.asInstanceOf[Numtype]) ++ + ranges) zip + (ranges ++ List(max_value)) + for ((lower, upper) <- iteration_range) { + if (items_by_range contains lower) + highest_seen = upper + } + + var seen_any = false + for {(lower, upper) <- iteration_range + // FIXME SCALABUG: This is a bug in Scala if I have to do this + val collector = items_by_range.getOrElse(lower, null.asInstanceOf[Coll]) + if (collector != null || unseen_all || + (unseen_between && seen_any && + upper != max_value && upper <= highest_seen)) + val col2 = if (collector != null) collector else create(lower) + } + yield { + if (collector != null) seen_any = true + (lower, upper, col2) + } + } + } + + class IntTableByRange[Coll]( + ranges: Seq[Int], + create: (Int)=>Coll + ) extends TableByRange[Coll,Int](ranges, create) { + val min_value = java.lang.Integer.MIN_VALUE + val max_value = java.lang.Integer.MAX_VALUE + } + + class DoubleTableByRange[Coll]( + ranges: Seq[Double], + create: (Double)=>Coll + ) extends TableByRange[Coll,Double](ranges, create) { + val min_value = java.lang.Double.NEGATIVE_INFINITY + val max_value = java.lang.Double.POSITIVE_INFINITY + } + + +// ////////////////////////////////////////////////////////////////////////////// +// // Depth-, breadth-first search // +// ////////////////////////////////////////////////////////////////////////////// +// +// // General depth-first search. 'node' is the node to search, the top of a +// // tree. 'matches' indicates whether a given node matches. 'children' +// // returns a list of child nodes. +// def depth_first_search(node, matches, children): +// nodelist = [node] +// while len(nodelist) > 0: +// node = nodelist.pop() +// if matches(node): +// yield node +// nodelist.extend(reversed(children(node))) +// +// // General breadth-first search. 'node' is the node to search, the top of a +// // tree. 'matches' indicates whether a given node matches. 'children' +// // returns a list of child nodes. +// def breadth_first_search(node, matches, children): +// nodelist = deque([node]) +// while len(nodelist) > 0: +// node = nodelist.popLeft() +// if matches(node): +// yield node +// nodelist.extend(children(node)) +// + + ///////////////////////////////////////////////////////////////////////////// + // Misc. list/iterator functions // + ///////////////////////////////////////////////////////////////////////////// + + def fromto(from: Int, too: Int) = { + if (from <= too) (from to too) + else (too to from) + } + + // Return an iterator over all elements in all the given sequences, omitting + // elements seen more than once and keeping the order. + def merge_numbered_sequences_uniquely[A, B](seqs: Iterable[(A, B)]*) = { + val keys_seen = mutable.Set[A]() + for { + seq <- seqs + (s, vall) <- seq + if (!(keys_seen contains s)) + } yield { + keys_seen += s + (s, vall) + } + } + + /** + * Combine two maps, adding up the numbers where overlap occurs. + */ + def combine_maps[T, U <: Int](map1: Map[T, U], map2: Map[T, U]) = { + /* We need to iterate over one of the maps and add each element to the + other map, checking first to see if it already exists. Make sure + to iterate over the smallest map, so that repeated combination of + maps will have O(n) rather than worst-case O(N^2). */ + if (map1.size > map2.size) + map1 ++ map2.map { case (k,v) => k -> (v + map1.getOrElse(k,0)) } + else + map2 ++ map1.map { case (k,v) => k -> (v + map2.getOrElse(k,0)) } + } + + /** + * Convert a list of items to a map counting how many of each item occurs. + */ + def list_to_item_count_map[T : Ordering](list: Seq[T]) = + list.sorted groupBy identity mapValues (_.size) +} + diff --git a/src/main/scala/opennlp/fieldspring/util/distances.scala b/src/main/scala/opennlp/fieldspring/util/distances.scala new file mode 100644 index 0000000..17e6ab3 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/distances.scala @@ -0,0 +1,322 @@ +/////////////////////////////////////////////////////////////////////////////// +// distances.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import math._ + +import printutil.warning +import mathutil.MeanShift + +/* + The coordinates of a point are spherical coordinates, indicating a + latitude and longitude. Latitude ranges from -90 degrees (south) to + +90 degrees (north), with the Equator at 0 degrees. Longitude ranges + from -180 degrees (west) to +179.9999999.... degrees (east). -180 and +180 + degrees correspond to the same north-south parallel, and we arbitrarily + choose -180 degrees over +180 degrees. 0 degrees longitude has been + arbitrarily chosen as the north-south parallel that passes through + Greenwich, England (near London). Note that longitude wraps around, but + latitude does not. Furthermore, the distance between latitude lines is + always the same (about 111 km per degree, or about 69 miles per degree), + but the distance between longitude lines varies according to the + latitude, ranging from about 111 km per degree at the Equator to 0 km + at the North and South Pole. +*/ + +/** + Singleton object holding information of various sorts related to distances + on the Earth and coordinates, objects for handling coordinates and cell + indices, and miscellaneous functions for computing distance and converting + between different coordinate formats. + + The following is contained: + + 1. Fixed information: e.g. radius of Earth in kilometers (km), number of + km per degree at the Equator, number of km per mile, minimum/maximum + latitude/longitude. + + 2. The SphereCoord class (holding a latitude/longitude pair) + + 3. Function spheredist() to compute spherical (great-circle) distance + between two SphereCoords; likewise degree_dist() to compute degree + distance between two SphereCoords + */ +package object distances { + + /***** Fixed values *****/ + + val minimum_latitude = -90.0 + val maximum_latitude = 90.0 + val minimum_longitude = -180.0 + val maximum_longitude = 180.0 - 1e-10 + + // Radius of the earth in km. Used to compute spherical distance in km, + // and km per degree of latitude/longitude. + // val earth_radius_in_miles = 3963.191 + val earth_radius_in_km = 6376.774 + + // Number of kilometers per mile. + val km_per_mile = 1.609 + + // Number of km per degree, at the equator. For longitude, this is the + // same everywhere, but for latitude it is proportional to the degrees away + // from the equator. + val km_per_degree = Pi * 2 * earth_radius_in_km / 360. + + // Number of miles per degree, at the equator. + val miles_per_degree = km_per_degree / km_per_mile + + def km_and_miles(kmdist: Double) = { + "%.2f km (%.2f miles)" format (kmdist, kmdist / km_per_mile) + } + + // A 2-dimensional coordinate. + // + // The following fields are defined: + // + // lat, long: Latitude and longitude of coordinate. + + case class SphereCoord(lat: Double, long: Double) { + // Not sure why this code was implemented with coerce_within_bounds, + // but either always coerce, or check the bounds ... + require(SphereCoord.valid(lat, long)) + override def toString() = "(%.2f,%.2f)".format(lat, long) + } + + implicit object SphereCoord extends Serializer[SphereCoord] { + // Create a coord, with METHOD defining how to handle coordinates + // out of bounds. If METHOD = "validate", check within bounds, + // and abort if not. If "coerce", coerce within bounds (latitudes + // are cropped, longitudes are taken mod 360). If "coerce-warn", + // same as "coerce" but also issue a warning when coordinates are + // out of bounds. + def apply(lat: Double, long: Double, method: String) = { + val (newlat, newlong) = + method match { + case "coerce-warn" => { + if (!valid(lat, long)) + warning("Coordinates out of bounds: (%.2f,%.2f)", lat, long) + coerce(lat, long) + } + case "coerce" => coerce(lat, long) + case "validate" => (lat, long) + case _ => { + require(false, + "Invalid method to SphereCoord(): %s" format method) + (0.0, 0.0) + } + } + new SphereCoord(newlat, newlong) + } + + def valid(lat: Double, long: Double) = ( + lat >= minimum_latitude && + lat <= maximum_latitude && + long >= minimum_longitude && + long <= maximum_longitude + ) + + def coerce(lat: Double, long: Double) = { + var newlat = lat + var newlong = long + if (newlat > maximum_latitude) newlat = maximum_latitude + while (newlong > maximum_longitude) newlong -= 360. + if (newlat < minimum_latitude) newlat = minimum_latitude + while (newlong < minimum_longitude) newlong += 360. + (newlat, newlong) + } + + def deserialize(foo: String) = { + val Array(lat, long) = foo.split(",", -1) + SphereCoord(lat.toDouble, long.toDouble) + } + + def serialize(foo: SphereCoord) = "%s,%s".format(foo.lat, foo.long) + } + + // Compute spherical distance in km (along a great circle) between two + // coordinates. + + def spheredist(p1: SphereCoord, p2: SphereCoord): Double = { + if (p1 == null || p2 == null) return 1000000. + val thisRadLat = (p1.lat / 180.) * Pi + val thisRadLong = (p1.long / 180.) * Pi + val otherRadLat = (p2.lat / 180.) * Pi + val otherRadLong = (p2.long / 180.) * Pi + + val anglecos = (sin(thisRadLat)*sin(otherRadLat) + + cos(thisRadLat)*cos(otherRadLat)* + cos(otherRadLong-thisRadLong)) + // If the values are extremely close to each other, the resulting cosine + // value will be extremely close to 1. In reality, however, if the values + // are too close (e.g. the same), the computed cosine will be slightly + // above 1, and acos() will complain. So special-case this. + if (abs(anglecos) > 1.0) { + if (abs(anglecos) > 1.000001) { + warning("Something wrong in computation of spherical distance, out-of-range cosine value %f", + anglecos) + return 1000000. + } else + return 0. + } + return earth_radius_in_km * acos(anglecos) + } + + def degree_dist(c1: SphereCoord, c2: SphereCoord) = { + sqrt((c1.lat - c2.lat) * (c1.lat - c2.lat) + + (c1.long - c2.long) * (c1.long - c2.long)) + } + + /** + * Square area in km^2 of a rectangle on the surface of a sphere made up + * of latitude and longitude lines. (Although the parameters below are + * described as bottom-left and top-right, respectively, the function as + * written is in fact insensitive to whether bottom-left/top-right or + * top-left/bottom-right pairs are given, and which order they are + * given. All that matters is that opposite corners are supplied. The + * use of `abs` below takes care of this.) + * + * @param botleft Coordinate of bottom left of rectangle + * @param topright Coordinate of top right of rectangle + */ + def square_area(botleft: SphereCoord, topright: SphereCoord) = { + var (lat1, lon1) = (botleft.lat, botleft.long) + var (lat2, lon2) = (topright.lat, topright.long) + lat1 = (lat1 / 180.) * Pi + lat2 = (lat2 / 180.) * Pi + lon1 = (lon1 / 180.) * Pi + lon2 = (lon2 / 180.) * Pi + + (earth_radius_in_km * earth_radius_in_km) * + abs(sin(lat1) - sin(lat2)) * + abs(lon1 - lon2) + } + + /** + * Average two longitudes. This is a bit tricky because of the way + * they wrap around. + */ + def average_longitudes(long1: Double, long2: Double): Double = { + if (long1 - long2 > 180.) + average_longitudes(long1 - 360., long2) + else if (long2 - long1 > 180.) + average_longitudes(long1, long2 - 360.) + else + (long1 + long2) / 2.0 + } + + class SphereMeanShift( + h: Double = 1.0, + max_stddev: Double = 1e-10, + max_iterations: Int = 100 + ) extends MeanShift[SphereCoord](h, max_stddev, max_iterations) { + def squared_distance(x: SphereCoord, y:SphereCoord) = { + val dist = spheredist(x, y) + dist * dist + } + + def weighted_sum(weights:Array[Double], points:Array[SphereCoord]) = { + val len = weights.length + var lat = 0.0 + var long = 0.0 + for (i <- 0 until len) { + val w = weights(i) + val c = points(i) + lat += c.lat * w + long += c.long * w + } + SphereCoord(lat, long) + } + + def scaled_sum(scalar:Double, points:Array[SphereCoord]) = { + var lat = 0.0 + var long = 0.0 + for (c <- points) { + lat += c.lat * scalar + long += c.long * scalar + } + SphereCoord(lat, long) + } + } + + // A 1-dimensional time coordinate (Epoch time, i.e. elapsed time since the + // Jan 1, 1970 Unix Epoch, in milliseconds). The use of a 64-bit long to + // represent Epoch time in milliseconds is common in Java and also used in + // Twitter (gives you over 300,000 years). An alternative is to use a + // double to represent seconds, which gets you approximately the same + // accuracy -- 52 bits of mantissa to represent any integer <= 2^52 + // exactly, similar to the approximately 53 bits worth of seconds you + // get when using milliseconds. + // + // Note that having this here is an important check on the correctness + // of the code elsewhere -- if by mistake you leave off the type + // parameters when calling a function, and the function asks for a type + // with a serializer and there's only one such type available, Scala + // automatically uses that one type. Hence code may work fine until you + // add a second serializable type, and then lots of compile errors. + + case class TimeCoord(millis: Long) { + override def toString() = "%s (%s)" format (millis, format_time(millis)) + } + + implicit object TimeCoord extends Serializer[TimeCoord] { + def deserialize(foo: String) = TimeCoord(foo.toLong) + def serialize(foo: TimeCoord) = "%s".format(foo.millis) + } + + /** + * Convert a time in milliseconds since the Epoch into a more familiar + * format, e.g. Wed Jun 27 03:49:08 EDT 2012 in place of 1340783348365. + */ + def format_time(millis: Long) = new java.util.Date(millis).toString + + /** + * Convert an interval in milliseconds into a more familiar format, e.g. + * 3m5s in place of 185000. + */ + def format_interval(millis: Long) = { + val sec_part_as_milli = millis % 60000 + val sec_part = sec_part_as_milli / 1000.0 + val truncated_mins = millis / 60000 + val min_part = truncated_mins % 60 + val truncated_hours = truncated_mins / 60 + val hour_part = truncated_hours % 24 + val truncated_days = truncated_hours / 24 + var res = "" + if (res.length > 0 || truncated_days > 0) + res += " " + truncated_days + "days" + if (res.length > 0 || hour_part > 0) + res += " " + hour_part + "hr" + if (res.length > 0 || min_part > 0) + res += " " + min_part + "min" + if (res.length > 0 || sec_part > 0) { + val int_sec_part = sec_part.toInt + val secstr = + if (int_sec_part == sec_part) + int_sec_part + else + "%.2f" format sec_part + res += " " + secstr + "sec" + } + if (res.length > 0 && res(0) == ' ') + res.tail + else + res + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/experiment.scala b/src/main/scala/opennlp/fieldspring/util/experiment.scala new file mode 100644 index 0000000..3c8a3fe --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/experiment.scala @@ -0,0 +1,720 @@ +/////////////////////////////////////////////////////////////////////////////// +// experiment.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.collection.mutable + +import argparser._ +import collectionutil._ +import ioutil.{FileHandler, LocalFileHandler} +import osutil._ +import printutil.{errprint, set_stdout_stderr_utf_8} +import textutil._ +import timeutil.format_minutes_seconds + +package object experiment { + /** + * A general experiment driver class for programmatic access to a program + * that runs experiments. + * + * Basic operation: + * + * 1. Create an instance of a class of type TParam (which is determined + * by the particular driver implementation) and populate it with the + * appropriate parameters. + * 2. Call `run()`, passing in the parameter object created in the previous + * step. The return value (of type TRunRes, again determined by + * the particular implementation) contains the results. + * + * NOTE: Some driver implementations may change the values of some of the + * parameters recorded in the parameter object (particularly to + * canonicalize them). + * + * Note that `run()` is actually a convenience method that does three steps: + * + * 1. `set_parameters`, which notes the parameters passed in and verifies + * that their values are good. + * 2. `setup_for_run`, which does any internal setup necessary for + * running the experiment (e.g. reading files, creating internal + * structures). + * 3. `run_after_setup`, which executes the experiment. + * + * These three steps have been split out because some applications may need + * to access each step separately, or override some but not others. For + * example, to add Hadoop support to an application, `run_after_setup` + * might need to be replaced with a different implementation based on the + * MapReduce framework, while the other steps might stay more or less the + * same. + */ + + trait ExperimentDriver { + type TParam + type TRunRes + var params: TParam = _ + + /** + * Signal a parameter error. + */ + def param_error(string: String) { + throw new IllegalArgumentException(string) + } + + protected def param_needed(param: String, param_english: String = null) { + val mparam_english = + if (param_english == null) + param.replace("-", " ") + else + param_english + param_error("Must specify %s using --%s" format + (mparam_english, param.replace("_", "-"))) + } + + protected def need_seq(value: Seq[String], param: String, + param_english: String = null) { + if (value.length == 0) + param_needed(param, param_english) + } + + protected def need(value: String, param: String, + param_english: String = null) { + if (value == null || value.length == 0) + param_needed(param, param_english) + } + + def set_parameters(params: TParam) { + this.params = params + handle_parameters() + } + + def run(params: TParam) = { + set_parameters(params) + setup_for_run() + run_after_setup() + } + + def heartbeat() { + } + + /********************************************************************/ + /* Function to override below this line */ + /********************************************************************/ + + /** + * Verify and canonicalize the parameters passed in. Retrieve any other + * parameters from the environment. NOTE: Currently, some of the + * fields in this structure will be changed (canonicalized). See above. + * If parameter values are illegal, an error will be signaled. + */ + + protected def handle_parameters() + + /** + * Do any setup before actually implementing the experiment. This + * may mean, for example, loading files and creating any needed + * structures. + */ + + def setup_for_run() + + /** + * Actually run the experiment. We have separated out the run process + * into three steps because we might want to replace one of the + * components in a sub-implementation of an experiment. (For example, + * if we implement a Hadoop version of an experiment, typically we need + * to replace the run_after_setup component but leave the others.) + */ + + def run_after_setup(): TRunRes + } + + /** + * A mix-in that adds to a driver the ability to record statistics + * (specifically, counters that can be incremented and queried) about + * the experiment run. These counters may be tracked globally across + * a set of separate tasks, e.g. in Hadoop where multiple separate tasks + * may be run in parallel of different machines to completely a global + * job. Counters can be tracked both globally and per-task. + */ + abstract trait ExperimentDriverStats { + /** Set of counters under the given group, using fully-qualified names. */ + protected val counters_by_group = setmap[String,String]() + /** Set of counter groups under the given group, using fully-qualified + * names. */ + protected val counter_groups_by_group = setmap[String,String]() + /** Set of all counters seen. */ + protected val counters_by_name = mutable.Set[String]() + + protected def local_to_full_name(name: String) = + if (name == "") + local_counter_group + else + local_counter_group + "." + name + + /** + * Note that a counter of the given name exists. It's not necessary to + * do this before calling `increment_counter` or `get_counter` because + * unseen counters default to 0, but doing so ensures that the counter + * is seen when we enumerate counters even if it has never been incremented. + */ + def note_counter(name: String) { + if (!(counters_by_name contains name)) { + counters_by_name += name + note_counter_by_group(name, true) + } + } + + protected def split_counter(name: String) = { + val lastsep = name.lastIndexOf('.') + if (lastsep < 0) + ("", name) + else + (name.slice(0, lastsep), name.slice(lastsep + 1, name.length)) + } + + /** + * Note each part of the name in the next-higher group. For example, + * for a name "io.sort.spill.percent", we note the counter "percent" in + * group "io.sort.spill", but also the group "spill" in the group "io.sort", + * the group "sort" in the group "io", and the group "io" in the group "". + * + * @param name Name of the counter + * @param is_counter Whether this is an actual counter, or a group. + */ + private def note_counter_by_group(name: String, is_counter: Boolean) { + val (group, _) = split_counter(name) + if (is_counter) + counters_by_group(group) += name + else + counter_groups_by_group(group) += name + if (group != "") + note_counter_by_group(group, false) + } + + /** + * Enumerate all the counters in the given group, for counters + * which have been either incremented or noted (using `note_counter`). + * + * @param group Group to list counters under. Can be set to the + * empty string ("") to list from the top level. + * @param recursive If true, search recursively under the group; + * otherwise, only list those at the level immediately under the group. + * @param fully_qualified If true (the default), all counters returned + * are fully-qualified, i.e. including the entire counter name. + * Otherwise, only the portion after `group` is returned. + * @return An iterator over counters + */ + def list_counters(group: String, recursive: Boolean, + fully_qualified: Boolean = true) = { + val groups = Iterable(group) + val subgroups = + if (!recursive) + Iterable[String]() + else + list_counter_groups(group, true) + val fq = (groups.view ++ subgroups) flatMap (counters_by_group(_)) + if (fully_qualified) fq + else fq map (_.stripPrefix(group + ".")) + } + + def list_local_counters(group: String, recursive: Boolean, + fully_qualified: Boolean = true) = + (list_counters(local_to_full_name(group), recursive, fully_qualified) map + (_.stripPrefix(local_counter_group + "."))) + + /** + * Enumerate all the counter groups in the given group, for counters + * which have been either incremented or noted (using `note_counter`). + * + * @param group Group to list counter groups under. Can be set to the + * empty string ("") to list from the top level. + * @param recursive If true, search recursively under the group; + * otherwise, only list those at the level immediately under the group. + * @param fully_qualified If true (the default), all counter groups returned + * are fully-qualified, i.e. including the entire counter name. + * Otherwise, only the portion after `group` is returned. + * @return An iterator over counter groups + */ + def list_counter_groups(group: String, recursive: Boolean, + fully_qualified: Boolean = true): Iterable[String] = { + val groups = counter_groups_by_group(group).view + val fq = + if (!recursive) + groups + else + groups ++ (groups flatMap (list_counter_groups(_, true))) + if (fully_qualified) fq + else fq map (_.stripPrefix(group + ".")) + } + + def list_local_counter_groups(group: String, recursive: Boolean, + fully_qualified: Boolean = true) = + (list_counter_groups(local_to_full_name(group), recursive, + fully_qualified) map (_.stripPrefix(local_counter_group + "."))) + + /** + * Increment the given local counter by 1. Local counters are those + * specific to this application rather than counters set by the overall + * framework. (The difference is that local counters are placed in their + * own group, as specified by `local_counter_group`.) + */ + def increment_local_counter(name: String) { + increment_local_counter(name, 1) + } + + /** + * Increment the given local counter by the given value. Local counters + * are those specific to this application rather than counters set by the + * overall framework. (The difference is that local counters are placed in + * their own group, as specified by `local_counter_group`.) + */ + def increment_local_counter(name: String, byvalue: Long) { + increment_counter(local_to_full_name(name), byvalue) + } + + /** + * Increment the given fully-named counter by 1. Global counters are those + * provided by the overall framework rather than specific to this + * application; hence there is rarely cause for changing them. + */ + def increment_counter(name: String) { + increment_counter(name, 1) + } + + /** + * Increment the given fully-named counter by the given value. + * + * @see increment_local_counter + */ + def increment_counter(name: String, byvalue: Long) { + note_counter(name) + imp_increment_counter(name, byvalue) + } + + /** + * Return the value of the given local counter. + * + * See `increment_local_counter` for a discussion of local counters, + * and `get_counter` for a discussion of caveats in a multi-process + * environment. + */ + def get_local_counter(name: String) = + get_counter(local_to_full_name(name)) + + /** + * Return the value of the given fully-named counter. Note: When operating + * in a multi-threaded or multi-process environment, it cannot be guaranteed + * that this value is up-to-date. This is especially the case e.g. when + * operating in Hadoop, where counters are maintained globally across + * tasks which are running on different machines on the network. + * + * @see get_local_counter + */ + def get_counter(name: String) = imp_get_counter(name) + + def construct_task_counter_name(name: String) = + "bytask." + get_task_id + "." + name + + def increment_task_counter(name: String, byvalue: Long = 1) { + increment_local_counter(construct_task_counter_name(name), byvalue) + } + + def get_task_counter(name: String) = { + get_local_counter(construct_task_counter_name(name)) + } + + /** + * A mechanism for wrapping task counters so that they can be stored + * in variables and incremented simply using +=. Note that access to + * them still needs to go through `value`, unfortunately. (Even marking + * `value` as implicit isn't enough as the function won't get invoked + * unless we're in an environment requiring an integral value. This means + * it won't get invoked in print statements, variable assignments, etc.) + * + * It would be nice to move this elsewhere; we'd have to pass in + * `driver_stats`, though. + * + * NOTE: The counters are task-specific because currently each task + * reads the entire set of training documents into memory. We could avoid + * this by splitting the tasks so that each task is commissioned to + * run over a specific portion of the Earth rather than a specific + * set of test documents. Note that if we further split things so that + * each task handled both a portion of test documents and a portion of + * the Earth, it would be somewhat trickier, depending on exactly how + * we write the code -- for a given set of test documents, different + * portions of the Earth would be reading in different training documents, + * so we'd presumably want their counts to add; but we might not want + * all counts to add. + */ + + class TaskCounterWrapper(name: String) { + def value = get_task_counter(name) + + def +=(incr: Long) { + increment_task_counter(name, incr) + } + } + + def create_counter_wrapper(prefix: String, split: String) = + new TaskCounterWrapper(prefix + "." + split) + + def countermap(prefix: String) = + new SettingDefaultHashMap[String, TaskCounterWrapper]( + create_counter_wrapper(prefix, _)) + + /******************* Override/implement below this line **************/ + + /** + * Group that local counters are placed in. + */ + val local_counter_group = "fieldspring" + + /** + * Return ID of current task. This is used for Hadoop or similar, to handle + * operations that may be run multiple times per task in an overall job. + * In a standalone environment, this should always return the same value. + */ + def get_task_id: Int + + /** + * Underlying implementation to increment the given counter by the + * given value. + */ + protected def imp_increment_counter(name: String, byvalue: Long) + + /** + * Underlying implementation to return the value of the given counter. + */ + protected def imp_get_counter(name: String): Long + } + + /** + * Implementation of driver-statistics mix-in that simply stores the + * counters locally. + */ + trait StandaloneExperimentDriverStats extends ExperimentDriverStats { + val counter_values = longmap[String]() + + def get_task_id = 0 + + protected def imp_increment_counter(name: String, incr: Long) { + counter_values(name) += incr + } + + protected def imp_get_counter(name: String) = counter_values(name) + } + + /** + * A general main application class for use in an application that performs + * experiments (currently used mostly for experiments in NLP -- i.e. natural + * language processing -- but not limited to this). The program is assumed + * to have various parameters controlling its operation, some of which come + * from command-line arguments, some from constants hard-coded into the source + * code, some from environment variables or configuration files, some computed + * from other parameters, etc. We divide them into two types: command-line + * (those coming from the command line) and ancillary (from any other source). + * Both types of parameters are output at the beginning of execution so that + * the researcher can see exactly which parameters this particular experiment + * was run with. + * + * NOTE: Although in common parlance the terms "argument" and "parameter" + * are often synonymous, we make a clear distinction between the two. + * We speak of "parameters" in general when referring to general settings + * that control the operation of a program, and "command-line arguments" + * when referring specifically to parameter setting controlled by arguments + * specified in the command-line invocation of a program. The latter can + * be either "options" (specified as e.g. '--outfile myfile.txt') or + * "positional arguments" (e.g. the arguments 'somefile.txt' and 'myfile.txt' + * in the command 'cp somefile.txt myfile.txt'). Although command-line + * arguments are one way of specifying parameters, parameters could also + * come from environment variables, from a file containing program settings, + * from the arguments to a function if the program is invoked through a + * function call, etc. + * + * The general operation of this class is as follows: + * + * (1) The command line is passed in, and we parse it. Command-line parsing + * uses the ArgParser class. We use "field-style" access to the + * arguments retrieved from the command line; this means that there is + * a separate class taking the ArgParser as a construction parameter, + * and containing vars, one per argument, each initialized using a + * method call on the ArgParser that sets up all the features of that + * particular argument. + * (2) Application verifies the parsed arguments passed in and initializes + * its parameters based on the parsed arguments and possibly other + * sources. + * (3) Application is run. + * + * A particular application customizes things as follows: + * + * (1) Consistent with "field-style" access, it creates a class that will + * hold the user-specified values of command-line arguments, and also + * initializes the ArgParser with the list of allowable arguments, + * types, default values, etc. `TParam` is the type of this class, + * and `create_param_object` must be implemented to create an instance + * of this class. + * (2) `initialize_parameters` must be implemented to handle validation + * of the command-line arguments and retrieval of any other parameters. + * (3) `run_program` must, of course, be implemented, to provide the + * actual behavior of the application. + */ + + /* SCALABUG: If this param isn't declared with a 'val', we get an error + below on the line creating ArgParser when trying to access progname, + saying "no such field". */ + abstract class ExperimentApp(val progname: String) { + val beginning_time = curtimesecs() + + // Things that must be implemented + + /** + * Class holding the declarations and received values of the command-line + * arguments. Needs to have an ArgParser object passed in to it, typically + * as a constructor parameter. + */ + type TParam + + /** + * Function to create an TParam, passing in the value of `arg_parser`. + */ + def create_param_object(ap: ArgParser): TParam + + /** + * Function to initialize and verify internal parameters from command-line + * arguments and other sources. + */ + def initialize_parameters() + + /** + * Function to run the actual app, after parameters have been set. + * @return Exit code of program (0 for successful completion, > 0 for + * an error + */ + def run_program(): Int + + // Things that may be overridden + + + /** + * Text describing the program, placed between the line beginning + * "Usage: ..." and the text describing the options and positional + * arguments. + */ + def description = "" + + /** + * Output the values of "ancillary" parameters (see above) + */ + def output_ancillary_parameters() {} + + /** + * Output the values of "command-line" parameters (see above) + */ + def output_command_line_parameters() { + errprint("") + errprint("Non-default parameter values:") + for (name <- arg_parser.argNames) { + if (arg_parser.specified(name)) + errprint("%30s: %s", name, arg_parser(name)) + } + errprint("") + errprint("Parameter values:") + for (name <- arg_parser.argNames) { + errprint("%30s: %s", name, arg_parser(name)) + //errprint("%30s: %s", name, arg_parser.getType(name)) + } + errprint("") + } + + /** + * An instance of ArgParser, for parsing options + */ + val arg_parser = + new ArgParser(progname, description = description) + + /** + * A class for holding the parameters retrieved from the command-line + * arguments and elsewhere ("ancillary parameters"; see above). Note + * that the parameters that originate in command-line arguments are + * also stored in the ArgParser object; this class provides easy access + * to those parameters through the preferred "field-style" paradigm + * of the ArgParser, and also holds the ancillary parameters (if any). + */ + var params: TParam = _ + + /** + * Code to implement main entrance point. We move this to a separate + * function because a subclass might want to wrap the program entrance/ + * exit (e.g. a Hadoop driver). + * + * @param args Command-line arguments, as specified by user + * @return Exit code, typically 0 for successful completion, + * positive for errorThis needs to be called explicitly from the + */ + def implement_main(args: Array[String]) = { + initialize_osutil() + set_stdout_stderr_utf_8() + errprint("Beginning operation at %s" format humandate_full(beginning_time)) + errprint("Arguments: %s" format (args mkString " ")) + val shadow_fields = create_param_object(arg_parser) + arg_parser.parse(args) + params = create_param_object(arg_parser) + initialize_parameters() + output_command_line_parameters() + output_ancillary_parameters() + val retval = run_program() + val ending_time = curtimesecs() + errprint("Ending operation at %s" format humandate_full(ending_time)) + errprint("Program running time: %s", + format_minutes_seconds(ending_time - beginning_time)) + retval + } + + /** + * Actual entrance point from the JVM. Does nothing but call + * `implement_main`, then call `System.exit()` with the returned exit code. + * @see #implement_main + */ + def main(args: Array[String]) { + val retval = implement_main(args) + System.exit(retval) + } + } + + /** + * An object holding parameters retrieved from command-line arguments + * using the `argparser` module, following the preferred style of that + * module. + * + * @param parser the ArgParser object that does the actual parsing of + * command-line arguments. + */ + + abstract class ArgParserParameters(val parser: ArgParser) { + } + + /** + * A version of ExperimentDriver where the parameters come from command-line + * arguments and parsing of these arguments is done using an `ArgParser` + * object (from the `argparser` module). The object holding the parameters + * must be a subclass of `ArgParserParameters`, which holds a reference to + * the `ArgParser` object that does the parsing. Parameter errors are + * redirected to the `ArgParser` error handler. It's also assumed that, + * in the general case, some of the parameters (so-called "ancillary + * parameters") may come from sources other than command-line arguments; + * in such a case, the function `output_ancillary_parameters` can be + * overridden to output the values of these ancillary parameters. + */ + + trait ArgParserExperimentDriver extends ExperimentDriver { + override type TParam <: ArgParserParameters + + override def param_error(string: String) = { + params.parser.error(string) + } + + /** + * Output the values of some internal parameters. Only needed + * for debugging. + */ + def output_ancillary_parameters() {} + } + + /** + * An extended version for use when both Hadoop and standalone versions of + * the project will be created. + */ + + trait HadoopableArgParserExperimentDriver extends + ArgParserExperimentDriver with ExperimentDriverStats { + /** + * FileHandler object for this driver. + */ + private val local_file_handler = new LocalFileHandler + + /** + * The file handler object for abstracting file access using either the + * Hadoop or regular Java API. By default, references the regular API, + * but can be overridden. + */ + def get_file_handler: FileHandler = local_file_handler + } + + /** + * A general implementation of the ExperimentApp class that uses an + * ArgParserExperimentDriver to do the actual work, so that both + * command-line and programmatic access to the experiment-running + * program is possible. + * + * Most concrete implementations will only need to implement `TDriver`, + * `create_driver` and `create_param_object`. (The latter two functions will + * be largely boilerplate, and are only needed at all because of type + * erasure in Java.) + * + * FIXME: It's not clear that the separation between ExperimentApp and + * ExperimentDriverApp is really worthwhile. If not, merge the two. + */ + + abstract class ExperimentDriverApp(appname: String) extends + ExperimentApp(appname) { + type TDriver <: ArgParserExperimentDriver + + val driver = create_driver() + type TParam = driver.TParam + + def create_driver(): TDriver + + override def output_ancillary_parameters() { + driver.output_ancillary_parameters() + } + + def initialize_parameters() { + driver.set_parameters(params) + } + + def run_program() = { + driver.setup_for_run() + driver.run_after_setup() + 0 + } + } + + /** + * An extension of the MeteredTask class that calls `driver.heartbeat` + * every time an item is processed or we otherwise do something, to let + * Hadoop know that we're actually making progress. + */ + class ExperimentMeteredTask( + driver: ExperimentDriver, + item_name: String, + verb: String, + secs_between_output: Double = 15, + maxtime: Double = 0.0 + ) extends MeteredTask(item_name, verb, secs_between_output, maxtime) { + driver.heartbeat() // Also overkill, again won't hurt. + override def item_processed() = { + driver.heartbeat() + super.item_processed() + } + override def finish() = { + // This is kind of overkill, but shouldn't hurt. + driver.heartbeat() + super.finish() + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/hadoop.scala b/src/main/scala/opennlp/fieldspring/util/hadoop.scala new file mode 100644 index 0000000..356519c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/hadoop.scala @@ -0,0 +1,430 @@ +/////////////////////////////////////////////////////////////////////////////// +// hadoop.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import org.apache.hadoop.io._ +import org.apache.hadoop.util._ +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.conf.{Configuration, Configured} +import org.apache.hadoop.fs._ + +// The following says to import everything except java.io.FileSystem, because +// it conflicts with Hadoop's FileSystem. (Technically, it imports everything +// but in the process aliases FileSystem to _, which has the effect of making +// it inaccessible. _ is special in Scala and has various meanings.) +import java.io.{FileSystem=>_,_} +import java.net.URI + +import opennlp.fieldspring.util.argparser._ +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.textdbutil._ +import opennlp.fieldspring.util.experiment._ +import opennlp.fieldspring.util.ioutil._ +import opennlp.fieldspring.util.printutil.{errprint, set_errout_prefix} + +package object hadoop { + class HadoopFileHandler(conf: Configuration) extends FileHandler { + protected def get_file_system(filename: String) = { + FileSystem.get(URI.create(filename), conf) + } + + protected def file_not_found(path: String) = + throw new FileNotFoundException( + "No such file or directory: %s" format path) + + def get_raw_input_stream(filename: String) = + get_file_system(filename).open(new Path(filename)) + + def get_raw_output_stream(filename: String, append: Boolean) = { + val fs = get_file_system(filename) + val path = new Path(filename) + if (append) + fs.append(path) + else + fs.create(path) + } + + def split_filename(filename: String) = { + val path = new Path(filename) + (path.getParent.toString, path.getName) + } + + def join_filename(dir: String, file: String) = + new Path(dir, file).toString + + def is_directory(filename: String) = { + val status = get_file_system(filename).getFileStatus(new Path(filename)) + if (status == null) + file_not_found(filename) + status.isDir + } + + def make_directories(filename: String):Boolean = + get_file_system(filename).mkdirs(new Path(filename)) + + def list_files(dir: String) = { + val status = get_file_system(dir).listStatus(new Path(dir)) + if (status == null) + file_not_found(dir) + for (file <- status) + yield file.getPath.toString + } + } + + object HadoopExperimentConfiguration { + /* Prefix used for storing parameters in a Hadoop configuration */ + val hadoop_conf_prefix = "fieldspring." + + /** + * Convert the parameters in `parser` to Hadoop configuration settings in + * `conf`. + * + * @param prefix Prefix to prepend to the names of all parameters. + * @param parser ArgParser object to retrieve parameters from. + * @param conf Configuration object to store configuration settings into. + */ + def convert_parameters_to_hadoop_conf(prefix: String, parser: ArgParser, + conf: Configuration) { + for (name <- parser.argNames if parser.specified(name)) { + val confname = prefix + name + parser(name) match { + case e:Int => conf.setInt(confname, e) + case e:Long => conf.setLong(confname, e) + case e:Float => conf.setFloat(confname, e) + case e:Double => conf.setFloat(confname, e.toFloat) + case e:String => conf.set(confname, e) + case e:Boolean => conf.setBoolean(confname, e) + case e:Seq[_] => { + val multitype = parser.getMultiType(name) + if (multitype == classOf[String]) { + conf.setStrings(confname, parser.get[Seq[String]](name): _*) + } else + throw new UnsupportedOperationException( + "Don't know how to store sequence of type %s of parameter %s into a Hadoop Configuration" + format (multitype, name)) + } + case ty@_ => { + throw new UnsupportedOperationException( + "Don't know how to store type %s of parameter %s into a Hadoop Configuration" + format (ty, name)) + } + } + } + } + + /** + * Convert the relevant Hadoop configuration settings in `conf` + * into the given ArgParser. + * + * @param prefix Prefix to prepend to the names of all parameters. + * @param parser ArgParser object to store parameters into. The names + * of parameters to fetch are taken from this object. + * @param conf Configuration object to retrieve settings from. + */ + def convert_parameters_from_hadoop_conf(prefix: String, parser: ArgParser, + conf: Configuration) { + // Configuration.dumpConfiguration(conf, new PrintWriter(System.err)) + for {name <- parser.argNames + confname = prefix + name + if conf.getRaw(confname) != null} { + val confname = prefix + name + val ty = parser.getType(name) + if (ty == classOf[Int]) + parser.set[Int](name, conf.getInt(confname, parser.defaultValue[Int](name))) + else if (ty == classOf[Long]) + parser.set[Long](name, conf.getLong(confname, parser.defaultValue[Long](name))) + else if (ty == classOf[Float]) + parser.set[Float](name, conf.getFloat(confname, parser.defaultValue[Float](name))) + else if (ty == classOf[Double]) + parser.set[Double](name, conf.getFloat(confname, parser.defaultValue[Double](name).toFloat).toDouble) + else if (ty == classOf[String]) + parser.set[String](name, conf.get(confname, parser.defaultValue[String](name))) + else if (ty == classOf[Boolean]) + parser.set[Boolean](name, conf.getBoolean(confname, parser.defaultValue[Boolean](name))) + else if (ty == classOf[Seq[_]]) { + val multitype = parser.getMultiType(name) + if (multitype == classOf[String]) + parser.set[Seq[String]](name, conf.getStrings(confname, parser.defaultValue[Seq[String]](name): _*).toSeq) + else + throw new UnsupportedOperationException( + "Don't know how to fetch sequence of type %s of parameter %s from a Hadoop Configuration" + format (multitype, name)) + } else { + throw new UnsupportedOperationException( + "Don't know how to store fetch %s of parameter %s from a Hadoop Configuration" + format (ty, name)) + } + } + } + } + + trait HadoopExperimentDriverApp extends ExperimentDriverApp { + var hadoop_conf: Configuration = _ + + override type TDriver <: HadoopExperimentDriver + + /* Set by subclass -- Initialize the various classes for map and reduce */ + def initialize_hadoop_classes(job: Job) + + /* Set by subclass -- Set the settings for reading appropriate input files, + possibly based on command line arguments */ + def initialize_hadoop_input(job: Job) + + /* Called after command-line arguments have been read, verified, + canonicalized and stored into `arg_parser`. We convert the arguments + into configuration variables in the Hadoop configuration -- this is + one way to get "side data" passed into a mapper, and is designed + exactly for things like command-line arguments. (For big chunks of + side data, it's better to use the Hadoop file system.) Then we + tell Hadoop about the classes used to do map and reduce by calling + initialize_hadoop_classes(), set the input and output files, and + actually run the job. + */ + override def run_program() = { + import HadoopExperimentConfiguration._ + convert_parameters_to_hadoop_conf(hadoop_conf_prefix, arg_parser, + hadoop_conf) + val job = new Job(hadoop_conf, progname) + /* We have to call set_job() here now, and not earlier. This is the + "bootstrapping issue" alluded to in the comments on + HadoopExperimentDriver. We can't set the Job until it's created, + and we can't create the Job until after we have set the appropriate + Fieldspring configuration parameters from the command-line arguments -- + but, we need the driver already created in order to parse the + command-line arguments, because it participates in that process. */ + driver.set_job(job) + initialize_hadoop_classes(job) + initialize_hadoop_input(job) + + if (job.waitForCompletion(true)) 0 else 1 + } + + class HadoopExperimentTool extends Configured with Tool { + override def run(args: Array[String]) = { + /* Set the Hadoop configuration object and then thread execution + back to the ExperimentApp. This will read command-line arguments, + call initialize_parameters() on GeolocateApp to verify + and canonicalize them, and then pass control back to us by + calling run_program(), which we override. */ + hadoop_conf = getConf() + set_errout_prefix(progname + ": ") + implement_main(args) + } + } + + override def main(args: Array[String]) { + val exitCode = ToolRunner.run(new HadoopExperimentTool(), args) + System.exit(exitCode) + } + } + + trait HadoopTextDBApp extends HadoopExperimentDriverApp { + def corpus_suffix: String + + def corpus_dirs: Iterable[String] + + def initialize_hadoop_input(job: Job) { + /* A very simple file processor that does nothing but note the files + seen, for Hadoop's benefit. */ + class RetrieveDocumentFilesFileProcessor( + suffix: String + ) extends TextDBLineProcessor[Unit](suffix) { + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + errprint("Called with %s", file) + FileInputFormat.addInputPath(job, new Path(file)) + (true, ()) + } + } + + val fileproc = new RetrieveDocumentFilesFileProcessor( + // driver.params.eval_set + "-" + driver.document_file_suffix + corpus_suffix + ) + fileproc.process_files(driver.get_file_handler, corpus_dirs) + // FileOutputFormat.setOutputPath(job, new Path(params.outfile)) + } + } + + /** + * Base mix-in for an Experiment application using Hadoop. + * + * @see HadoopExperimentDriver + */ + + trait BaseHadoopExperimentDriver extends + HadoopableArgParserExperimentDriver { + /** + * FileHandler object for this driver. + */ + private lazy val hadoop_file_handler = + new HadoopFileHandler(get_configuration) + + override def get_file_handler: FileHandler = hadoop_file_handler + + // override type TParam <: HadoopExperimentParameters + + /* Implementation of the driver statistics mix-in (ExperimentDriverStats) + that store counters in Hadoop. find_split_counter needs to be + implemented. */ + + /** + * Find the Counter object for the given counter, split into the + * group and tail components. The way to do this depends on whether + * we're running the job driver on the client, or a map or reduce task + * on a tasktracker. + */ + protected def find_split_counter(group: String, tail: String): Counter + + def get_job_context: JobContext + + def get_configuration = get_job_context.getConfiguration + + def get_task_id = get_configuration.getInt("mapred.task.partition", -1) + + /** + * Find the Counter object for the given counter. + */ + protected def find_counter(name: String) = { + val (group, counter) = split_counter(name) + find_split_counter(group, counter) + } + + protected def imp_increment_counter(name: String, incr: Long) { + val counter = find_counter(name) + counter.increment(incr) + } + + protected def imp_get_counter(name: String) = { + val counter = find_counter(name) + counter.getValue() + } + } + + /** + * Mix-in for an Experiment application using Hadoop. This is a trait + * because it should be mixed into a class providing the implementation of + * an application in a way that is indifferent to whether it's being run + * stand-alone or in Hadoop. + * + * This is used both in map/reduce task code and in the client job-running + * code. In some ways it would be cleaner to have separate classes for + * task vs. client job code, but that would entail additional boilerplate + * for any individual apps as they'd have to create separate task and + * client job versions of each class along with a base superclass for the + * two. + */ + + trait HadoopExperimentDriver extends BaseHadoopExperimentDriver { + var job: Job = _ + var context: TaskInputOutputContext[_,_,_,_] = _ + + /** + * Set the task context, if we're running in the map or reduce task + * code on a tasktracker. (Both Mapper.Context and Reducer.Context are + * subclasses of TaskInputOutputContext.) + */ + def set_task_context(context: TaskInputOutputContext[_,_,_,_]) { + this.context = context + } + + /** + * Set the Job object, if we're running the job-running code on the + * client. (Note that we have to set the job like this, rather than have + * it passed in at creation time, e.g. through an abstract field, + * because of bootstrapping issues; explained in HadoopExperimentApp.) + */ + + def set_job(job: Job) { + this.job = job + } + + def get_job_context = { + if (context != null) context + else if (job != null) job + else need_to_set_context() + } + + def find_split_counter(group: String, counter: String) = { + if (context != null) + context.getCounter(group, counter) + else if (job != null) + job.getCounters.findCounter(group, counter) + else + need_to_set_context() + } + + def need_to_set_context() = + throw new IllegalStateException("Either task context or job needs to be set before any counter operations") + + override def heartbeat() { + if (context != null) + context.progress + } + } + + trait HadoopExperimentMapReducer { + type TContext <: TaskInputOutputContext[_,_,_,_] + type TDriver <: HadoopExperimentDriver + val driver = create_driver() + type TParam = driver.TParam + + def progname: String + + def create_param_object(ap: ArgParser): TParam + def create_driver(): TDriver + + /** Originally this was simply called 'setup', but used only for a + * trait that could be mixed into a mapper. Expanding this to allow + * it to be mixed into both a mapper and a reducer didn't cause problems + * but creating a subtrait that override this function did cause problems, + * a complicated message like this: + +[error] /Users/benwing/devel/fieldspring/src/main/scala/opennlp/fieldspring/preprocess/GroupCorpus.scala:208: overriding method setup in class Mapper of type (x$1: org.apache.hadoop.mapreduce.Mapper[java.lang.Object,org.apache.hadoop.io.Text,org.apache.hadoop.io.Text,org.apache.hadoop.io.Text]#Context)Unit; +[error] method setup in trait GroupCorpusMapReducer of type (context: GroupCorpusMapper.this.TContext)Unit cannot override a concrete member without a third member that's overridden by both (this rule is designed to prevent ``accidental overrides'') +[error] class GroupCorpusMapper extends +[error] ^ + + * so I got around it by defining the actual code in another method and + * making the setup() calls everywhere call this. (FIXME unfortunately this + * is error-prone.) + */ + def init(context: TContext) { + import HadoopExperimentConfiguration._ + + val conf = context.getConfiguration + val ap = new ArgParser(progname) + // Initialize set of parameters in `ap` + create_param_object(ap) + // Retrieve configuration values and store in `ap` + convert_parameters_from_hadoop_conf(hadoop_conf_prefix, ap, conf) + // Now create a class containing the stored configuration values + val params = create_param_object(ap) + driver.set_task_context(context) + context.progress + driver.set_parameters(params) + context.progress + driver.setup_for_run() + context.progress + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/ioutil.scala b/src/main/scala/opennlp/fieldspring/util/ioutil.scala new file mode 100644 index 0000000..34ab36c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/ioutil.scala @@ -0,0 +1,943 @@ +/////////////////////////////////////////////////////////////////////////////// +// ioutil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.util.control.Breaks._ +import scala.collection.mutable + +// The following says to import everything except java.io.Console, because +// it conflicts with (and overrides) built-in scala.Console. (Technically, +// it imports everything but in the process aliases Console to _, which +// has the effect of making it inaccessible. _ is special in Scala and has +// various meanings.) +import java.io.{Console=>_,_} +import java.util.NoSuchElementException + +import org.apache.commons.compress.compressors.bzip2._ +import org.apache.commons.compress.compressors.gzip._ + +import printutil.{errprint, warning} +import textutil._ +import osutil._ + +/** + * A 'package object' declaration creates a new subpackage and puts the + * stuff here directly in package scope. This makes it possible to have + * functions in package scope instead of inside a class or object (i.e. + * singleton class). The functions here are accessed using + * 'import opennlp.fieldspring.util.ioutil._' outside of package 'util', + * and simply 'import ioutil._' inside of it. Note that this is named + * 'ioutil' instead of just 'io' to avoid possible conflicts with 'scala.io', + * which is visible by default as 'io'. (Merely declaring it doesn't cause + * a problem, as it overrides 'scala.io'; but people using 'io.*' either + * elsewhere in this package or anywhere that does an import of + * 'opennlp.fieldspring.util._', expecting it to refer to 'scala.io', will + * be surprised. + */ + +package object ioutil { + + ////////////////////////////////////////////////////////////////////////////// + // File reading functions // + ////////////////////////////////////////////////////////////////////////////// + + case class FileFormatException( + message: String + ) extends Exception(message) { } + + /** + * Iterator that yields lines in a given encoding (by default, UTF-8) from + * an input stream, usually with any terminating newline removed and usually + * with automatic closing of the stream when EOF is reached. + * + * @param stream Input stream to read from. + * @param encoding Encoding of the text; by default, UTF-8. + * @param chomp If true (the default), remove any terminating newline. + * Any of LF, CRLF or CR will be removed at end of line. + * @param close If true (the default), automatically close the stream when + * EOF is reached. + * @param errors How to handle conversion errors. (FIXME: Not implemented.) + */ + class FileIterator( + stream: InputStream, + encoding: String = "UTF-8", + chomp: Boolean = true, + close: Boolean = true, + errors: String = "strict" + ) extends Iterator[String] { + var ireader = new InputStreamReader(stream, encoding) + var reader = + // Wrapping in a BufferedReader is necessary because readLine() doesn't + // exist on plain InputStreamReaders + /* if (bufsize > 0) new BufferedReader(ireader, bufsize) else */ + new BufferedReader(ireader) + var nextline: String = null + var hit_eof: Boolean = false + protected def getNextLine() = { + if (hit_eof) false + else { + nextline = reader.readLine() + if (nextline == null) { + hit_eof = true + if (close) + close() + false + } else { + if (chomp) { + if (nextline.endsWith("\r\n")) + nextline = nextline.dropRight(2) + else if (nextline.endsWith("\r")) + nextline = nextline.dropRight(1) + else if (nextline.endsWith("\n")) + nextline = nextline.dropRight(1) + } + true + } + } + } + + def hasNext = { + if (nextline != null) true + else if (reader == null) false + else getNextLine() + } + + def next() = { + if (!hasNext) Iterator.empty.next + else { + val ret = nextline + nextline = null + ret + } + } + + def close() { + if (reader != null) { + reader.close() + reader = null + } + } + } + + abstract class FileHandler { + /** + * Return an InputStream that reads from the given file, usually with + * buffering. + * + * @param filename Name of the file. + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + */ + def get_input_stream(filename: String, bufsize: Int = 0) = { + val raw_in = get_raw_input_stream(filename) + if (bufsize < 0) + raw_in + else if (bufsize == 0) + new BufferedInputStream(raw_in) + else + new BufferedInputStream(raw_in, bufsize) + } + + /** + * Return an OutputStream that writes to the given file, usually with + * buffering. + * + * @param filename Name of the file. + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + */ + def get_output_stream(filename: String, append: Boolean, + bufsize: Int = 0) = { + val raw_out = get_raw_output_stream(filename, append) + if (bufsize < 0) + raw_out + else if (bufsize == 0) + new BufferedOutputStream(raw_out) + else + new BufferedOutputStream(raw_out, bufsize) + } + + /** + * Open a filename with the given encoding (by default, UTF-8) and + * optional decompression (by default, based on the filename), and + * return an iterator that yields lines, usually with any terminating + * newline removed and usually with automatic closing of the stream + * when EOF is reached. + * + * @param filename Name of file to read from. + * @param encoding Encoding of the text; by default, UTF-8. + * @param compression Compression of the file (by default, "byname"). + * Valid values are "none" (no compression), "byname" (use the + * extension of the filename to determine the compression), "gzip" + * and "bzip2". + * @param chomp If true (the default), remove any terminating newline. + * Any of LF, CRLF or CR will be removed at end of line. + * @param close If true (the default), automatically close the stream when + * EOF is reached. + * @param errors How to handle conversion errors. (FIXME: Not implemented.) + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + * + * @return A tuple `(iterator, compression_type, uncompressed_filename)` + * where `iterator` is the iterator over lines, `compression_type` is + * a string indicating the actual compression of the file ("none", + * "gzip" or "bzip2") and `uncompressed_filename` is the name the + * uncompressed file would have (typically by removing the extension + * that indicates compression). + * + * @see `FileIterator`, `get_input_stream`, + * `get_input_stream_handling_compression` + */ + def openr_with_compression_info(filename: String, + encoding: String = "UTF-8", compression: String = "byname", + chomp: Boolean = true, close: Boolean = true, + errors: String = "strict", bufsize: Int = 0) = { + val (stream, comtype, realname) = + get_input_stream_handling_compression(filename, + compression=compression, bufsize=bufsize) + (new FileIterator(stream, encoding=encoding, chomp=chomp, close=close, + errors=errors), comtype, realname) + } + + /** + * Open a filename with the given encoding (by default, UTF-8) and + * optional decompression (by default, based on the filename), and + * return an iterator that yields lines, usually with any terminating + * newline removed and usually with automatic closing of the stream + * when EOF is reached. + * + * @param filename Name of file to read from. + * @param encoding Encoding of the text; by default, UTF-8. + * @param compression Compression of the file (by default, "byname"). + * Valid values are "none" (no compression), "byname" (use the + * extension of the filename to determine the compression), "gzip" + * and "bzip2". + * @param chomp If true (the default), remove any terminating newline. + * Any of LF, CRLF or CR will be removed at end of line. + * @param close If true (the default), automatically close the stream when + * EOF is reached. + * @param errors How to handle conversion errors. (FIXME: Not implemented.) + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + * + * @return An iterator over lines. Use `openr_with_compression_info` to + * also get the actual type of compression and the uncompressed name + * of the file (minus any extension like .gz or .bzip2). The iterator + * will close itself automatically when EOF is reached if the `close` + * option is set to true (the default), or it can be closed explicitly + * using the `close()` method on the iterator. + * + * @see `FileIterator`, `openr_with_compression_info`, `get_input_stream`, + * `get_input_stream_handling_compression` + */ + def openr(filename: String, encoding: String = "UTF-8", + compression: String = "byname", chomp: Boolean = true, + close: Boolean = true, errors: String = "strict", bufsize: Int = 0) = { + val (iterator, _, _) = openr_with_compression_info(filename, + encoding=encoding, compression=compression, chomp=chomp, close=close, + errors=errors, bufsize=bufsize) + iterator + } + + /** + * Wrap an InputStream with optional decompression. It is strongly + * recommended that the InputStream be buffered. + * + * @param stream Input stream. + * @param compression Compression type. Valid values are "none" (no + * compression), "gzip", and "bzip2". + */ + def wrap_input_stream_with_compression(stream: InputStream, + compression: String) = { + if (compression == "none") stream + else if (compression == "gzip") new GzipCompressorInputStream(stream) + else if (compression == "bzip2") new BZip2CompressorInputStream(stream) + else throw new IllegalArgumentException( + "Invalid compression argument: %s" format compression) + } + + /** + * Wrap an OutputStream with optional compression. It is strongly + * recommended that the OutputStream be buffered. + * + * @param stream Output stream. + * @param compression Compression type. Valid values are "none" (no + * compression), "gzip", and "bzip2". + */ + def wrap_output_stream_with_compression(stream: OutputStream, + compression: String) = { + if (compression == "none") stream + else if (compression == "gzip") new GzipCompressorOutputStream(stream) + else if (compression == "bzip2") new BZip2CompressorOutputStream(stream) + else throw new IllegalArgumentException( + "Invalid compression argument: %s" format compression) + } + + /** + * Create an InputStream that reads from the given file, usually with + * buffering and automatic decompression. Either the decompression + * format can be given explicitly (including "none"), or the function can + * be instructed to use the extension of the filename to determine the + * compression format (e.g. ".gz" for gzip). + * + * @param filename Name of the file. + * @param compression Compression of the file (by default, "byname"). + * Valid values are "none" (no compression), "byname" (use the + * extension of the filename to determine the compression), "gzip" + * and "bzip2". + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + * + * @return A tuple `(stream, compression_type, uncompressed_filename)` + * where `stream` is the stream to read from, `compression_type` is + * a string indicating the actual compression of the file ("none", + * "gzip" or "bzip2") and `uncompressed_filename` is the name the + * uncompressed file would have (typically by removing the extension + * that indicates compression). + */ + def get_input_stream_handling_compression(filename: String, + compression: String = "byname", bufsize: Int = 0) = { + val raw_in = get_input_stream(filename, bufsize) + val comtype = + if (compression == "byname") { + if (BZip2Utils.isCompressedFilename(filename)) "bzip2" + else if (GzipUtils.isCompressedFilename(filename)) "gzip" + else "none" + } else compression + val in = wrap_input_stream_with_compression(raw_in, comtype) + val realname = comtype match { + case "gzip" => GzipUtils.getUncompressedFilename(filename) + case "bzip2" => BZip2Utils.getUncompressedFilename(filename) + case _ => { + assert(comtype == "none", + "wrap_input_stream_with_compression should have verified value") + filename + } + } + (in, comtype, realname) + } + + /** + * Create an OutputStream that writes ito the given file, usually with + * buffering and automatic decompression. + * + * @param filename Name of the file. The filename will automatically + * have a suffix added to it to indicate compression, if compression + * is called for. + * @param compression Compression of the file (by default, "none"). + * Valid values are "none" (no compression), "gzip" and "bzip2". + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + * + * @return A tuple `(stream, compressed_filename)`, `stream` is the + * stream to write to and `compressed_filename` is the actual name + * assigned to the file, including the compression suffix, if any. + */ + def get_output_stream_handling_compression(filename: String, + append: Boolean, compression: String = "none", bufsize: Int = 0) = { + val realname = compression match { + case "gzip" => GzipUtils.getCompressedFilename(filename) + case "bzip2" => BZip2Utils.getCompressedFilename(filename) + case "none" => filename + case _ => throw new IllegalArgumentException( + "Invalid compression argument: %s" format compression) + } + val raw_out = get_output_stream(realname, append, bufsize) + val out = wrap_output_stream_with_compression(raw_out, compression) + (out, realname) + } + + /** + * Open a file for writing, with optional compression (by default, no + * compression), and encoding (by default, UTF-8) and return a + * PrintStream that will write to the file. + * + * @param filename Name of file to write to. The filename will + * automatically have a suffix added to it to indicate compression, + * if compression is called for. + * @param encoding Encoding of the text; by default, UTF-8. + * @param compression Compression type. Valid values are "none" (no + * compression), "gzip", and "bzip2". + * @param bufsize Buffering size. If 0 (the default), the default + * buffer size is used. If > 0, the specified size is used. If + * < 0, there is no buffering. + * @param autoflush If true, automatically flush the PrintStream after + * every output call. (Note that if compression is in effect, the + * flush may not actually cause anything to get written.) + */ + def openw(filename: String, append: Boolean = false, + encoding: String = "UTF-8", compression: String = "none", + bufsize: Int = 0, autoflush: Boolean = false) = { + val (out, _) = + get_output_stream_handling_compression(filename, append, compression, + bufsize) + new PrintStream(out, autoflush, encoding) + } + + /* ----------- Abstract functions below this line ----------- */ + + /** + * Return an unbuffered InputStream that reads from the given file. + */ + def get_raw_input_stream(filename: String): InputStream + + /** + * Return an unbuffered OutputStream that writes to the given file, + * overwriting an existing file. + */ + def get_raw_output_stream(filename: String, append: Boolean): OutputStream + /** + * Split a string naming a file into the directory it's in and the + * final component. + */ + def split_filename(filename: String): (String, String) + /** + * Join a string naming a directory to a string naming a file. If the + * file is relative, it is to be interpreted relative to the directory. + */ + def join_filename(dir: String, file: String): String + /** + * Is this file a directory? + */ + def is_directory(filename: String): Boolean + /** + * Create a directory, along with any missing parents. Returns true + * if the directory was created, false if it already exists. + */ + def make_directories(filename: String): Boolean + /** + * List the files in the given directory. + */ + def list_files(dir: String): Iterable[String] + + } + + class LocalFileHandler extends FileHandler { + def check_exists(filename: String) { + if (!new File(filename).exists) + throw new FileNotFoundException("%s (No such file or directory)" + format filename) + } + def get_raw_input_stream(filename: String) = new FileInputStream(filename) + def get_raw_output_stream(filename: String, append: Boolean) = + new FileOutputStream(filename, append) + def split_filename(filename: String) = { + val file = new File(filename) + (file.getParent, file.getName) + } + def join_filename(dir: String, file: String) = + new File(dir, file).toString + def is_directory(filename: String) = { + check_exists(filename) + new File(filename).isDirectory + } + def make_directories(filename: String): Boolean = + new File(filename).mkdirs + def list_files(dir: String) = { + check_exists(dir) + for (file <- new File(dir).listFiles) + yield file.toString + } + } + + val local_file_handler = new LocalFileHandler + + /* NOTE: Following is the original Python code, which worked slightly + differently and had a few additional features: + + -- You could pass in a list of files and it would iterate through + all files in turn; you could pass in no files, in which case it + would read from stdin. + -- You could specify the way of handling errors when doing Unicode + encoding. (FIXME: How do we do this in Java?) + -- You could also specify a read mode. This was primarily useful + for controlling the way that line endings are handled -- e.g. + "rU" or "U" turns on "universal newline" support, where the + various kinds of newline endings are automatically converted to + '\n'; and "rb", which turns on "binary" mode, which forces + newline conversion *not* to happen even on systems where it is + the default (particularly, on Windows, where text files are + terminated by '\r\n', which is normally converted to '\n' on + input). Currently, when 'chomp' is true, we automatically + chomp off all kinds of newlines (whether '\n', '\r' or '\r\n'); + otherwise, we do what the system wants to do by default. + -- You could specify "in-place modification". This is built into + the underlying 'fileinput' module in Python and works like the + similar feature in Perl. If you turn the feature on, the input + file (which cannot be stdin) is renamed upon input, and stdout + is opened so it writes to a file with the original name. + The backup file is normally formed by appending '.bak', and + is deleted automatically on close; but if the 'backup' argument + is given, the backup file will be maintained, and will be named + by appending the string given as the value of the argument. + */ + + + ///// 1. chompopen(): + ///// + ///// A generator that yields lines from a file, with any terminating newline + ///// removed (but no other whitespace removed). Ensures that the file + ///// will be automatically closed under all circumstances. + ///// + ///// 2. openr(): + ///// + ///// Same as chompopen() but specifically open the file as 'utf-8' and + ///// return Unicode strings. + + //""" + //Test gopen + // + //import nlputil + //for line in nlputil.gopen("foo.txt"): + // print line + //for line in nlputil.gopen("foo.txt", chomp=true): + // print line + //for line in nlputil.gopen("foo.txt", encoding="utf-8"): + // print line + //for line in nlputil.gopen("foo.txt", encoding="utf-8", chomp=true): + // print line + //for line in nlputil.gopen("foo.txt", encoding="iso-8859-1"): + // print line + //for line in nlputil.gopen(["foo.txt"], encoding="iso-8859-1"): + // print line + //for line in nlputil.gopen(["foo.txt"], encoding="utf-8"): + // print line + //for line in nlputil.gopen(["foo.txt"], encoding="iso-8859-1", chomp=true): + // print line + //for line in nlputil.gopen(["foo.txt", "foo2.txt"], encoding="iso-8859-1", chomp=true): + // print line + //""" + +// // General function for opening a file, with automatic closure after iterating +// // through the lines. The encoding can be specified (e.g. "utf-8"), and if so, +// // the error-handling can be given. Whether to remove the final newline +// // (chomp=true) can be specified. The filename can be either a regular +// // filename (opened with open) or codecs.open(), or a list of filenames or +// // None, in which case the argument is passed to fileinput.input() +// // (if a non-empty list is given, opens the list of filenames one after the +// // other; if an empty list is given, opens stdin; if None is given, takes +// // list from the command-line arguments and proceeds as above). When using +// // fileinput.input(), the arguments "inplace", "backup" and "bufsize" can be +// // given, appropriate to that function (e.g. to do in-place filtering of a +// // file). In all cases, +// def gopen(filename, mode="r", encoding=None, errors="strict", chomp=false, +// inplace=0, backup="", bufsize=0): +// if isinstance(filename, basestring): +// def yieldlines(): +// if encoding is None: +// mgr = open(filename) +// else: +// mgr = codecs.open(filename, mode, encoding=encoding, errors=errors) +// with mgr as f: +// for line in f: +// yield line +// iterator = yieldlines() +// else: +// if encoding is None: +// openhook = None +// else: +// def openhook(filename, mode): +// return codecs.open(filename, mode, encoding=encoding, errors=errors) +// iterator = fileinput.input(filename, inplace=inplace, backup=backup, +// bufsize=bufsize, mode=mode, openhook=openhook) +// if chomp: +// for line in iterator: +// if line and line[-1] == "\n": line = line[:-1] +// yield line +// else: +// for line in iterator: +// yield line +// +// // Open a filename and yield lines, but with any terminating newline +// // removed (similar to "chomp" in Perl). Basically same as gopen() but +// // with defaults set differently. +// def chompopen(filename, mode="r", encoding=None, errors="strict", +// chomp=true, inplace=0, backup="", bufsize=0): +// return gopen(filename, mode=mode, encoding=encoding, errors=errors, +// chomp=chomp, inplace=inplace, backup=backup, bufsize=bufsize) +// +// // Open a filename with UTF-8-encoded input. Basically same as gopen() +// // but with defaults set differently. +// def uopen(filename, mode="r", encoding="utf-8", errors="strict", +// chomp=false, inplace=0, backup="", bufsize=0): +// return gopen(filename, mode=mode, encoding=encoding, errors=errors, +// chomp=chomp, inplace=inplace, backup=backup, bufsize=bufsize) +// + + /** + * Class that lets you process a series of files in turn; if any file + * names a directory, all files in the directory will be processed. + * If a file is given as 'null', that will be passed on unchanged. + * (Useful to signal input taken from an internal source.) + * + * @tparam T Type of result associated with a file. + */ + trait FileProcessor[T] { + + /** + * Process all files, calling `process_file` on each. + * + * @param files Files to process. If any file names a directory, + * all files in the directory will be processed. If any file + * is null, it will be passed on unchanged (see above; useful + * e.g. for specifying input from an internal source). + * @param output_messages If true, output messages indicating the + * files being processed. + * @return Tuple `(completed, values)` where `completed` is True if file + * processing continued to completion, false if interrupted because an + * invocation of `process_file` returned false, and `values` is a + * sequence of the individual values for each file. + */ + def process_files(filehand: FileHandler, files: Iterable[String], + output_messages: Boolean = true) = { + var broken = false + begin_processing(filehand, files) + val buf = mutable.Buffer[T]() + breakable { + def process_one_file(filename: String) { + if (output_messages && filename != null) + errprint("Processing file %s..." format filename) + begin_process_file(filehand, filename) + val (continue, value) = process_file(filehand, filename) + buf += value + if (!continue) { + // This works because of the way 'breakable' is implemented + // (dynamically-scoped). Might "break" (stop working) if break + // is made totally lexically-scoped. + broken = true + } + end_process_file(filehand, filename) + if (broken) + break + } + for (dir <- files) { + if (dir == null) + process_one_file(dir) + else { + if (filehand.is_directory(dir)) { + if (output_messages) + errprint("Processing directory %s..." format dir) + begin_process_directory(filehand, dir) + for (file <- list_files(filehand, dir)) { + process_one_file(file) + } + end_process_directory(filehand, dir) + } else process_one_file(dir) + } + } + } + end_processing(filehand, files) + (!broken, buf.toSeq) + } + + /*********************** MUST BE IMPLEMENTED *************************/ + + /** + * Process a given file. + * + * @param filehand The FileHandler for working with the file. + * @param file The file to process (possibly null, see above). + * @return A tuple of `(continue, result)` where `continue` is a Boolean + * (True if file processing should continue) and `result` is the result + * of processing this file. + */ + def process_file(filehand: FileHandler, file: String): (Boolean, T) + + /***************** MAY BE IMPLEMENTED (THROUGH OVERRIDING) ***************/ + + /** + * Called when about to begin processing all files in a directory. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param dir Directory being processed. + */ + def begin_process_directory(filehand: FileHandler, dir: String) { + } + + /** + * Called when finished processing all files in a directory. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param dir Directory being processed. + */ + def end_process_directory(filehand: FileHandler, dir: String) { + } + + /** + * Called when about to begin processing a file. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param file File being processed. + */ + def begin_process_file(filehand: FileHandler, file: String) { + } + + /** + * Called when finished processing a file. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param file File being processed. + */ + def end_process_file(filehand: FileHandler, file: String) { + } + + /** + * Called when about to begin processing files. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param dir Directory being processed. + */ + def begin_processing(filehand: FileHandler, files: Iterable[String]) { + } + + /** + * Called when finished processing all files. + * Must be overridden, since it has an (empty) definition by default. + * + * @param filehand The FileHandler for working with the file. + * @param dir Directory being processed. + */ + def end_processing(filehand: FileHandler, files: Iterable[String]) { + } + + /** + * List the files in a directory to be processed. + * + * @param filehand The FileHandler for working with the files. + * @param dir Directory being processed. + * + * @return The list of files to process. + */ + def list_files(filehand: FileHandler, dir: String) = + filehand.list_files(dir) + } + + /** + * Class that lets you process a series of text files in turn, using + * the same mechanism for processing the files themselves as in + * `FileProcessor`. + */ + trait LineProcessor[T] extends FileProcessor[T] { + /** + * Process a given file. + * + * @param filehand The FileHandler for working with the file. + * @param file The file to process (possibly null, see above). + * @return True if file processing should continue; false to + * abort any further processing. + */ + def process_file(filehand: FileHandler, file: String) = { + val (lines, compression, realname) = + filehand.openr_with_compression_info(file) + try { + begin_process_lines(lines, filehand, file, compression, realname) + process_lines(lines, filehand, file, compression, realname) + } finally { + lines.close() + } + } + + /*********************** MUST BE IMPLEMENTED *************************/ + + /** + * Called to process the lines of a file. + * Must be overridden, since it has an (empty) definition by default. + * + * @param lines Iterator over the lines in the file. + * @param filehand The FileHandler for working with the file. + * @param file The name of the file being processed. + * @param compression The compression of the file ("none" for no + * compression). + * @param realname The "real" name of the file, after any compression + * suffix (e.g. .gz, .bzip2) is removed. + */ + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String): (Boolean, T) + + /***************** MAY BE IMPLEMENTED (THROUGH OVERRIDING) ***************/ + + /** + * Called when about to begin processing the lines from a file. + * Must be overridden, since it has an (empty) definition by default. + * Note that this is generally called just once per file, just like + * `begin_process_file`; but this function has compression info and + * the line iterator available to it. + * + * @param filehand The FileHandler for working with the file. + * @param file File being processed. + */ + def begin_process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) { } + } + + //////////////////////////////////////////////////////////////////////////// + // File Splitting // + //////////////////////////////////////////////////////////////////////////// + + // Return the next file to output to, when the instances being output to the + // files are meant to be split according to SPLIT_FRACTIONS. The absolute + // quantities in SPLIT_FRACTIONS don't matter, only the values relative to + // the other values, i.e. [20, 60, 10] is the same as [4, 12, 2]. This + // function implements an algorithm that is deterministic (same results + // each time it is run), and spreads out the instances as much as possible. + // For example, if all values are equal, it will cycle successively through + // the different split files; if the values are [1, 1.5, 1], the output + // will be [1, 2, 3, 2, 1, 2, 3, ...]; etc. + + def next_split_set(split_fractions: Seq[Double]): Iterable[Int] = { + + val num_splits = split_fractions.length + val cumulative_articles = mutable.Seq.fill(num_splits)(0.0) + + // Normalize so that the smallest value is 1. + + val minval = split_fractions min + val normalized_split_fractions = + (for (value <- split_fractions) yield value.toDouble/minval) + + // The algorithm used is as follows. We cycle through the output sets in + // order; each time we return a set, we increment the corresponding + // cumulative count, but before returning a set, we check to see if the + // count has reached the corresponding fraction and skip this set if so. + // If we have run through an entire cycle without returning any sets, + // then for each set we subtract the fraction value from the cumulative + // value. This way, if the fraction value is not a whole number, then + // any fractional quantity (e.g. 0.6 for a value of 7.6) is left over, + // any will ensure that the total ratios still work out appropriately. + + def fuckme_no_yield(): Stream[Int] = { + var yieldme = mutable.Buffer[Int]() + for (j <- 0 until num_splits) { + //println("j=%s, this_output=%s" format (j, this_output)) + if (cumulative_articles(j) < normalized_split_fractions(j)) { + yieldme += j + cumulative_articles(j) += 1 + } + } + if (yieldme.length == 0) { + for (j <- 0 until num_splits) { + while (cumulative_articles(j) >= normalized_split_fractions(j)) + cumulative_articles(j) -= normalized_split_fractions(j) + } + } + yieldme.toStream ++ fuckme_no_yield() + } + fuckme_no_yield() + } + + //////////////////////////////////////////////////////////////////////////// + // Subprocesses // + //////////////////////////////////////////////////////////////////////////// + + /** + * Run a subprocess and capture its output. Arguments given are those + * that will be passed to the subprocess. + */ + + def capture_subprocess_output(args: String*) = { + val output = new StringBuilder() + val proc = new ProcessBuilder(args: _*).start() + val in = proc.getInputStream() + val br = new BufferedReader(new InputStreamReader(in)) + val cbuf = new Array[Char](100) + var numread = 0 + /* SCALABUG: The following compiles but will give incorrect results because + the result of an assignment is Unit! (You do get a warning but ...) + + while ((numread = br.read(cbuf, 0, cbuf.length)) != -1) + output.appendAll(cbuf, 0, numread) + + */ + numread = br.read(cbuf, 0, cbuf.length) + while (numread != -1) { + output.appendAll(cbuf, 0, numread) + numread = br.read(cbuf, 0, cbuf.length) + } + proc.waitFor() + in.close() + output.toString + } + + // The original Python implementation, which had more functionality: + + /* + Run the specified command; return its output (usually, the combined + stdout and stderr output) as a string. 'command' can either be a + string or a list of individual arguments. Optional argument 'shell' + indicates whether to pass the command to the shell to run. If + unspecified, it defaults to true if 'command' is a string, false if + a list. If optional arg 'input' is given, pass this string as the + stdin to the command. If 'include_stderr' is true (the default), + stderr will be included along with the output. If return code is + non-zero, throw CommandError if 'throw' is specified; else, return + tuple of (output, return-code). + */ + +// def backquote(command, input=None, shell=None, include_stderr=true, throw=true): +// //logdebug("backquote called: %s" % command) +// if shell is None: +// if isinstance(command, basestring): +// shell = true +// else: +// shell = false +// stderrval = STDOUT if include_stderr else PIPE +// if input is not None: +// popen = Popen(command, stdin=PIPE, stdout=PIPE, stderr=stderrval, +// shell=shell, close_fds=true) +// output = popen.communicate(input) +// else: +// popen = Popen(command, stdout=PIPE, stderr=stderrval, +// shell=shell, close_fds=true) +// output = popen.communicate() +// if popen.returncode != 0: +// if throw: +// if output[0]: +// outputstr = "Command's output:\n%s" % output[0] +// if outputstr[-1] != '\n': +// outputstr += '\n' +// errstr = output[1] +// if errstr and errstr[-1] != '\n': +// errstr += '\n' +// errmess = ("Error running command: %s\n\n%s\n%s" % +// (command, output[0], output[1])) +// //log.error(errmess) +// oserror(errmess, EINVAL) +// else: +// return (output[0], popen.returncode) +// return output[0] +// +// def oserror(mess, err): +// e = OSError(mess) +// e.errno = err +// raise e + +} diff --git a/src/main/scala/opennlp/fieldspring/util/mathutil.scala b/src/main/scala/opennlp/fieldspring/util/mathutil.scala new file mode 100644 index 0000000..8e32f63 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/mathutil.scala @@ -0,0 +1,90 @@ +/////////////////////////////////////////////////////////////////////////////// +// mathutil.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import math._ + +package object mathutil { + /** + * Return the median value of a list. List will be sorted, so this is O(n). + */ + def median(list: Seq[Double]) = { + val sorted = list.sorted + val len = sorted.length + if (len % 2 == 1) + sorted(len / 2) + else { + val midp = len / 2 + 0.5*(sorted(midp-1) + sorted(midp)) + } + } + + /** + * Return the mean of a list. + */ + def mean(list: Seq[Double]) = { + list.sum / list.length + } + + def variance(x: Seq[Double]) = { + val m = mean(x) + mean(for (y <- x) yield ((y - m) * (y - m))) + } + + def stddev(x: Seq[Double]) = sqrt(variance(x)) + + abstract class MeanShift[Coord : Manifest]( + h: Double = 1.0, + max_stddev: Double = 1e-10, + max_iterations: Int = 100 + ) { + def squared_distance(x:Coord, y:Coord): Double + def weighted_sum(weights:Array[Double], points:Array[Coord]): Coord + def scaled_sum(scalar:Double, points:Array[Coord]): Coord + + def vec_mean(points:Array[Coord]) = scaled_sum(1.0/points.length, points) + + def vec_variance(points:Array[Coord]) = { + def m = vec_mean(points) + mean( + for (i <- 0 until points.length) yield squared_distance(m, points(i))) + } + + def mean_shift(list: Seq[Coord]):Array[Coord] = { + var next_stddev = max_stddev + 1 + var numiters = 0 + val points = list.toArray + val shifted = list.toArray + while (next_stddev >= max_stddev && numiters <= max_iterations) { + for (j <- 0 until points.length) { + val y = shifted(j) + val weights = + (for (i <- 0 until points.length) + yield exp(-squared_distance(y, points(i))/(h*h))) + val weight_sum = weights sum + val normalized_weights = weights.map(_ / weight_sum).toArray + shifted(j) = weighted_sum(normalized_weights, points) + } + numiters += 1 + next_stddev = sqrt(vec_variance(shifted)) + } + shifted + } + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/osutil.scala b/src/main/scala/opennlp/fieldspring/util/osutil.scala new file mode 100644 index 0000000..72ac499 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/osutil.scala @@ -0,0 +1,199 @@ +/////////////////////////////////////////////////////////////////////////////// +// osutil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.collection.mutable + +import java.util.Date +import java.text.DateFormat +import java.io._ + +import ioutil._ +import printutil._ +import textutil._ +import timeutil.format_minutes_seconds + +package object osutil { + + //////////////////////////////////////////////////////////////////////////// + // Resource Usage // + //////////////////////////////////////////////////////////////////////////// + + /** + * Return floating-point value, number of seconds since the Epoch + **/ + def curtimesecs() = System.currentTimeMillis/1000.0 + + def curtimehuman() = (new Date()) toString + + def humandate_full(sectime: Double) = + (new Date((sectime*1000).toLong)) toString + def humandate_time(sectime: Double) = + DateFormat.getTimeInstance().format((sectime*1000).toLong) + + import java.lang.management._ + def getpid() = ManagementFactory.getRuntimeMXBean().getName().split("@")(0) + private var initialized = false + private def check_initialized() { + if (!initialized) + throw new IllegalStateException("""You must call initialize_osutil() +at the beginning of your program, in order to use get_program_time_usage()""") + } + /** + * Call this if you use `get_program_time_usage` or `output_resource_usage`. + * This is necessary in order to record the time at the beginning of the + * program. + */ + def initialize_osutil() { + // Simply calling this function is enough, because it will trigger the + // loading of the class associated with package object osutil, which + // will cause `beginning_prog_time` to get set. We set a flag to verify + // that this is done. + initialized = true + } + val beginning_prog_time = curtimesecs() + + def get_program_time_usage() = { + check_initialized() + curtimesecs() - beginning_prog_time + } + + /** + * Return memory usage as a + */ + def get_program_memory_usage(virtual: Boolean = false, + method: String = "auto"): (String, Long) = { + method match { + case "java" => (method, get_program_memory_usage_java()) + case "proc" => (method, get_program_memory_usage_proc(virtual = virtual)) + case "ps" => (method, get_program_memory_usage_ps(virtual = virtual)) + case "rusage" => (method, get_program_memory_usage_rusage()) + case "auto" => { + val procmem = get_program_memory_usage_proc(virtual = virtual) + if (procmem > 0) return ("proc", procmem) + val psmem = get_program_memory_usage_ps(virtual = virtual) + if (psmem > 0) return ("ps", psmem) + val rusagemem = get_program_memory_usage_rusage() + if (rusagemem > 0) return ("rusage", rusagemem) + return ("java", get_program_memory_usage_java()) + } + } + } + + def get_program_memory_usage_java() = { + System.gc() + System.gc() + val rt = Runtime.getRuntime + rt.totalMemory - rt.freeMemory + } + + def get_program_memory_usage_rusage() = { + // val res = resource.getrusage(resource.RUSAGE_SELF) + // // FIXME! This is "maximum resident set size". There are other more useful + // // values, but on the Mac at least they show up as 0 in this structure. + // // On Linux, alas, all values show up as 0 or garbage (e.g. negative). + // res.ru_maxrss + -1L + } + + def wrap_call[Ret](fn: => Ret, errval: Ret) = { + try { + fn + } catch { + case e@_ => { errprint("%s", e); errval } + } + } + + // Get memory usage by running 'ps'; getrusage() doesn't seem to work very + // well. The following seems to work on both Mac OS X and Linux, at least. + def get_program_memory_usage_ps(virtual: Boolean = false, + wraperr: Boolean = true): Long = { + if (wraperr) + return wrap_call(get_program_memory_usage_ps( + virtual=virtual, wraperr=false), -1L) + val header = if (virtual) "vsz" else "rss" + val pid = getpid() + val input = + capture_subprocess_output("ps", "-p", pid.toString, "-o", header) + val lines = input.split('\n') + for (line <- lines if line.trim != header.toUpperCase) + return 1024*line.trim.toLong + return -1L + } + + // Get memory usage by running 'proc'; this works on Linux and doesn't + // require spawning a subprocess, which can crash when your program is + // very large. + def get_program_memory_usage_proc(virtual: Boolean = false, + wraperr: Boolean = true): Long = { + if (wraperr) + return wrap_call(get_program_memory_usage_proc( + virtual=virtual, wraperr=false), -1L) + val header = if (virtual) "VmSize:" else "VmRSS:" + if (!((new File("/proc/self/status")).exists)) + return -1L + for (line <- local_file_handler.openr("/proc/self/status")) { + val trimline = line.trim + if (trimline.startsWith(header)) { + val size = ("""\s+""".r.split(trimline))(1).toLong + return 1024*size + } + } + return -1L + } + + def output_memory_usage(virtual: Boolean = false) { + for (method <- List("auto", "java", "proc", "ps", "rusage")) { + val (meth, mem) = + get_program_memory_usage(virtual = virtual, method = method) + val memtype = if (virtual) "virtual size" else "resident set size" + val methstr = if (method == "auto") "auto=%s" format meth else method + errout("Memory usage, %s (%s): ", memtype, methstr) + if (mem <= 0) + errprint("Unknown") + else + errprint("%s bytes", with_commas(mem)) + } + } + + def output_resource_usage(dojava: Boolean = true) { + errprint("Total elapsed time since program start: %s", + format_minutes_seconds(get_program_time_usage())) + val (vszmeth, vsz) = get_program_memory_usage(virtual = true, + method = "auto") + errprint("Memory usage, virtual memory size (%s): %s bytes", vszmeth, + with_commas(vsz)) + val (rssmeth, rss) = get_program_memory_usage(virtual = false, + method = "auto") + errprint("Memory usage, actual (i.e. resident set) (%s): %s bytes", rssmeth, + with_commas(rss)) + if (dojava) { + val (_, java) = get_program_memory_usage(virtual = false, + method = "java") + errprint("Memory usage, Java heap: %s bytes", with_commas(java)) + } else + System.gc() + } + + /* For testing the output_memory_usage() function. */ + object TestMemUsage extends App { + output_memory_usage() + } +} + diff --git a/src/main/scala/opennlp/fieldspring/util/printutil.scala b/src/main/scala/opennlp/fieldspring/util/printutil.scala new file mode 100644 index 0000000..db82edc --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/printutil.scala @@ -0,0 +1,251 @@ +/////////////////////////////////////////////////////////////////////////////// +// printutil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.util.control.Breaks._ +import scala.collection.mutable + +// The following says to import everything except java.io.Console, because +// it conflicts with (and overrides) built-in scala.Console. (Technically, +// it imports everything but in the process aliases Console to _, which +// has the effect of making it inaccessible. _ is special in Scala and has +// various meanings.) +import java.io.{Console=>_,_} + +import textutil._ +import ioutil._ +import osutil._ + +package object printutil { + + //////////////////////////////////////////////////////////////////////////// + // Text output functions // + //////////////////////////////////////////////////////////////////////////// + + // This stuff sucks. Need to create new Print streams to get the expected + // UTF-8 output, since the existing System.out/System.err streams don't do it! + val stdout_stream = new PrintStream(System.out, true, "UTF-8") + val stderr_stream = new PrintStream(System.err, true, "UTF-8") + + /** + Set Java System.out and System.err, and Scala Console.out and Console.err, + so that they convert text to UTF-8 upon output (rather than e.g. MacRoman, + the default on Mac OS X). + */ + def set_stdout_stderr_utf_8() { + // Fuck me to hell, have to fix things up in a non-obvious way to + // get UTF-8 output on the Mac (default is MacRoman???). + System.setOut(stdout_stream) + System.setErr(stderr_stream) + Console.setOut(System.out) + Console.setErr(System.err) + } + + def uniprint(text: String, outfile: PrintStream = System.out) { + outfile.println(text) + } + def uniout(text: String, outfile: PrintStream = System.out) { + outfile.print(text) + } + + var errout_prefix = "" + + def set_errout_prefix(prefix: String) { + errout_prefix = prefix + } + + var need_prefix = true + + var errout_stream: PrintStream = System.err + + def set_errout_stream(stream: PrintStream) { + if (stream == null) + errout_stream = System.err + else + errout_stream = stream + } + + def get_errout_stream(file: String) = { + if (file == null) + System.err + else + (new LocalFileHandler).openw(file, append = true, bufsize = -1) + } + + def set_errout_file(file: String) { + set_errout_stream(get_errout_stream(file)) + } + + protected def format_outtext(format: String, args: Any*) = { + // If no arguments, assume that we've been passed a raw string to print, + // so print it directly rather than passing it to 'format', which might + // munge % signs + val outtext = + if (args.length == 0) format + else format format (args: _*) + if (need_prefix) + errout_prefix + outtext + else + outtext + } + + def errfile(file: String, format: String, args: Any*) { + val stream = get_errout_stream(file) + stream.println(format_outtext(format, args: _*)) + need_prefix = true + stream.flush() + if (stream != System.err) + stream.close() + } + + def errprint(format: String, args: Any*) { + errout_stream.println(format_outtext(format, args: _*)) + need_prefix = true + errout_stream.flush() + } + + def errout(format: String, args: Any*) { + val text = format_outtext(format, args: _*) + errout_stream.print(text) + need_prefix = text.last == '\n' + errout_stream.flush() + } + + /** + Output a warning, formatting into UTF-8 as necessary. + */ + def warning(format: String, args: Any*) { + errprint("Warning: " + format, args: _*) + } + + /** + Output a value, for debugging through print statements. + Basically same as just caling errprint() or println() or whatever, + but useful because the call to debprint() more clearly identifies a + temporary piece of debugging code that should be removed when the + bug has been identified. + */ + def debprint(format: String, args: Any*) { + errprint("Debug: " + format, args: _*) + } + + def print_msg_heading(msg: String, blank_lines_before: Int = 1) { + for (x <- 0 until blank_lines_before) + errprint("") + errprint(msg) + errprint("-" * msg.length) + } + + /** + * Return the stack trace of an exception as a string. + */ + def stack_trace_as_string(e: Exception) = { + val writer = new StringWriter() + val pwriter = new PrintWriter(writer) + e.printStackTrace(pwriter) + pwriter.close() + writer.toString + } + + //////////////////////////////////////////////////////////////////////////// + // Table Output // + //////////////////////////////////////////////////////////////////////////// + + /** + * Given a list of tuples, output the list, one line per tuple. + * + * @param outfile If specified, send output to this stream instead of + * stdout. + * @param indent If specified, indent all rows by this string (usually + * some number of spaces). + * @param maxrows If specified, output at most this many rows. + */ + def output_tuple_list[T,U]( + items: Seq[(T,U)], outfile: PrintStream = System.out, + indent: String = "", maxrows: Int = -1) { + var its = items + if (maxrows >= 0) + its = its.slice(0, maxrows) + for ((key, value) <- its) + outfile.println("%s%s = %s" format (indent, key, value)) + } + + /** + * Given a list of tuples, where the second element of the tuple is a + * number and the first a key, output the list, sorted on the numbers from + * bigger to smaller. Within a given number, normally sort the items + * alphabetically. + * + * @param, keep_secondary_order If true, the original order of items is + * left instead of sorting secondarily. + * @param outfile If specified, send output to this stream instead of + * stdout. + * @param indent If specified, indent all rows by this string (usually + * some number of spaces). + * @param maxrows If specified, output at most this many rows. + */ + def output_reverse_sorted_list[T <% Ordered[T],U <% Ordered[U]]( + items: Seq[(T,U)], keep_secondary_order: Boolean = false, + outfile: PrintStream = System.out, indent: String = "", + maxrows: Int = -1) { + var its = items + if (!keep_secondary_order) + its = its sortBy (_._1) + its = its sortWith (_._2 > _._2) + output_tuple_list(its, outfile, indent, maxrows) + } + + /** + * Given a table with values that are numbers, output the table, sorted on + * the numbers from bigger to smaller. Within a given number, normally + * sort the items alphabetically. + * + * @param, keep_secondary_order If true, the original order of items is + * left instead of sorting secondarily. + * @param outfile If specified, send output to this stream instead of + * stdout. + * @param indent If specified, indent all rows by this string (usually + * some number of spaces). + * @param maxrows If specified, output at most this many rows. + */ + def output_reverse_sorted_table[T <% Ordered[T],U <% Ordered[U]]( + table: collection.Map[T,U], keep_secondary_order: Boolean = false, + outfile: PrintStream = System.out, indent: String = "", + maxrows: Int = -1) { + output_reverse_sorted_list(table toList, keep_secondary_order, + outfile, indent, maxrows) + } + + /** + * Output a table, sorted by its key. + * + * @param outfile If specified, send output to this stream instead of + * stdout. + * @param indent If specified, indent all rows by this string (usually + * some number of spaces). + * @param maxrows If specified, output at most this many rows. + */ + def output_key_sorted_table[T <% Ordered[T],U]( + table: collection.Map[T,U], + outfile: PrintStream = System.out, indent: String = "", + maxrows: Int = -1) { + output_tuple_list(table.toSeq.sortBy (_._1), outfile, indent, + maxrows) + } +} diff --git a/src/main/scala/opennlp/fieldspring/util/textdbutil.scala b/src/main/scala/opennlp/fieldspring/util/textdbutil.scala new file mode 100644 index 0000000..6c77cba --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/textdbutil.scala @@ -0,0 +1,937 @@ +/////////////////////////////////////////////////////////////////////////////// +// textdbutil.scala +// +// Copyright (C) 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import collection.mutable +import util.control.Breaks._ + +import java.io.PrintStream + +import printutil.{errprint, warning} +import ioutil._ + +package object textdbutil { + /** + * A line processor where each line is made up of a fixed number + * of fields, separated by some sort of separator (by default a tab + * character). No implementation is provided for `process_lines`, + * the driver function. This function in general should loop over the + * lines, calling `parse_row` on each one. + * + * @param split_re Regular expression used to split one field from another. + * By default a tab character. + */ + trait FieldLineProcessor[T] extends LineProcessor[T] { + val split_re: String = "\t" + + var all_num_processed = 0 + var all_num_bad = 0 + var num_processed = 0 + var num_bad = 0 + + var fieldnames: Seq[String] = _ + + /** + * Set the field names used for processing rows. + */ + def set_fieldnames(fieldnames: Seq[String]) { + this.fieldnames = fieldnames + } + + override def begin_process_file(filehand: FileHandler, file: String) { + num_processed = 0 + num_bad = 0 + super.begin_process_file(filehand, file) + } + + override def end_process_file(filehand: FileHandler, file: String) { + all_num_processed += num_processed + all_num_bad += num_bad + super.end_process_file(filehand, file) + } + + /** + * Parse a given row into fields. Call either #process_row or + * #handle_bad_row. + * + * @param line Raw text of line describing the row + * @return True if processing should continue, false if it should stop. + */ + def parse_row(line: String) = { + // println("[%s]" format line) + val fieldvals = line.split(split_re, -1) + if (fieldvals.length != fieldnames.length) { + handle_bad_row(line, fieldvals) + num_processed += 1 + true + } else { + val (good, keep_going) = process_row(fieldvals) + if (!good) + handle_bad_row(line, fieldvals) + num_processed += 1 + keep_going + } + } + + /*********************** MUST BE IMPLEMENTED *************************/ + + /** + * Called when a "good" row is seen (good solely in that it has the + * correct number of fields). + * + * @param fieldvals List of the string values for each field. + * + * @return Tuple `(good, keep_going)` where `good` indicates whether + * the given row was truly "good" (and hence processed, rather than + * skipped), and `keep_going` indicates whether processing of further + * rows should continue or stop. If the return value indicates that + * the row isn't actually good, `handle_bad_row` will be called. + */ + def process_row(fieldvals: Seq[String]): (Boolean, Boolean) + + /* Also, + + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String): (Boolean, T) + + A simple implementation simply loops over all the lines and calls + parse_row() on each one. + */ + + /******************* MAY BE IMPLEMENTED (OVERRIDDEN) *******************/ + + /** + * Called when a bad row is seen. By default, output a warning. + * + * @param line Text of the row. + * @param fieldvals Field values parsed from the row. + */ + def handle_bad_row(line: String, fieldvals: Seq[String]) { + val lineno = num_processed + 1 + if (fieldnames.length != fieldvals.length) { + warning( + """Line %s: Bad record, expected %s fields, saw %s fields; + skipping line=%s""", lineno, fieldnames.length, fieldvals.length, + line) + } else { + warning("""Line %s: Bad record; skipping line=%s""", lineno, line) + } + num_bad += 1 + } + + /* Also, the same "may-be-implemented" functions from the superclass + LineProcessor. */ + } + + /** + * An object describing a textdb schema, i.e. a description of each of the + * fields in a textdb, along with "fixed fields" containing the same + * value for every row. + * + * @param fieldnames List of the name of each field + * @param fixed_values Map specifying additional fields possessing the + * same value for every row. This is optional, but usually at least + * the "corpus-name" field should be given, with the name of the corpus + * (currently used purely for identification purposes). + */ + class Schema( + val fieldnames: Seq[String], + val fixed_values: Map[String, String] = Map[String, String]() + ) { + + val field_indices = fieldnames.zipWithIndex.toMap + + def check_values_fit_schema(fieldvals: Seq[String]) { + if (fieldvals.length != fieldnames.length) + throw FileFormatException( + "Wrong-length line, expected %d fields, found %d: %s" format ( + fieldnames.length, fieldvals.length, fieldvals)) + } + + def get_field(fieldvals: Seq[String], key: String, + error_if_missing: Boolean = true) = + get_field_or_else(fieldvals, key, error_if_missing = error_if_missing) + + def get_field_or_else(fieldvals: Seq[String], key: String, + default: String = null, error_if_missing: Boolean = false): String = { + check_values_fit_schema(fieldvals) + if (field_indices contains key) + fieldvals(field_indices(key)) + else + get_fixed_field(key, default, error_if_missing) + } + + def get_fixed_field(key: String, default: String = null, + error_if_missing: Boolean = false) = { + if (fixed_values contains key) + fixed_values(key) + else + Schema.error_or_default(key, default, error_if_missing) + } + + /** + * Output the schema to a file. + */ + def output_schema_file(filehand: FileHandler, schema_file: String, + split_text: String = "\t") { + val schema_outstream = filehand.openw(schema_file) + schema_outstream.println(fieldnames mkString split_text) + for ((field, value) <- fixed_values) + schema_outstream.println(Seq(field, value) mkString split_text) + schema_outstream.close() + } + + /** + * Output the schema to a file. The file will be named + * `DIR/PREFIX-SUFFIX-schema.txt`. + */ + def output_constructed_schema_file(filehand: FileHandler, dir: String, + prefix: String, suffix: String, split_text: String = "\t") { + val schema_file = Schema.construct_schema_file(filehand, dir, prefix, + suffix) + output_schema_file(filehand, schema_file, split_text) + } + + /** + * Output a row describing a document. + * + * @param outstream The output stream to write to, as returned by + * `open_document_file`. + * @param fieldvals Iterable describing the field values to be written. + * There should be as many items as there are field names in the + * `fieldnames` field of the schema. + */ + def output_row(outstream: PrintStream, fieldvals: Seq[String], + split_text: String = "\t") { + assert(fieldvals.length == fieldnames.length, + "values %s (length %s) not same length as fields %s (length %s)" format + (fieldvals, fieldvals.length, fieldnames, + fieldnames.length)) + outstream.println(fieldvals mkString split_text) + } + } + + /** + * A Schema that can be used to select some fields from a larger schema. + * + * @param fieldnames Names of fields in this schema; should be a subset of + * the field names in `orig_schema` + * @param fixed_values Fixed values in this schema + * @param orig_schema Original schema from which fields have been selected. + */ + class SubSchema( + fieldnames: Seq[String], + fixed_values: Map[String, String] = Map[String, String](), + val orig_schema: Schema + ) extends Schema(fieldnames, fixed_values) { + val orig_field_indices = + orig_schema.field_indices.filterKeys(fieldnames contains _).values.toSet + + /** + * Given a set of field values corresponding to the original schema + * (`orig_schema`), produce a list of field values corresponding to this + * schema. + */ + def map_original_fieldvals(fieldvals: Seq[String]) = + fieldvals.zipWithIndex. + filter { case (x, ind) => orig_field_indices contains ind }. + map { case (x, ind) => x } + } + + object Schema { + /** + * Construct the name of a schema file, based on the given file handler, + * directory, prefix and suffix. The file will end with "-schema.txt". + */ + def construct_schema_file(filehand: FileHandler, dir: String, + prefix: String, suffix: String) = + TextDBProcessor.construct_output_file(filehand, dir, prefix, + suffix, "-schema.txt") + + /** + * Read the given schema file. + * + * @param filehand File handler of schema file name. + * @param schema_file Name of the schema file. + * @param split_re Regular expression used to split the fields of the + * schema file, usually TAB. (There's only one row, and each field in + * the row gives the name of the corresponding field in the document + * file.) + */ + def read_schema_file(filehand: FileHandler, schema_file: String, + split_re: String = "\t") = { + val lines = filehand.openr(schema_file) + val fieldname_line = lines.next() + val fieldnames = fieldname_line.split(split_re, -1) + for (field <- fieldnames if field.length == 0) + throw new FileFormatException( + "Blank field name in schema file %s: fields are %s". + format(schema_file, fieldnames)) + var fixed_fields = Map[String,String]() + for (line <- lines) { + val fixed = line.split(split_re, -1) + if (fixed.length != 2) + throw new FileFormatException( + "For fixed fields (i.e. lines other than first) in schema file %s, should have two values (FIELD and VALUE), instead of %s". + format(schema_file, line)) + val Array(from, to) = fixed + if (from.length == 0) + throw new FileFormatException( + "Blank field name in fxed-value part of schema file %s: line is %s". + format(schema_file, line)) + fixed_fields += (from -> to) + } + new Schema(fieldnames, fixed_fields) + } + + def get_field(fieldnames: Seq[String], fieldvals: Seq[String], key: String, + error_if_missing: Boolean = true) = + get_field_or_else(fieldnames, fieldvals, key, + error_if_missing = error_if_missing) + + def get_field_or_else(fieldnames: Seq[String], fieldvals: Seq[String], + key: String, default: String = null, + error_if_missing: Boolean = false): String = { + assert(fieldvals.length == fieldnames.length) + var i = 0 + while (i < fieldnames.length) { + if (fieldnames(i) == key) return fieldvals(i) + i += 1 + } + return error_or_default(key, default, error_if_missing) + } + + protected def error_or_default(key: String, default: String, + error_if_missing: Boolean) = { + if (error_if_missing) { + throw new NoSuchElementException("key not found: %s" format key) + } else default + } + + /** + * Convert a set of field names and values to a map, to make it easier + * to work with them. The result is a mutable order-preserving map, + * which is important so that when converted back to separate lists of + * names and values, the values are still written out correctly. + * (The immutable order-preserving ListMap isn't sufficient since changing + * a field value results in the field getting moved to the end.) + * + */ + def to_map(fieldnames: Seq[String], fieldvals: Seq[String]) = + mutable.LinkedHashMap[String, String]() ++ (fieldnames zip fieldvals) + + /** + * Convert from a map back to a tuple of lists of field names and values. + */ + def from_map(map: mutable.Map[String, String]) = + map.toSeq.unzip + + } + + /** + * File processor for reading in a set of records in "textdb" format. + * The database has the following format: + * + * (1) The documents are stored as field-text files, separated by a TAB + * character. + * (2) There is a corresponding schema file, which lists the names of + * each field, separated by a TAB character, as well as any + * "fixed" fields that have the same value for all rows (one per + * line, with the name, a TAB, and the value). + * (3) The document and schema files are identified by a suffix. + * The document files are named `DIR/PREFIX-SUFFIX.txt` + * (or `DIR/PREFIX-SUFFIX.txt.bz2` or similar, for compressed files), + * while the schema file is named `DIR/PREFIX-SUFFIX-schema.txt`. + * Note that the SUFFIX is set when the `TextDBLineProcessor` is + * created, and typically specifies the category of corpus being + * read (e.g. "text" for corpora containing text or "unigram-counts" + * for a corpus containing unigram counts). The directory is specified + * in a particular call to `process_files` or `read_schema_from_textdb`. + * The prefix is arbitrary and descriptive -- i.e. any files in the + * appropriate directory and with the appropriate suffix, regardless + * of prefix, will be loaded. The prefix of the currently-loading + * document file is available though the field `current_document_prefix`. + * + * The most common setup is to have the schema file and any document files + * placed in the same directory, although it's possible to have them in + * different directories or to have document files scattered across multiple + * directories. Note that the naming of the files allows for multiple + * document files in a single directory, as well as multiple corpora to + * coexist in the same directory, as long as they have different suffixes. + * This is often used to present different "views" onto the same corpus + * (e.g. one containing raw text, one containing unigram counts, etc.), or + * different splits (e.g. training vs. dev vs. test). (In fact, it is + * common to divide a corpus into sub-corpora according to the split. + * In such a case, document files will be named `DIR/PREFIX-SPLIT-SUFFIX.txt` + * or similar. This allows all files for all splits to be located using a + * suffix consisting only of the final "SUFFIX" part, while a particular + * split can be located using a larger prefix of the form "SPLIT-SUFFIX".) + * + * Generally, after creating a file processor of this sort, the schema + * file needs to be read using `read_schema_from_textdb`; then the document + * files can be processed using `process_files`. Most commonly, the same + * directory is passed to both functions. In more complicated setups, + * however, different directory names can be used; multiple calls to + * `process_files` can be made to process multiple directories; or + * individual file names can be given to `process_files` for maximum + * control. + * + * Various fields store things like the current directory and file prefix + * (the part before the suffix). + * + * @param suffix the suffix of the corpus files, as described above + * + */ + abstract class TextDBLineProcessor[T]( + suffix: String + ) extends LineProcessor[T] { + import TextDBProcessor._ + + /** + * Name of the schema file. + */ + var schema_file: String = _ + /** + * File handler of the schema file. + */ + var schema_filehand: FileHandler = _ + /** + * Directory of the schema file. + */ + var schema_dir: String = _ + /** + * Prefix of the schema file (see above). + */ + var schema_prefix: String = _ + /** + * Schema read from the schema file. + */ + var schema: Schema = _ + + /** + * Current document file being read. + */ + var current_document_file: String = _ + /** + * File handler of the current document file. + */ + var current_document_filehand: FileHandler = _ + /** + * "Real name" of the current document file, after any compression suffix + * has been removed. + */ + var current_document_realname: String = _ + /** + * Type of compression of the current document file. + */ + var current_document_compression: String = _ + /** + * Directory of the current document file. + */ + var current_document_dir: String = _ + /** + * Prefix of the current document file (see above). + */ + var current_document_prefix: String = _ + + def set_schema(schema: Schema) { + this.schema = schema + } + + override def begin_process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) { + current_document_compression = compression + current_document_filehand = filehand + current_document_file = file + current_document_realname = realname + val (dir, base) = filehand.split_filename(realname) + current_document_dir = dir + current_document_prefix = base.stripSuffix("-" + suffix + ".txt") + super.begin_process_lines(lines, filehand, file, compression, realname) + } + + /** + * Locate and read the schema file of the appropriate suffix in the + * given directory. Set internal variables containing the schema file + * and schema. + */ + def read_schema_from_textdb(filehand: FileHandler, dir: String) { + schema_file = find_schema_file(filehand, dir, suffix) + schema_filehand = filehand + val (_, base) = filehand.split_filename(schema_file) + schema_dir = dir + schema_prefix = base.stripSuffix("-" + suffix + "-schema.txt") + val schema = Schema.read_schema_file(filehand, schema_file) + set_schema(schema) + } + + /** + * List only the document files of the appropriate suffix. + */ + override def list_files(filehand: FileHandler, dir: String) = { + val filter = make_document_file_suffix_regex(suffix) + val files = filehand.list_files(dir) + for (file <- files if filter.findFirstMatchIn(file) != None) yield file + } + } + + object TextDBProcessor { + val possible_compression_re = """(\.bz2|\.bzip2|\.gz|\.gzip)?$""" + /** + * For a given suffix, create a regular expression + * ([[scala.util.matching.Regex]]) that matches document files of the + * suffix. + */ + def make_document_file_suffix_regex(suffix: String) = { + val re_quoted_suffix = """-%s\.txt""" format suffix + (re_quoted_suffix + possible_compression_re).r + } + /** + * For a given suffix, create a regular expression + * ([[scala.util.matching.Regex]]) that matches schema files of the + * suffix. + */ + def make_schema_file_suffix_regex(suffix: String) = { + val re_quoted_suffix = """-%s-schema\.txt""" format suffix + (re_quoted_suffix + possible_compression_re).r + } + + /** + * Construct the name of a file (either schema or document file), based + * on the given file handler, directory, prefix, suffix and file ending. + * For example, if the file ending is "-schema.txt", the file will be + * named `DIR/PREFIX-SUFFIX-schema.txt`. + */ + def construct_output_file(filehand: FileHandler, dir: String, + prefix: String, suffix: String, file_ending: String) = { + val new_base = prefix + "-" + suffix + file_ending + filehand.join_filename(dir, new_base) + } + + /** + * Locate the schema file of the appropriate suffix in the given directory. + */ + def find_schema_file(filehand: FileHandler, dir: String, suffix: String) = { + val schema_regex = make_schema_file_suffix_regex(suffix) + val all_files = filehand.list_files(dir) + val files = + (for (file <- all_files + if schema_regex.findFirstMatchIn(file) != None) yield file).toSeq + if (files.length == 0) + throw new FileFormatException( + "Found no schema files (matching %s) in directory %s" + format (schema_regex, dir)) + if (files.length > 1) + throw new FileFormatException( + "Found multiple schema files (matching %s) in directory %s: %s" + format (schema_regex, dir, files)) + files(0) + } + + /** + * Locate and read the schema file of the appropriate suffix in the + * given directory. + */ + def read_schema_from_textdb(filehand: FileHandler, dir: String, + suffix: String) = { + val schema_file = find_schema_file(filehand, dir, suffix) + Schema.read_schema_file(filehand, schema_file) + } + + /** + * Return a list of shell-style wildcard patterns matching all the document + * files in the given directory with the given suffix (including compressed + * files). + */ + def get_matching_patterns(filehand: FileHandler, dir: String, + suffix: String) = { + val possible_endings = Seq("", ".bz2", ".bzip2", ".gz", ".gzip") + for {ending <- possible_endings + full_ending = "-%s.txt%s" format (suffix, ending) + pattern = filehand.join_filename(dir, "*%s" format full_ending) + all_files = filehand.list_files(dir) + files = all_files.filter(_ endsWith full_ending) + if files.toSeq.length > 0} + yield pattern + } + } + + /** + * A basic file processor for reading a "corpus" of records in + * textdb format. You might want to use the higher-level + * `TextDBProcessor`. + * + * FIXME: Should probably eliminate this. + * + * @param suffix the suffix of the corpus files, as described above + */ + abstract class BasicTextDBProcessor[T]( + suffix: String + ) extends TextDBLineProcessor[T](suffix) with FieldLineProcessor[T] { + override def set_schema(schema: Schema) { + super.set_schema(schema) + set_fieldnames(schema.fieldnames) + } + } + + class ExitTextDBProcessor[T](val value: Option[T]) extends Throwable { } + + /** + * A file processor for reading in a "corpus" of records in textdb + * format (where each record or "row" is a single line, with fields + * separated by TAB) and processing each one. Each row is assumed to + * generate an object of type T. The result of calling `process_file` + * will be a Seq[T] of all objects, and the result of calling + * `process_files` to process all files will be a Seq[Seq[T]], one per + * file. + * + * You should implement `handle_row`, which is passed in the field + * values, and should return either `Some(x)` for x of type T, if + * you were able to process the row, or `None` otherwise. If you + * want to exit further processing, throw a `ExitTextDBProcessor(value)`, + * where `value` is either `Some(x)` or `None`, as for a normal return + * value. + * + * @param suffix the suffix of the corpus files, as described above + */ + abstract class TextDBProcessor[T](suffix: String) extends + BasicTextDBProcessor[Seq[T]](suffix) { + + def handle_row(fieldvals: Seq[String]): Option[T] + + val values_so_far = mutable.ListBuffer[T]() + + def process_row(fieldvals: Seq[String]): (Boolean, Boolean) = { + try { + handle_row(fieldvals) match { + case Some(x) => { values_so_far += x; return (true, true) } + case None => { return (false, true) } + } + } catch { + case exit: ExitTextDBProcessor[T] => { + exit.value match { + case Some(x) => { values_so_far += x; return (true, false) } + case None => { return (false, false) } + } + } + } + } + + val task = new MeteredTask("document", "reading") + def process_lines(lines: Iterator[String], + filehand: FileHandler, file: String, + compression: String, realname: String) = { + var should_stop = false + values_so_far.clear() + breakable { + for (line <- lines) { + if (!parse_row(line)) { + should_stop = true + break + } + } + } + (!should_stop, values_so_far.toSeq) + } + + override def end_processing(filehand: FileHandler, files: Iterable[String]) { + task.finish() + super.end_processing(filehand, files) + } + + /** + * Read a corpus from a directory and return the result of processing the + * rows in the corpus. (If you want more control over the processing, + * call `read_schema_from_textdb` and then `process_files`. This allows, + * for example, reading files from multiple directories with a single + * schema, or determining whether processing was aborted early or allowed + * to run to completion.) + * + * @param filehand File handler object of the directory + * @param dir Directory to read + * + * @return A sequence of sequences of values. There is one inner sequence + * per file read in, and each such sequence contains all the values + * read from the file. (There may be fewer values than rows in a file + * if some rows were rejected by `handle_row`, and there may be fewer + * files read than exist in the directory if `handle_row` signalled that + * processing should stop.) + */ + def read_textdb(filehand: FileHandler, dir: String) = { + read_schema_from_textdb(filehand, dir) + val (finished, value) = process_files(filehand, Seq(dir)) + value + } + + /* + FIXME: Should be implemented. Requires that process_files() returns + the filename along with the value. (Should not be a problem for any + existing users of BasicTextDBProcessor, because AFAIK they + ignore the value.) + + def read_textdb_with_filenames(filehand: FileHandler, dir: String) = ... + */ + } + + /** + * Class for writing a "corpus" of documents. The corpus has the + * format described in `TextDBLineProcessor`. + * + * @param schema the schema describing the fields in the document file + * @param suffix the suffix of the corpus files, as described in + * `TextDBLineProcessor` + * + */ + class TextDBWriter( + val schema: Schema, + val suffix: String + ) { + /** + * Text used to separate fields. Currently this is always a tab + * character, and no provision is made for changing this. + */ + val split_text = "\t" + + /** + * Open a document file and return an output stream. The file will be + * named `DIR/PREFIX-SUFFIX.txt`, possibly with an additional suffix + * (e.g. `.bz2`), depending on the specified compression (which defaults + * to no compression). Call `output_row` to output a row describing + * a document. + */ + def open_document_file(filehand: FileHandler, dir: String, + prefix: String, compression: String = "none") = { + val file = TextDBProcessor.construct_output_file(filehand, dir, + prefix, suffix, ".txt") + filehand.openw(file, compression = compression) + } + + /** + * Output the schema to a file. The file will be named + * `DIR/PREFIX-SUFFIX-schema.txt`. + */ + def output_schema_file(filehand: FileHandler, dir: String, + prefix: String) { + schema.output_constructed_schema_file(filehand, dir, prefix, suffix, + split_text) + } + } + + val document_metadata_suffix = "document-metadata" + val unigram_counts_suffix = "unigram-counts" + val ngram_counts_suffix = "ngram-counts" + val text_suffix = "text" + + class EncodeDecode(val chars_to_encode: Seq[Char]) { + private val encode_chars_regex = "[%s]".format(chars_to_encode mkString "").r + private val encode_chars_map = + chars_to_encode.map(c => (c.toString, "%%%02X".format(c.toInt))).toMap + private val decode_chars_map = + encode_chars_map.toSeq.flatMap { + case (dec, enc) => Seq((enc, dec), (enc.toLowerCase, dec)) }.toMap + private val decode_chars_regex = + "(%s)".format(decode_chars_map.keys mkString "|").r + + def encode(str: String) = + encode_chars_regex.replaceAllIn(str, m => encode_chars_map(m.matched)) + def decode(str: String) = + decode_chars_regex.replaceAllIn(str, m => decode_chars_map(m.matched)) + } + + private val endec_string_for_count_map_field = + new EncodeDecode(Seq('%', ':', ' ', '\t', '\n', '\r', '\f')) + private val endec_string_for_sequence_field = + new EncodeDecode(Seq('%', '>', '\t', '\n', '\r', '\f')) + private val endec_string_for_whole_field = + new EncodeDecode(Seq('%', '\t', '\n', '\r', '\f')) + + /** + * Encode a word for placement inside a "counts" field. Colons and spaces + * are used for separation inside of a field, and tabs and newlines are used + * for separating fields and records. We need to escape all of these + * characters (normally whitespace should be filtered out during + * tokenization, but for some applications it won't necessarily). We do this + * using URL-style-encoding, e.g. replacing : by %3A; hence we also have to + * escape % signs. (We could equally well use HTML-style encoding; then we'd + * have to escape & instead of :.) Note that regardless of whether we use + * URL-style or HTML-style encoding, we probably want to do the encoding + * ourselves rather than use a predefined encoder. We could in fact use the + * presupplied URL encoder, but it would encode all sorts of stuff, which is + * unnecessary and would make the raw files harder to read. In the case of + * HTML-style encoding, : isn't even escaped, so that wouldn't work at all. + */ + def encode_string_for_count_map_field(word: String) = + endec_string_for_count_map_field.encode(word) + + /** + * Encode an n-gram into text suitable for the "counts" field. + The + * individual words are separated by colons, and each word is encoded + * using `encode_string_for_count_map_field`. We need to encode '\n' + * (record separator), '\t' (field separator), ' ' (separator between + * word/count pairs), ':' (separator between word and count), + * '%' (encoding indicator). + */ + def encode_ngram_for_count_map_field(ngram: Iterable[String]) = { + ngram.map(encode_string_for_count_map_field) mkString ":" + } + + /** + * Decode a word encoded using `encode_string_for_count_map_field`. + */ + def decode_string_for_count_map_field(word: String) = + endec_string_for_count_map_field.decode(word) + + /** + * Encode a string for placement in a field consisting of a sequence + * of strings. This is similar to `encode_string_for_count_map_field` except + * that we don't encode spaces. We encode '>' for use as a separator + * inside of a field (since it's almost certain not to occur, because + * we generally get HTML-encoded text; and even if not, it's fairly + * rare). + */ + def encode_string_for_sequence_field(word: String) = + endec_string_for_sequence_field.encode(word) + + /** + * Decode a string encoded using `encode_string_for_sequence_field`. + */ + def decode_string_for_sequence_field(word: String) = + endec_string_for_sequence_field.decode(word) + + /** + * Encode a string for placement in a field by itself. This is similar + * to `encode_word_for_sequence_field` except that we don't encode the > + * sign. + */ + def encode_string_for_whole_field(word: String) = + endec_string_for_whole_field.encode(word) + + /** + * Decode a string encoded using `encode_string_for_whole_field`. + */ + def decode_string_for_whole_field(word: String) = + endec_string_for_whole_field.decode(word) + + /** + * Decode an n-gram encoded using `encode_ngram_for_count_map_field`. + */ + def decode_ngram_for_count_map_field(ngram: String) = { + ngram.split(":", -1).map(decode_string_for_count_map_field) + } + + /** + * Split counts field into the encoded n-gram section and the word count. + */ + def shallow_split_count_map_field(field: String) = { + val last_colon = field.lastIndexOf(':') + if (last_colon < 0) + throw FileFormatException( + "Counts field must be of the form WORD:WORD:...:COUNT, but %s seen" + format field) + val count = field.slice(last_colon + 1, field.length).toInt + (field.slice(0, last_colon), count) + } + + /** + * Split counts field into n-gram and word count. + */ + def deep_split_count_map_field(field: String) = { + val (encoded_ngram, count) = shallow_split_count_map_field(field) + (decode_ngram_for_count_map_field(encoded_ngram), count) + } + + /** + * Serialize a sequence of (encoded-word, count) pairs into the format used + * in a corpus. The word or ngram must already have been encoded using + * `encode_string_for_count_map_field` or `encode_ngram_for_count_map_field`. + */ + def shallow_encode_count_map(seq: collection.Seq[(String, Int)]) = { + // Sorting isn't strictly necessary but ensures consistent output as well + // as putting the most significant items first, for visual confirmation. + (for ((word, count) <- seq sortWith (_._2 > _._2)) yield + ("%s:%s" format (word, count))) mkString " " + } + + /** + * Serialize a sequence of (word, count) pairs into the format used + * in a corpus. + */ + def encode_count_map(seq: collection.Seq[(String, Int)]) = { + shallow_encode_count_map(seq map { + case (word, count) => (encode_string_for_count_map_field(word), count) + }) + } + + /** + * Deserialize an encoded word-count map into a sequence of + * (word, count) pairs. + */ + def decode_count_map(encoded: String) = { + if (encoded.length == 0) + Array[(String, Int)]() + else + { + val wordcounts = encoded.split(" ") + for (wordcount <- wordcounts) yield { + val split_wordcount = wordcount.split(":", -1) + if (split_wordcount.length != 2) + throw FileFormatException( + "For unigram counts, items must be of the form WORD:COUNT, but %s seen" + format wordcount) + val Array(word, strcount) = split_wordcount + if (word.length == 0) + throw FileFormatException( + "For unigram counts, WORD in WORD:COUNT must not be empty, but %s seen" + format wordcount) + val count = strcount.toInt + val decoded_word = decode_string_for_count_map_field(word) + (decoded_word, count) + } + } + } + + object Encoder { + def count_map(x: collection.Map[String, Int]) = encode_count_map(x.toSeq) + def count_map_seq(x: collection.Seq[(String, Int)]) = encode_count_map(x) + def string(x: String) = encode_string_for_whole_field(x) + def string_in_seq(x: String) = encode_string_for_sequence_field(x) + def seq_string(x: collection.Seq[String]) = + x.map(encode_string_for_sequence_field) mkString ">>" + def timestamp(x: Long) = x.toString + def long(x: Long) = x.toString + def int(x: Int) = x.toString + def double(x: Double) = x.toString + } + + object Decoder { + def count_map(x: String) = decode_count_map(x).toMap + def count_map_seq(x: String) = decode_count_map(x) + def string(x: String) = decode_string_for_whole_field(x) + def seq_string(x: String) = + x.split(">>", -1).map(decode_string_for_sequence_field) + def timestamp(x: String) = x.toLong + def long(x: String) = x.toLong + def int(x: String) = x.toInt + def double(x: String) = x.toDouble + } + +} diff --git a/src/main/scala/opennlp/fieldspring/util/textutil.scala b/src/main/scala/opennlp/fieldspring/util/textutil.scala new file mode 100644 index 0000000..0fdbd95 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/textutil.scala @@ -0,0 +1,193 @@ +/////////////////////////////////////////////////////////////////////////////// +// TextUtil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import scala.util.control.Breaks._ +import scala.util.matching.Regex +import math._ + +import printutil.warning + +package object textutil { + + //////////////////////////////////////////////////////////////////////////// + // String functions involving numbers // + //////////////////////////////////////////////////////////////////////////// + + /** + Convert a string to floating point, but don't crash on errors; + instead, output a warning. + */ + def safe_float(x: String) = { + try { + x.toDouble + } catch { + case _ => { + val y = x.trim() + if (y != "") warning("Expected number, saw %s", y) + 0. + } + } + } + + // Originally based on code from: + // http://stackoverflow.com/questions/1823058/how-to-print-number-with-commas-as-thousands-separators-in-python-2-x + def with_commas(x: Long): String = { + var mx = x + if (mx < 0) + "-" + with_commas(-mx) + else { + var result = "" + while (mx >= 1000) { + val r = mx % 1000 + mx /= 1000 + result = ",%03d%s" format (r, result) + } + "%d%s" format (mx, result) + } + } + + // My own version + def with_commas(x: Double): String = { + val intpart = floor(x).toInt + val fracpart = x - intpart + with_commas(intpart) + ("%.2f" format fracpart).drop(1) + } + + /** + * Try to format a floating-point number using %f style (i.e. avoding + * scientific notation) and with a fixed number of significant digits + * after the decimal point. + * + * @param sigdigits Number of significant digits after decimal point + * to display. + * @param include_plus If true, include a + sign before positive numbers. + */ + def format_float(x: Double, sigdigits: Int = 2, + include_plus: Boolean = false) = { + var precision = sigdigits + if (x != 0) { + var xx = abs(x) + while (xx < 0.1) { + xx *= 10 + precision += 1 + } + } + val formatstr = + "%%%s.%sf" format (if (include_plus) "+" else "", precision) + formatstr format x + } + + //////////////////////////////////////////////////////////////////////////// + // Other string functions // + //////////////////////////////////////////////////////////////////////////// + + /** + * Split a string, similar to `str.split(delim_re, -1)`, but also + * return the delimiters. Return an Iterable of tuples `(text, delim)` + * where `delim` is the delimiter following each section of text. The + * last delimiter will be an empty string. + */ + def split_with_delim(str: String, delim_re: Regex): + Iterable[(String, String)] = { + // Find all occurrences of regexp, extract start and end positions, + // flatten, and add suitable positions for the start and end of the string. + // Adding the end-of-string position twice ensures that we get the empty + // delimiter at the end. + val delim_intervals = + Iterator(0) ++ + delim_re.findAllIn(str).matchData.flatMap(m => Iterator(m.start, m.end)) ++ + Iterator(str.length, str.length) + // Group into (start, end) pairs for both text and delimiter, extract + // the strings, group into text-delimiter tuples and convert to Iterable. + delim_intervals sliding 2 map { + case Seq(start, end) => str.slice(start, end) + } grouped 2 map { + case Seq(text, delim) => (text, delim) + } toIterable + } + + def split_text_into_words(text: String, ignore_punc: Boolean=false, + include_nl: Boolean=false) = { + // This regexp splits on whitespace, but also handles the following cases: + // 1. Any of , ; . etc. at the end of a word + // 2. Parens or quotes in words like (foo) or "bar" + // These punctuation characters are returned as separate words, unless + // 'ignore_punc' is given. Also, if 'include_nl' is given, newlines are + // returned as their own words; otherwise, they are treated like all other + // whitespace (i.e. ignored). + (for ((word, punc) <- + split_with_delim(text, """([,;."):]*(?:\s+|$)[("]*)""".r)) yield + Seq(word) ++ ( + for (p <- punc; if !(" \t\r\f\013" contains p)) yield ( + if (p == '\n') (if (include_nl) p.toString else "") + else (if (!ignore_punc) p.toString else "") + ) + ) + ) reduce (_ ++ _) filter (_ != "") + } + + + /** + Pluralize an English word, using a basic but effective algorithm. + */ + def pluralize(word: String) = { + val upper = word.last >= 'A' && word.last <= 'Z' + val lowerword = word.toLowerCase() + val ies_re = """.*[b-df-hj-np-tv-z]y$""".r + val es_re = """.*([cs]h|[sx])$""".r + lowerword match { + case ies_re() => + if (upper) word.dropRight(1) + "IES" + else word.dropRight(1) + "ies" + case es_re() => + if (upper) word + "ES" + else word + "es" + case _ => + if (upper) word + "S" + else word + "s" + } + } + + /** + Capitalize the first letter of string, leaving the remainder alone. + */ + def capfirst(st: String) = { + if (st == "") st else st(0).toString.capitalize + st.drop(1) + } + + /* + A simple object to make regexps a bit less awkward. Works like this: + + ("foo (.*)", "foo bar") match { + case Re(x) => println("matched 1 %s" format x) + case _ => println("no match 1") + } + + This will print out "matched 1 bar". + */ + + object Re { + def unapplySeq(x: Tuple2[String, String]) = { + val (re, str) = x + re.r.unapplySeq(str) + } + } +} + diff --git a/src/main/scala/opennlp/fieldspring/util/timeutil.scala b/src/main/scala/opennlp/fieldspring/util/timeutil.scala new file mode 100644 index 0000000..844ed2d --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/timeutil.scala @@ -0,0 +1,158 @@ +/////////////////////////////////////////////////////////////////////////////// +// timeutil.scala +// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.util + +import math._ +import textutil.split_with_delim + +package object timeutil { + + def format_minutes_seconds(seconds: Double, hours: Boolean = true) = { + var secs = seconds + var mins = (secs / 60).toInt + secs = secs % 60 + val hourstr = { + if (!hours) "" + else { + val hours = (mins / 60).toInt + mins = mins % 60 + if (hours > 0) "%s hour%s " format (hours, if (hours == 1) "" else "s") + else "" + } + } + val secstr = (if (secs.toInt == secs) "%s" else "%1.1f") format secs + "%s%s minute%s %s second%s" format ( + hourstr, + mins, if (mins == 1) "" else "s", + secstr, if (secs == 1) "" else "s") + } + + /** + * Parse a date and return a time as milliseconds since the Epoch + * (Jan 1, 1970). Accepts various formats, all variations of the + * following: + * + * 20100802180500PST (= August 2, 2010, 18:05:00 Pacific Standard Time) + * 2010:08:02:0605pmPST (= same) + * 20100802:10:05pm (= same if current time zone is Eastern Daylight) + * + * That is, either 12-hour or 24-hour time can be given, colons can be + * inserted anywhere for readability, and the time zone can be omitted or + * specified. In addition, part or all of the time of day (hours, + * minutes, seconds) can be omitted. Years must always be full (i.e. + * 4 digits). + */ + def parse_date(datestr: String): Option[Long] = { + // Variants for the hour-minute-second portion + val hms_variants = List("", "HH", "HHmm", "HHmmss", "hhaa", "hhmmaa", + "hhmmssaa") + // Fully-specified format including date + val full_fmt = hms_variants.map("yyyyMMdd"+_) + // All formats, including variants with time zone specified + val all_fmt = full_fmt ++ full_fmt.map(_+"zz") + for (fmt <- all_fmt) { + val pos = new java.text.ParsePosition(0) + val formatter = new java.text.SimpleDateFormat(fmt) + // (Possibly we shouldn't do this?) This rejects nonstandardness, e.g. + // out-of-range values such as month 13 or hour 25; that's useful for + // error-checking in case someone messed up entering the date. + formatter.setLenient(false) + val canon_datestr = datestr.replace(":", "") + val date = formatter.parse(canon_datestr, pos) + if (date != null && pos.getIndex == canon_datestr.length) + return Some(date.getTime) + } + None + } + + /** + * Parse a time offset specification, e.g. "5h" (or "+5h") for 5 hours, + * or "3m2s" for "3 minutes 2 seconds". Negative values are allowed, + * to specify offsets going backwards in time. If able to parse, return + * a tuple (Some(millisecs),""); else return (None,errmess) specifying + * an error message. + */ + def parse_time_offset(str: String): (Option[Long], String) = { + if (str.length == 0) + return (None, "Time offset cannot be empty") + val sections = + split_with_delim(str.toLowerCase, "[a-z]+".r).filterNot { + case (text, delim) => text.length == 0 && delim.length == 0 } + val offset_secs = + (for ((valstr, units) <- sections) yield { + val multiplier = + units match { + case "s" => 1 + case "m" => 60 + case "h" => 60*60 + case "d" => 60*60*24 + case "w" => 60*60*24*7 + case "" => + return (None, "Missing units in component '%s' in time offset '%s'; should be e.g. '25s' or '10h30m'" + format (valstr + units, str)) + case _ => + return (None, "Unrecognized component '%s' in time offset '%s'; should be e.g. '25s' or '10h30m'" + format (valstr + units, str)) + } + val value = + try { + valstr.toDouble + } catch { + case e => return (None, + "Unable to convert value '%s' in component '%s' in time offset '%s': %s" format + (valstr, valstr + units, str, e)) + } + multiplier * value + }).sum + (Some((offset_secs * 1000).toLong), "") + } + + /** + * Parse a date and offset into an interval. Should be of the + * form TIME/LENGTH, e.g. '20100802180502PST/2h3m' (= starting at + * August 2, 2010, 18:05:02 Pacific Standard Time, ending exactly + * 2 hours 3 minutes later). Negative offsets are allowed, to indicate + * an interval backwards from a reference point. + * + * @return Tuple of `(Some((start, end)),"")` if able to parse, else + * return None along with an error message. + */ + def parse_date_interval(str: String): (Option[(Long, Long)], String) = { + val date_offset = str.split("/", -1) + if (date_offset.length != 2) + (None, "Time chunk %s must be of the format 'START/LENGTH'" + format str) + else { + val Array(datestr, offsetstr) = date_offset + val timelen = parse_time_offset(offsetstr) match { + case (Some(len), "") => len + case (None, errmess) => return (None, errmess) + } + parse_date(datestr) match { + case Some(date) => + (Some((date, date + timelen)), "") + case None => + (None, + "Can't parse time '%s'; should be something like 201008021805pm" + format datestr) + } + } + } +} + diff --git a/src/main/scala/opennlp/fieldspring/util/twokenize.scala b/src/main/scala/opennlp/fieldspring/util/twokenize.scala new file mode 100644 index 0000000..a0da428 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/util/twokenize.scala @@ -0,0 +1,212 @@ +package opennlp.fieldspring.util + +/* Fieldspring note: + + This has been taken directly from Jason Baldridge's twokenize.scala + version as of 2011-06-13, from here: + + https://bitbucket.org/jasonbaldridge/twokenize + + The only change so far has been adding this comment and the above + package statement. + + FIXME: Make Twokenize be a package retrievable by Maven. + + - Ben Wing (ben@benwing.com) November 2011 +*/ + +/* + TweetMotif is licensed under the Apache License 2.0: + http://www.apache.org/licenses/LICENSE-2.0.html + Copyright Brendan O'Connor, Michel Krieger, and David Ahn, 2009-2010. +*/ + +/* + + Scala port of Brendar O' Connor's twokenize.py + + This is not a direct port, as some changes were made in the aim of + simplicity. In the Python version, the @tokenize@ method returned a + Tokenization object which wrapped a Python List with some extra + methods. + + The @tokenize@ method given here receives a String and returns the + tokenized Array[String] of the input text Twokenize.tokenize("foobar + baz.") => ['foobar', 'baz', '.'] + + The main method reads from stdin like it's python counterpart and + calls the above method on each line + + > scalac twokenize.scala + > echo "@foo #bar \$1.00 isn't 8:42 p.m. Mr. baz." | scala Twokenize + @foo $1.0 #bar baz . + + - David Snyder (dsnyder@cs.utexas.edu) + April 2011 + + Modifications to more functional style, fix a few bugs, and making + output more like twokenize.py. Added abbrevations. Tweaked some + regex's to produce better tokens. + + - Jason Baldridge (jasonbaldridge@gmail.com) + June 2011 +*/ + +import scala.util.matching.Regex + +object Twokenize { + + val Contractions = """(?i)(\w+)(n't|'ve|'ll|'d|'re|'s|'m)$""".r + val Whitespace = """\s+""".r + + val punctChars = """['“\".?!,:;]""" + val punctSeq = punctChars+"""+""" + val entity = """&(amp|lt|gt|quot);""" + + // URLs + + // David: I give the Larry David eye to this whole URL regex + // (http://www.youtube.com/watch?v=2SmoBvg-etU) There are + // potentially better options, see: + // http://daringfireball.net/2010/07/improved_regex_for_matching_urls + // http://mathiasbynens.be/demo/url-regex + + val urlStart1 = """(https?://|www\.)""" + val commonTLDs = """(com|co\.uk|org|net|info|ca|ly)""" + val urlStart2 = """[A-Za-z0-9\.-]+?\.""" + commonTLDs + """(?=[/ \W])""" + val urlBody = """[^ \t\r\n<>]*?""" + val urlExtraCrapBeforeEnd = "("+punctChars+"|"+entity+")+?" + val urlEnd = """(\.\.+|[<>]|\s|$)""" + val url = """\b("""+urlStart1+"|"+urlStart2+")"+urlBody+"(?=("+urlExtraCrapBeforeEnd+")?"+urlEnd+")" + + // Numeric + val timeLike = """\d+:\d+""" + val numNum = """\d+\.\d+""" + val numberWithCommas = """(\d+,)+?\d{3}""" + """(?=([^,]|$))""" + + // Note the magic 'smart quotes' (http://en.wikipedia.org/wiki/Smart_quotes) + val edgePunctChars = """'\"“”‘’<>«»{}\(\)\[\]""" + val edgePunct = "[" + edgePunctChars + "]" + val notEdgePunct = "[a-zA-Z0-9]" + val EdgePunctLeft = new Regex("""(\s|^)("""+edgePunct+"+)("+notEdgePunct+")") + val EdgePunctRight = new Regex("("+notEdgePunct+")("+edgePunct+"""+)(\s|$)""") + + // Abbreviations + val boundaryNotDot = """($|\s|[“\"?!,:;]|""" + entity + ")" + val aa1 = """([A-Za-z]\.){2,}(?=""" + boundaryNotDot + ")" + val aa2 = """[^A-Za-z]([A-Za-z]\.){1,}[A-Za-z](?=""" + boundaryNotDot + ")" + val standardAbbreviations = """\b([Mm]r|[Mm]rs|[Mm]s|[Dd]r|[Ss]r|[Jj]r|[Rr]ep|[Ss]en|[Ss]t)\.""" + val arbitraryAbbrev = "(" + aa1 +"|"+ aa2 + "|" + standardAbbreviations + ")" + + val separators = "(--+|―)" + val decorations = """[♫]+""" + val thingsThatSplitWords = """[^\s\.,]""" + val embeddedApostrophe = thingsThatSplitWords+"""+'""" + thingsThatSplitWords + """+""" + + // Emoticons + val normalEyes = "(?iu)[:=]" + val wink = "[;]" + val noseArea = "(|o|O|-)" // rather tight precision, \S might be reasonable... + val happyMouths = """[D\)\]]""" + val sadMouths = """[\(\[]""" + val tongue = "[pP]" + val otherMouths = """[doO/\\]""" // remove forward slash if http://'s aren't cleaned + + val emoticon = "("+normalEyes+"|"+wink+")" + noseArea + "("+tongue+"|"+otherMouths+"|"+sadMouths+"|"+happyMouths+")" + + // We will be tokenizing using these regexps as delimiters + val Protected = new Regex( + "(" + Array( + emoticon, + url, + entity, + timeLike, + numNum, + numberWithCommas, + punctSeq, + arbitraryAbbrev, + separators, + decorations, + embeddedApostrophe + ).mkString("|") + ")" ) + + // The main work of tokenizing a tweet. + def simpleTokenize (text: String) = { + + // Do the no-brainers first + val splitPunctText = splitEdgePunct(text) + val textLength = splitPunctText.length + + // Find the matches for subsequences that should be protected + // (not further split), e.g. URLs, 1.0, U.N.K.L.E., 12:53 + val matches = Protected.findAllIn(splitPunctText).matchData.toList + + // The protected spans should not be split. + val protectedSpans = matches map (mat => Tuple2(mat.start, mat.end)) + + // Create a list of indices to create the "splittables", which can be + // split. We are taking protected spans like + // List((2,5), (8,10)) + // to create + /// List(0, 2, 5, 8, 10, 12) + // where, e.g., "12" here would be the textLength + val indices = + (0 :: (protectedSpans flatMap { case (start,end) => List(start,end) }) + ::: List(textLength)) + + // Group the indices and map them to their respective portion of the string + val splittableSpans = + indices.grouped(2) map { x => splitPunctText.slice(x(0),x(1)) } toList + + //The 'splittable' strings are safe to be further tokenized by whitespace + val splittables = splittableSpans map { str => str.trim.split(" ").toList } + + //Storing as List[List[String]] to make zip easier later on + val protecteds = protectedSpans map { + case(start,end) => List(splitPunctText.slice(start,end)) } + + // Reinterpolate the 'splittable' and 'protected' Lists, ensuring that + // additonal tokens from last splittable item get included + val zippedStr = + (if (splittables.length == protecteds.length) + splittables.zip(protecteds) map { pair => pair._1 ++ pair._2 } + else + ((splittables.zip(protecteds) map { pair => pair._1 ++ pair._2 }) ::: + List(splittables.last)) + ).flatten + + // Split based on special patterns (like contractions) and check all tokens are non empty + zippedStr.map(splitToken(_)).flatten.filter(_.length > 0) + } + + // 'foo' => ' foo ' + def splitEdgePunct (input: String) = { + val splitLeft = EdgePunctLeft.replaceAllIn(input,"$1$2 $3") + EdgePunctRight.replaceAllIn(splitLeft,"$1 $2$3") + } + + // "foo bar" => "foo bar" + def squeezeWhitespace (input: String) = Whitespace.replaceAllIn(input," ").trim + + // Final pass tokenization based on special patterns + def splitToken (token: String) = { + token match { + case Contractions(stem, contr) => List(stem.trim, contr.trim) + case token => List(token.trim) + } + } + + // Apply method allows it to be used as Twokenize(line) in Scala. + def apply (text: String): List[String] = simpleTokenize(squeezeWhitespace(text)) + + // Named for Java coders who would wonder what the heck the 'apply' method is for. + def tokenize (text: String): List[String] = apply(text) + + // Main method + def main (args: Array[String]) = { + io.Source.stdin.getLines foreach { + line => println(apply(line) reduceLeft(_ + " " + _)) + } + } + +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/BigramWordDist.scala.bitrotted b/src/main/scala/opennlp/fieldspring/worddist/BigramWordDist.scala.bitrotted new file mode 100644 index 0000000..22380ca --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/BigramWordDist.scala.bitrotted @@ -0,0 +1,410 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2011 Ben Wing, Thomas Darr, Andy Luong, Erik Skiles, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +//////// +//////// BigramWordDist.scala +//////// +//////// Copyright (c) 2011 Ben Wing. +//////// + +package opennlp.fieldspring.worddist + +import math._ +import collection.mutable +import util.control.Breaks._ + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.ioutil.{FileHandler, FileFormatException} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import WordDist.memoizer._ + +/** + * Bigram word distribution with a table listing counts for each word, + * initialized from the given key/value pairs. + * + * @param key Array holding keys, possibly over-sized, so that the internal + * arrays from DynamicArray objects can be used + * @param values Array holding values corresponding to each key, possibly + * oversize + * @param num_words Number of actual key/value pairs to be stored + * statistics. + */ + +abstract class BigramWordDist( + unigramKeys: Array[String], + unigramValues: Array[Int], + num_unigrams: Int, + bigramKeys: Array[(String, String)], + bigramValues: Array[Int], + num_bigrams: Int +) extends WordDist { + + val unicounts = create_word_int_map() + for (i <- 0 until num_unigrams) + unicounts(memoize_string(unigramKeys(i))) = unigramValues(i) + var num_word_tokens : Double = unicounts.values.sum + + def num_word_types = unicounts.size + + val bicounts = create_word_int_map() + for (i <- 0 until num_bigrams) + bicounts(memoize_bigram(bigramKeys(i))) = bigramValues(i) + var num_bigram_tokens = bicounts.values.sum + + def num_bigram_types = bicounts.size + + def innerToString: String + + /** + * Memoize a bigram. The words should be encoded to remove ':' signs, + * just as they are stored in the document file. + */ + def memoize_bigram(bigram: (String, String)) = { + val word1 = bigram._1 + val word2 = bigram._2 + assert(!(word1 contains ':')) + assert(!(word2 contains ':')) + memoize_string(word1 + ":" + word2) + } + + override def toString = { + val finished_str = + if (!finished) ", unfinished" else "" + val num_words_to_print = 15 + def get_words(counts: WordIntMap) = { + val need_dots = counts.size > num_words_to_print + val items = + for ((word, count) <- + counts.toSeq.sortWith(_._2 > _._2).view(0, num_words_to_print)) + yield "%s=%s" format (unmemoize_string(word), count) + (items mkString " ") + (if (need_dots) " ..." else "") + } + "BigramWordDist(%d unigram types, %d unigram tokens, %d bigram types, %d bigram tokens%s%s, %s, %s)" format ( + num_word_types, num_word_tokens, num_bigram_types, num_bigram_tokens, + innerToString, finished_str, get_words(unicounts), get_words(bicounts)) + } + + protected def imp_add_document(words: Iterable[String], + ignore_case: Boolean, stopwords: Set[String]) { + errprint("add_document") + var previous = ""; + unicounts(memoize_string(previous)) += 1 + for {word <- words + val wlower = if (ignore_case) word.toLowerCase() else word + if !stopwords(wlower) } { + unicounts(memoize_string(wlower)) += 1 + num_word_tokens += 1 + bicounts(memoize_string(previous + "_" + wlower)) += 1 + previous = wlower + } + } + + protected def imp_add_word_distribution(xworddist: WordDist) { + val worddist = xworddist.asInstanceOf[BigramWordDist] + for ((word, count) <- worddist.unicounts) + unicounts(word) += count + for ((bigram, count) <- worddist.bicounts) + bicounts(bigram) += count + num_word_tokens += worddist.num_word_tokens + num_bigram_tokens += worddist.num_bigram_tokens + if (debug("bigram")) + errprint("add_word_distribution: " + num_word_tokens + " " + + num_bigram_tokens) + } + + protected def imp_finish_before_global(minimum_word_count: Int) { + // make sure counts not null (eg article in coords file but not counts file) + if (unicounts == null || bicounts == null) return + + // If 'minimum_word_count' was given, then eliminate words whose count + // is too small. + if (minimum_word_count > 1) { + for ((word, count) <- unicounts if count < minimum_word_count) { + num_word_tokens -= count + unicounts -= word + } + for ((bigram, count) <- bicounts if count < minimum_word_count) { + num_bigram_tokens -= count + bicounts -= bigram + } + } + } + + /** + * This is a basic unigram implementation of the computation of the + * KL-divergence between this distribution and another distribution. + * Useful for checking against other, faster implementations. + * + * Computing the KL divergence is a bit tricky, especially in the + * presence of smoothing, which assigns probabilities even to words not + * seen in either distribution. We have to take into account: + * + * 1. Words in this distribution (may or may not be in the other). + * 2. Words in the other distribution that are not in this one. + * 3. Words in neither distribution but seen globally. + * 4. Words never seen at all. + * + * The computation of steps 3 and 4 depends heavily on the particular + * smoothing algorithm; in the absence of smoothing, these steps + * contribute nothing to the overall KL-divergence. + * + * @param xother The other distribution to compute against. + * @param partial If true, only do step 1 above. + * + * @return A tuple of (divergence, word_contribs) where the first + * value is the actual KL-divergence and the second is the map + * of word contributions as described above; will be null if + * not requested. + */ + protected def imp_kl_divergence(xother: WordDist, partial: Boolean) = { + val other = xother.asInstanceOf[BigramWordDist] + var kldiv = 0.0 + //val contribs = + // if (return_contributing_words) mutable.Map[Word, Double]() else null + // 1. + for (word <- bicounts.keys) { + val p = lookup_word(word) + val q = other.lookup_word(word) + if (p <= 0.0 || q <= 0.0) + errprint("Warning: problematic values: p=%s, q=%s, word=%s", p, q, word) + else { + kldiv += p*(log(p) - log(q)) + if (debug("bigram")) + errprint("kldiv1: " + kldiv + " :p: " + p + " :q: " + q) + //if (return_contributing_words) + // contribs(word) = p*(log(p) - log(q)) + } + } + + if (partial) + kldiv + else { + // Step 2. + for (word <- other.bicounts.keys if !(bicounts contains word)) { + val p = lookup_bigram(word) + val q = other.lookup_bigram(word) + kldiv += p*(log(p) - log(q)) + if (debug("bigram")) + errprint("kldiv2: " + kldiv + " :p: " + p + " :q: " + q) + //if (return_contributing_words) + // contribs(word) = p*(log(p) - log(q)) + } + + val retval = kldiv + kl_divergence_34(other) + //(retval, contribs) + retval + } + } + + /** + * Steps 3 and 4 of KL-divergence computation. + * @see #kl_divergence + */ + def kl_divergence_34(other: BigramWordDist): Double + + def get_nbayes_logprob(xworddist: WordDist) = { + val worddist = xworddist.asInstanceOf[BigramWordDist] + var logprob = 0.0 + for ((word, count) <- worddist.bicounts) { + val value = lookup_bigram(word) + if (value <= 0) { + // FIXME: Need to figure out why this happens (perhaps the word was + // never seen anywhere in the training data? But I thought we have + // a case to handle that) and what to do instead. + errprint("Warning! For word %s, prob %s out of range", word, value) + } else + logprob += log(value) + } + // FIXME: Also use baseline (prior probability) + logprob + } + + def lookup_bigram(word: Word): Double + + def find_most_common_word(pred: String => Boolean) = { + val filtered = + (for ((word, count) <- unicounts if pred(unmemoize_string(word))) + yield (word, count)).toSeq + if (filtered.length == 0) None + else { + val (maxword, maxcount) = filtered maxBy (_._2) + Some(maxword) + } + } +} + +trait SimpleBigramWordDistConstructor extends WordDistConstructor { + /** + * Initial size of the internal DynamicArray objects; an optimization. + */ + protected val initial_dynarr_size = 1000 + /** + * Internal DynamicArray holding the keys (canonicalized words). + */ + protected val keys_dynarr = + new DynamicArray[String](initial_alloc = initial_dynarr_size) + /** + * Internal DynamicArray holding the values (word counts). + */ + protected val values_dynarr = + new DynamicArray[Int](initial_alloc = initial_dynarr_size) + /** + * Set of the raw, uncanonicalized words seen, to check that an + * uncanonicalized word isn't seen twice. (Canonicalized words may very + * well occur multiple times.) + */ + protected val raw_keys_set = mutable.Set[String]() + /** Same as `keys_dynarr`, for bigrams. */ + protected val bigram_keys_dynarr = + new DynamicArray[(String, String)](initial_alloc = initial_dynarr_size) + /** Same as `values_dynarr`, for bigrams. */ + protected val bigram_values_dynarr = + new DynamicArray[Int](initial_alloc = initial_dynarr_size) + /** Same as `raw_keys_set`, for bigrams. */ + protected val raw_bigram_keys_set = mutable.Set[(String, String)]() + + /** + * Called each time a unigram word count is seen. This can accept or + * reject the word (e.g. based on whether the count is high enough or + * the word is in a stopwords list), and optionally change the word into + * something else (e.g. the lowercased version or a generic -OOV-). + * + * @param doc Document whose distribution is being initialized. + * @param word Raw word seen. NOTE: Unlike in the case of UnigramWordDist, + * the word is still encoded in the way it's found in the document file. + * (In particular, this means that colons are replaced with %3A, and + * percent signs by %25.) + * @param count Raw count for the word + * @return A modified form of the word, or None to reject the word. + */ + def canonicalize_accept_unigram(doc: GenericDistDocument, + word: String, count: Int): Option[String] + + /** + * Called each time a bigram word count is seen. This can accept or + * reject the bigram, much like for `canonicalize_accept_unigram`. + * + * @param doc Document whose distribution is being initialized. + * @param bigram Tuple of `(word1, word2)` describing bigram. + * @param count Raw count for the bigram + * @return A modified form of the bigram, or None to reject the bigram. + */ + def canonicalize_accept_bigram(doc: GenericDistDocument, + bigram: (String, String), count: Int): Option[(String, String)] + + def parse_counts(doc: GenericDistDocument, countstr: String) { + keys_dynarr.clear() + values_dynarr.clear() + raw_keys_set.clear() + bigram_keys_dynarr.clear() + bigram_values_dynarr.clear() + raw_bigram_keys_set.clear() + val wordcounts = countstr.split(" ") + for (wordcount <- wordcounts) yield { + val split_wordcount = wordcount.split(":", -1) + def check_nonempty_word(word: String) { + if (word.length == 0) + throw FileFormatException( + "For unigram/bigram counts, WORD must not be empty, but %s seen" + format wordcount) + } + if (split_wordcount.length == 2) { + val Array(word, strcount) = split_wordcount + check_nonempty_word(word) + val count = strcount.toInt + if (raw_keys_set contains word) + throw FileFormatException( + "Word %s seen twice in same counts list" format word) + raw_keys_set += word + val opt_canon_word = canonicalize_accept_unigram(doc, word, count) + if (opt_canon_word != None) { + keys_dynarr += opt_canon_word.get + values_dynarr += count + } + } else if (split_wordcount.length == 3) { + val Array(word1, word2, strcount) = split_wordcount + check_nonempty_word(word1) + check_nonempty_word(word2) + val word = (word1, word2) + val count = strcount.toInt + if (raw_bigram_keys_set contains word) + throw FileFormatException( + "Word %s seen twice in same counts list" format word) + raw_bigram_keys_set += word + val opt_canon_word = canonicalize_accept_bigram(doc, word, count) + if (opt_canon_word != None) { + bigram_keys_dynarr += opt_canon_word.get + bigram_values_dynarr += count + } + } else + throw FileFormatException( + "For bigram counts, items must be of the form WORD:COUNT or WORD:WORD:COUNT, but %s seen" + format wordcount) + } + } + + def initialize_distribution(doc: GenericDistDocument, countstr: String, + is_training_set: Boolean) { + parse_counts(doc, countstr) + // Now set the distribution on the document; but don't use the test + // set's distributions in computing global smoothing values and such. + set_bigram_word_dist(doc, keys_dynarr.array, values_dynarr.array, + keys_dynarr.length, bigram_keys_dynarr.array, + bigram_values_dynarr.array, bigram_keys_dynarr.length, + is_training_set) + } + + def set_bigram_word_dist(doc: GenericDistDocument, + keys: Array[String], values: Array[Int], num_words: Int, + bigram_keys: Array[(String, String)], bigram_values: Array[Int], + num_bigrams: Int, is_training_set: Boolean) +} + +/** + * General factory for BigramWordDist distributions. + */ +abstract class BigramWordDistFactory extends + WordDistFactory with SimpleBigramWordDistConstructor { + def canonicalize_word(doc: GenericDistDocument, word: String) = { + val lword = maybe_lowercase(doc, word) + if (!is_stopword(doc, lword)) lword else "-OOV-" + } + + def canonicalize_accept_unigram(doc: GenericDistDocument, raw_word: String, + count: Int) = { + val lword = maybe_lowercase(doc, raw_word) + /* minimum_word_count (--minimum-word-count) currently handled elsewhere. + FIXME: Perhaps should be handled here. */ + if (!is_stopword(doc, lword)) + Some(lword) + else + None + } + + def canonicalize_accept_bigram(doc: GenericDistDocument, + raw_bigram: (String, String), count: Int) = { + val cword1 = canonicalize_word(doc, raw_bigram._1) + val cword2 = canonicalize_word(doc, raw_bigram._2) + /* minimum_word_count (--minimum-word-count) currently handled elsewhere. + FIXME: Perhaps should be handled here. */ + Some((cword1, cword1)) + } +} + diff --git a/src/main/scala/opennlp/fieldspring/worddist/DirichletUnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/DirichletUnigramWordDist.scala new file mode 100644 index 0000000..a72edbc --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/DirichletUnigramWordDist.scala @@ -0,0 +1,47 @@ +/////////////////////////////////////////////////////////////////////////////// +// DirichletUnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +/** + * This class implements Dirichlet discounting, where the discount factor + * depends on the size of the document. + */ +class DirichletUnigramWordDistFactory( + interpolate_string: String, + val dirichlet_factor: Double + ) extends DiscountedUnigramWordDistFactory(interpolate_string != "no") { + def create_word_dist(note_globally: Boolean) = + new DirichletUnigramWordDist(this, note_globally) +} + +class DirichletUnigramWordDist( + factory: WordDistFactory, + note_globally: Boolean +) extends DiscountedUnigramWordDist( + factory, note_globally + ) { + override protected def imp_finish_after_global() { + unseen_mass = 1.0 - + (model.num_tokens.toDouble / + (model.num_tokens + + factory.asInstanceOf[DirichletUnigramWordDistFactory]. + dirichlet_factor)) + super.imp_finish_after_global() + } +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/DiscountedUnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/DiscountedUnigramWordDist.scala new file mode 100644 index 0000000..518660d --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/DiscountedUnigramWordDist.scala @@ -0,0 +1,365 @@ +/////////////////////////////////////////////////////////////////////////////// +// DiscountedUnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import math._ + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.printutil.errprint + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +// FIXME! For --tf-idf +import opennlp.fieldspring.gridlocate.GridLocateDriver +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import WordDist.memoizer._ + +abstract class DiscountedUnigramWordDistFactory( + val interpolate: Boolean + ) extends UnigramWordDistFactory { + // Estimate of number of unseen word types for all documents + var total_num_unseen_word_types = 0 + + /** + * Overall probabilities over all documents of seeing a word in a document, + * for all words seen at least once in any document, computed using the + * empirical frequency of a word among all documents, adjusted by the mass + * to be assigned to globally unseen words (words never seen at all), i.e. + * the value in 'globally_unseen_word_prob'. We start out by storing raw + * counts, then adjusting them. + */ + var overall_word_probs = create_word_double_map() + var owp_adjusted = false + var document_freq = create_word_double_map() + var num_documents = 0 + var global_normalization_factor = 0.0 + + override def note_dist_globally(dist: WordDist) { + val udist = dist.asInstanceOf[DiscountedUnigramWordDist] + super.note_dist_globally(dist) + if (dist.note_globally) { + assert(!owp_adjusted) + for ((word, count) <- udist.model.iter_items) { + if (!(overall_word_probs contains word)) + total_num_word_types += 1 + // Record in overall_word_probs; note more tokens seen. + overall_word_probs(word) += count + // Our training docs should never have partial (interpolated) counts. + assert (count == count.toInt) + total_num_word_tokens += count.toInt + // Note document frequency of word + document_freq(word) += 1 + } + num_documents += 1 + } + if (debug("lots")) { + errprint("""For word dist, total tokens = %s, unseen_mass = %s, overall unseen mass = %s""", + udist.model.num_tokens, udist.unseen_mass, udist.overall_unseen_mass) + } + } + + // The total probability mass to be assigned to words not seen at all in + // any document, estimated using Good-Turing smoothing as the unadjusted + // empirical probability of having seen a word once. No longer used at + // all in the "new way". + // var globally_unseen_word_prob = 0.0 + + // For documents whose word counts are not known, use an empty list to + // look up in. + // unknown_document_counts = ([], []) + + def finish_global_distribution() { + /* We do in-place conversion of counts to probabilities. Make sure + this isn't done twice!! */ + assert (!owp_adjusted) + owp_adjusted = true + // A holdout from the "old way". + val globally_unseen_word_prob = 0.0 + if (GridLocateDriver.Params.tf_idf) { + for ((word, count) <- overall_word_probs) + overall_word_probs(word) = + count*math.log(num_documents/document_freq(word)) + } + global_normalization_factor = ((overall_word_probs.values) sum) + for ((word, count) <- overall_word_probs) + overall_word_probs(word) = ( + count.toDouble/global_normalization_factor*(1.0 - globally_unseen_word_prob)) + } +} + +abstract class DiscountedUnigramWordDist( + gen_factory: WordDistFactory, + note_globally: Boolean +) extends UnigramWordDist(gen_factory, note_globally) { + type TThis = DiscountedUnigramWordDist + type TKLCache = DiscountedUnigramKLDivergenceCache + def dufactory = gen_factory.asInstanceOf[DiscountedUnigramWordDistFactory] + + /** Total probability mass to be assigned to all words not + seen in the document. This indicates how much mass to "discount" from + the unsmoothed (maximum-likelihood) estimated language model of the + document. This can be document-specific and is one of the two basic + differences between the smoothing methods investigated here: + + 1. Jelinek-Mercer uses a constant value. + 2. Dirichlet uses a value that is related to document length, getting + smaller as document length increases. + 3. Pseudo Good-Turing, motivated by Good-Turing smoothing, computes this + mass as the unadjusted empirical probability of having seen a word + once. + + The other difference is whether to do interpolation or back-off. + This only affects words that exist in the unsmoothed model. With + interpolation, we mix the unsmoothed and global models, using the + discount value. With back-off, we only use the unsmoothed model in + such a case, using the global model only for words unseen in the + unsmoothed model. + + In other words: + + 1. With interpolation, we compute the probability as + + COUNTS[W]/TOTAL_TOKENS*(1 - UNSEEN_MASS) + + UNSEEN_MASS * overall_word_probs[W] + + For unseen words, only the second term is non-zero. + + 2. With back-off, for words with non-zero MLE counts, we compute + the probability as + + COUNTS[W]/TOTAL_TOKENS*(1 - UNSEEN_MASS) + + For other words, we compute the probability as + + UNSEEN_MASS * (overall_word_probs[W] / OVERALL_UNSEEN_MASS) + + The idea is that overall_word_probs[W] / OVERALL_UNSEEN_MASS is + an estimate of p(W | W not in A). We have to divide by + OVERALL_UNSEEN_MASS to make these probabilities be normalized + properly. We scale p(W | W not in A) by the total probability mass + we have available for all words not seen in A. + + 3. An additional possibility is that we are asked the probability of + a word never seen at all. The old code I wrote tried to assign + a non-zero probability to such words using the formula + + UNSEEN_MASS * globally_unseen_word_prob / NUM_UNSEEN_WORDS + + where NUM_UNSEEN_WORDS is an estimate of the total number of words + "exist" but haven't been seen in any documents. Based on Good-Turing + motivation, we used the number of words seen once in any document. + This certainly underestimates this number if not too many documents + have been seen but might be OK if many documents seen. + + The paper on this subject suggests assigning zero probability to + such words and ignoring them entirely in calculations if/when they + occur in a query. + */ + var unseen_mass = 0.5 + /** + Probability mass assigned in 'overall_word_probs' to all words not seen + in the document. This is 1 - (sum over W in A of overall_word_probs[W]). + See above. + */ + var overall_unseen_mass = 1.0 + + def innerToString = ", %.2f unseen mass" format unseen_mass + + var normalization_factor = 0.0 + + /** + * Here we compute the value of `overall_unseen_mass`, which depends + * on the global `overall_word_probs` computed from all of the + * distributions. + */ + protected def imp_finish_after_global() { + val factory = dufactory + + // Make sure that overall_word_probs has been computed properly. + assert(factory.owp_adjusted) + + if (factory.interpolate) + overall_unseen_mass = 1.0 + else + overall_unseen_mass = 1.0 - ( + // NOTE NOTE NOTE! The toSeq needs to be added for some reason; if not, + // the computation yields different values, which cause a huge loss of + // accuracy (on the order of 10-15%). I have no idea why; I suspect a + // Scala bug. (SCALABUG) (Or, it used to occur when the code read + // 'counts.keys.toSeq'; who knows now.) + (for (ind <- model.iter_keys.toSeq) + yield factory.overall_word_probs(ind)) sum) + if (GridLocateDriver.Params.tf_idf) { + for ((word, count) <- model.iter_items) + model.set_item(word, + count*log(factory.num_documents/factory.document_freq(word))) + } + normalization_factor = model.num_tokens + //if (use_sorted_list) + // counts = new SortedList(counts) + if (debug("discount-factor") || debug("discountfactor")) + errprint("For distribution %s, norm_factor = %g, model.num_tokens = %s, unseen_mass = %g" + format (this, normalization_factor, model.num_tokens, unseen_mass)) + } + + def fast_kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false) = { + FastDiscountedUnigramWordDist.fast_kl_divergence( + this.asInstanceOf[TThis], cache.asInstanceOf[TKLCache], + other.asInstanceOf[TThis], interpolate = dufactory.interpolate, + partial = partial) + } + + def cosine_similarity(other: WordDist, partial: Boolean = false, + smoothed: Boolean = false) = { + if (smoothed) + FastDiscountedUnigramWordDist.fast_smoothed_cosine_similarity( + this.asInstanceOf[TThis], other.asInstanceOf[TThis], + partial = partial) + else + FastDiscountedUnigramWordDist.fast_cosine_similarity( + this.asInstanceOf[TThis], other.asInstanceOf[TThis], + partial = partial) + } + + def kl_divergence_34(other: UnigramWordDist) = { + val factory = dufactory + var overall_probs_diff_words = 0.0 + for (word <- other.model.iter_keys if !(model contains word)) { + overall_probs_diff_words += factory.overall_word_probs(word) + } + + inner_kl_divergence_34(other.asInstanceOf[TThis], + overall_probs_diff_words) + } + + /** + * Actual implementation of steps 3 and 4 of KL-divergence computation, given + * a value that we may want to compute as part of step 2. + */ + def inner_kl_divergence_34(other: TThis, + overall_probs_diff_words: Double) = { + var kldiv = 0.0 + + // 3. For words seen in neither dist but seen globally: + // You can show that this is + // + // factor1 = (log(self.unseen_mass) - log(self.overall_unseen_mass)) - + // (log(other.unseen_mass) - log(other.overall_unseen_mass)) + // factor2 = self.unseen_mass / self.overall_unseen_mass * factor1 + // kldiv = factor2 * (sum(words seen globally but not in either dist) + // of overall_word_probs[word]) + // + // The final sum + // = 1 - sum(words in self) overall_word_probs[word] + // - sum(words in other, not self) overall_word_probs[word] + // = self.overall_unseen_mass + // - sum(words in other, not self) overall_word_probs[word] + // + // So we just need the sum over the words in other, not self. + // + // Note that the above formula was derived using back-off, but it + // still applies in interpolation. For words seen in neither dist, + // the only difference between back-off and interpolation is that + // the "overall_unseen_mass" factors for all distributions are + // effectively 1.0 (and the corresponding log terms above disappear). + + val factor1 = ((log(unseen_mass) - log(overall_unseen_mass)) - + (log(other.unseen_mass) - log(other.overall_unseen_mass))) + val factor2 = unseen_mass / overall_unseen_mass * factor1 + val the_sum = overall_unseen_mass - overall_probs_diff_words + kldiv += factor2 * the_sum + + // 4. For words never seen at all: + /* The new way ignores these words entirely. + val p = (unseen_mass*factory.globally_unseen_word_prob / + factory.total_num_unseen_word_types) + val q = (other.unseen_mass*factory.globally_unseen_word_prob / + factory.total_num_unseen_word_types) + kldiv += factory.total_num_unseen_word_types*(p*(log(p) - log(q))) + */ + kldiv + } + + def lookup_word(word: Word) = { + val factory = dufactory + assert(finished) + if (factory.interpolate) { + val wordcount = if (model contains word) model.get_item(word) else 0.0 + // if (debug("some")) { + // errprint("Found counts for document %s, num word types = %s", + // doc, wordcounts(0).length) + // errprint("Unknown prob = %s, overall_unseen_mass = %s", + // unseen_mass, overall_unseen_mass) + // } + val owprob = factory.overall_word_probs.getOrElse(word, 0.0) + val mle_wordprob = wordcount.toDouble/normalization_factor + val wordprob = mle_wordprob*(1.0 - unseen_mass) + owprob*unseen_mass + if (debug("lots")) + errprint("Word %s, seen in document, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } else { + val retval = + if (!(model contains word)) { + factory.overall_word_probs.get(word) match { + case None => { + /* + The old way: + + val wordprob = (unseen_mass*factory.globally_unseen_word_prob + / factory.total_num_unseen_word_types) + */ + /* The new way: Just return 0 */ + val wordprob = 0.0 + if (debug("lots")) + errprint("Word %s, never seen at all, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + case Some(owprob) => { + val wordprob = unseen_mass * owprob / overall_unseen_mass + //if (wordprob <= 0) + // warning("Bad values; unseen_mass = %s, overall_word_probs[word] = %s, overall_unseen_mass = %s", + // unseen_mass, factory.overall_word_probs[word], + // factory.overall_unseen_mass) + if (debug("lots")) + errprint("Word %s, seen but not in document, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + } + } else { + val wordcount = model.get_item(word) + //if (wordcount <= 0 or model.num_tokens <= 0 or unseen_mass >= 1.0) + // warning("Bad values; wordcount = %s, unseen_mass = %s", + // wordcount, unseen_mass) + // for ((word, count) <- self.counts) + // errprint("%s: %s", word, count) + val wordprob = wordcount.toDouble/normalization_factor*(1.0 - unseen_mass) + if (debug("lots")) + errprint("Word %s, seen in document, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + retval + } + } +} + diff --git a/src/main/scala/opennlp/fieldspring/worddist/FastDiscountedUnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/FastDiscountedUnigramWordDist.scala new file mode 100644 index 0000000..32492ce --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/FastDiscountedUnigramWordDist.scala @@ -0,0 +1,321 @@ +/////////////////////////////////////////////////////////////////////////////// +// FastDiscountedUnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import scala.collection.mutable +import math.{log, sqrt} + +import opennlp.fieldspring.util.collectionutil.DynamicArray + +import WordDist.memoizer.Word + +/** + Fast implementation of KL-divergence and cosine-similarity algorithms + for use with discounted smoothing. This code was originally written + for Pseudo-Good-Turing, but it isn't specific to this algorithm -- + it works for any algorithm that involves discounting of a distribution + in order to interpolate the global distribution. + + This code was originally broken out of WordDist (before the + pseudo-Good-Turing code was separated out) so that it could be + rewritten in .pyc (Python with extra 'cdef' annotations that can be + converted to C and compiled down to machine language). With the + pseudo-Good-Turing code extracted, this should properly be merged + into PseudoGoodTuringSmoothedWordDist.scala, but keep separated for + the moment in case we need to convert it to Java, C++, etc. + */ + +class DiscountedUnigramKLDivergenceCache( + val worddist: DiscountedUnigramWordDist + ) extends KLDivergenceCache { + val self_size = worddist.model.num_types + val self_keys = worddist.model.iter_keys.toArray + val self_values = worddist.model.iter_items.map { case (k,v) => v}.toArray +} + +object FastDiscountedUnigramWordDist { + type TDist = DiscountedUnigramWordDist + + def get_kl_divergence_cache(self: TDist) = + new DiscountedUnigramKLDivergenceCache(self) + + /* + A fast implementation of KL-divergence that uses inline lookups as much + as possible. Uses cached values if possible to avoid garbage from + copying arrays. + + In normal operation of grid location, we repeatedly do KL divergence + with the same `self` distribution and different `other` distributions, + and we have to iterate over all key/value pairs in `self`, so caching + the `self` keys and values into arrays is useful. `cache` can be + null (no cache available) or a cache created using + `get_kl_divergence_cache`, which must have been called on `self`, + and NO CHANGES to `self` made between cache creation time and use time. + */ + def fast_kl_divergence(self: TDist, + cache: DiscountedUnigramKLDivergenceCache, + other: TDist, interpolate: Boolean, partial: Boolean = false): Double = { + + val the_cache = + if (cache == null) + get_kl_divergence_cache(self) + else + cache + assert(the_cache.worddist == self) + assert(the_cache.self_size == self.model.num_types) + val pkeys = the_cache.self_keys + val pvalues = the_cache.self_values + val pfact = (1.0 - self.unseen_mass)/self.model.num_tokens + val qfact = (1.0 - other.unseen_mass)/other.model.num_tokens + val pfact_unseen = self.unseen_mass / self.overall_unseen_mass + val qfact_unseen = other.unseen_mass / other.overall_unseen_mass + val factory = self.dufactory + /* Not needed in the new way + val qfact_globally_unseen_prob = (other.unseen_mass* + factory.globally_unseen_word_prob / + factory.total_num_unseen_word_types) + */ + val owprobs = factory.overall_word_probs + val pmodel = self.model + val qmodel = other.model + + // 1. + + val psize = self.model.num_types + + // FIXME!! p * log(p) is the same for all calls of fast_kl_divergence + // on this item, so we could cache it. Not clear it would save much + // time, though. + var kldiv = 0.0 + /* THIS IS THE INSIDE LOOP. THIS IS THE CODE BOTTLENECK. THIS IS IT. + + This code needs to scream. Hence we do extra setup above involving + arrays, to avoid having a function call through a function + pointer (through the "obvious" use of forEach()). FIXME: But see + comment above. + + Note that HotSpot is good about inlining function calls. + Hence we can assume that the calls to apply() below (e.g. + qcounts(word)) will be inlined. However, it's *very important* + to avoid doing anything that creates objects each iteration, + and best to avoid creating objects per call to fast_kl_divergence(). + This object creation will kill us, as it will trigger tons + and tons of garbage collection. + + Recent HotSpot implementations (6.0 rev 14 and above) have "escape + analysis" that *might* make the object creation magically vanish, + but don't count on it. + */ + var i = 0 + if (interpolate) { + while (i < psize) { + val word = pkeys(i) + val pcount = pvalues(i) + val qcount = qmodel.get_item(word) + val owprob = owprobs(word) + val p = pcount * pfact + owprob * pfact_unseen + val q = qcount * qfact + owprob * qfact_unseen + /* In the "new way" we have to notice when a word was never seen + at all, and ignore it. */ + if (q > 0.0) { + //if (p == 0.0) + // errprint("Warning: zero value: p=%s q=%s word=%s pcount=%s qcount=%s qfact=%s qfact_unseen=%s owprobs=%s", + // p, q, word, pcount, qcount, qfact, qfact_unseen, + // owprobs(word)) + kldiv += p * (log(p) - log(q)) + } + i += 1 + } + } else { + while (i < psize) { + val word = pkeys(i) + val pcount = pvalues(i) + val p = pcount * pfact + val q = { + val qcount = qmodel.get_item(word) + if (qcount != 0) qcount * qfact + else { + val owprob = owprobs(word) + /* The old way: + if (owprob != 0.0) owprob * qfact_unseen + else qfact_globally_unseen_prob + */ + /* The new way: No need for a globally unseen probability. */ + owprob * qfact_unseen + } + } + /* However, in the "new way" we have to notice when a word was never + seen at all, and ignore it. */ + if (q > 0.0) { + //if (q == 0.0) + // errprint("Strange: word=%s qfact_globally_unseen_prob=%s qcount=%s qfact=%s", + // word, qfact_globally_unseen_prob, qcount, qfact) + //if (p == 0.0 || q == 0.0) + // errprint("Warning: zero value: p=%s q=%s word=%s pcount=%s qcount=%s qfact=%s qfact_unseen=%s owprobs=%s", + // p, q, word, pcount, qcount, qfact, qfact_unseen, + // owprobs(word)) + kldiv += p * (log(p) - log(q)) + } + i += 1 + } + } + + if (partial) + return kldiv + + // 2. + var overall_probs_diff_words = 0.0 + for ((word, qcount) <- qmodel.iter_items if !(pmodel contains word)) { + val word_overall_prob = owprobs(word) + val p = word_overall_prob * pfact_unseen + val q = qcount * qfact + kldiv += p * (log(p) - log(q)) + overall_probs_diff_words += word_overall_prob + } + + return kldiv + self.inner_kl_divergence_34(other, overall_probs_diff_words) + } + + // The older implementation that uses smoothed probabilities. + + /** + A fast implementation of cosine similarity that uses Cython declarations + and inlines lookups as much as possible. It's always "partial" in that it + ignores words neither in P nor Q, despite the fact that they have non-zero + probability due to smoothing. But with parameter "partial" to true we + proceed as with KL-divergence and ignore words not in P. + */ + def fast_smoothed_cosine_similarity(self: TDist, other: TDist, + partial: Boolean = false): Double = { + val pfact = (1.0 - self.unseen_mass)/self.model.num_tokens + val qfact = (1.0 - other.unseen_mass)/other.model.num_tokens + val qfact_unseen = other.unseen_mass / other.overall_unseen_mass + val factory = self.dufactory + /* Not needed in the new way + val qfact_globally_unseen_prob = (other.unseen_mass* + factory.globally_unseen_word_prob / + factory.total_num_unseen_word_types) + */ + val owprobs = factory.overall_word_probs + val pmodel = self.model + val qmodel = other.model + + // 1. + + // FIXME!! Length of p is the same for all calls of fast_cosine_similarity + // on this item, so we could cache it. Not clear it would save much + // time, though. + var pqsum = 0.0 + var p2sum = 0.0 + var q2sum = 0.0 + for ((word, pcount) <- pmodel.iter_items) { + val p = pcount * pfact + val q = { + val qcount = qmodel.get_item(word) + val owprob = owprobs(word) + qcount * qfact + owprob * qfact_unseen + } + //if (q == 0.0) + // errprint("Strange: word=%s qfact_globally_unseen_prob=%s qcount=%s qfact=%s", + // word, qfact_globally_unseen_prob, qcount, qfact) + //if (p == 0.0 || q == 0.0) + // errprint("Warning: zero value: p=%s q=%s word=%s pcount=%s qcount=%s qfact=%s qfact_unseen=%s owprobs=%s", + // p, q, word, pcount, qcount, qfact, qfact_unseen, + // owprobs(word)) + pqsum += p * q + p2sum += p * p + q2sum += q * q + } + + if (partial) + return pqsum / (sqrt(p2sum) * sqrt(q2sum)) + + // 2. + val pfact_unseen = self.unseen_mass / self.overall_unseen_mass + var overall_probs_diff_words = 0.0 + for ((word, qcount) <- qmodel.iter_items if !(pmodel contains word)) { + val word_overall_prob = owprobs(word) + val p = word_overall_prob * pfact_unseen + val q = qcount * qfact + pqsum += p * q + p2sum += p * p + q2sum += q * q + //overall_probs_diff_words += word_overall_prob + } + + // FIXME: This would be the remainder of the computation for words + // neither in P nor Q. We did a certain amount of math in the case of the + // KL-divergence to make it possible to do these steps efficiently. + // Probably similar math could make the steps here efficient as well, but + // unclear. + + //kldiv += self.kl_divergence_34(other, overall_probs_diff_words) + //return kldiv + + return pqsum / (sqrt(p2sum) * sqrt(q2sum)) + } + + // The newer implementation that uses unsmoothed probabilities. + + /** + A fast implementation of cosine similarity that uses Cython declarations + and inlines lookups as much as possible. It's always "partial" in that it + ignores words neither in P nor Q, despite the fact that they have non-zero + probability due to smoothing. But with parameter "partial" to true we + proceed as with KL-divergence and ignore words not in P. + */ + def fast_cosine_similarity(self: TDist, other: TDist, + partial: Boolean = false) = { + val pfact = 1.0/self.model.num_tokens + val qfact = 1.0/other.model.num_tokens + // 1. + val pmodel = self.model + val qmodel = other.model + + // FIXME!! Length of p is the same for all calls of fast_cosine_similarity + // on this item, so we could cache it. Not clear it would save much + // time, though. + var pqsum = 0.0 + var p2sum = 0.0 + var q2sum = 0.0 + for ((word, pcount) <- pmodel.iter_items) { + val p = pcount * pfact + val q = qmodel.get_item(word) * qfact + //if (q == 0.0) + // errprint("Strange: word=%s qfact_globally_unseen_prob=%s qcount=%s qfact=%s", + // word, qfact_globally_unseen_prob, qcount, qfact) + //if (p == 0.0 || q == 0.0) + // errprint("Warning: zero value: p=%s q=%s word=%s pcount=%s qcount=%s qfact=%s qfact_unseen=%s owprobs=%s", + // p, q, word, pcount, qcount, qfact, qfact_unseen, + // owprobs(word)) + pqsum += p * q + p2sum += p * p + q2sum += q * q + } + + // 2. + if (!partial) + for ((word, qcount) <- qmodel.iter_items if !(pmodel contains word)) { + val q = qcount * qfact + q2sum += q * q + } + + if (pqsum == 0.0) 0.0 else pqsum / (sqrt(p2sum) * sqrt(q2sum)) + } +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/JelinekMercerUnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/JelinekMercerUnigramWordDist.scala new file mode 100644 index 0000000..9a2086b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/JelinekMercerUnigramWordDist.scala @@ -0,0 +1,44 @@ +/////////////////////////////////////////////////////////////////////////////// +// JelinekMercerUnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +/** + * This class implements Jelinek-Mercer discounting, the simplest type of + * discounting where we just use a constant discount factor. + */ +class JelinekMercerUnigramWordDistFactory( + interpolate_string: String, + val jelinek_factor: Double + ) extends DiscountedUnigramWordDistFactory(interpolate_string != "no") { + def create_word_dist(note_globally: Boolean) = + new JelinekMercerUnigramWordDist(this, note_globally) +} + +class JelinekMercerUnigramWordDist( + factory: WordDistFactory, + note_globally: Boolean +) extends DiscountedUnigramWordDist( + factory, note_globally + ) { + override protected def imp_finish_after_global() { + unseen_mass = (factory.asInstanceOf[JelinekMercerUnigramWordDistFactory]. + jelinek_factor) + super.imp_finish_after_global() + } +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/Memoizer.scala b/src/main/scala/opennlp/fieldspring/worddist/Memoizer.scala new file mode 100644 index 0000000..b6e2e0f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/Memoizer.scala @@ -0,0 +1,167 @@ +/////////////////////////////////////////////////////////////////////////////// +// Memoizer.scala +// +// Copyright (C) 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import collection.mutable + +import com.codahale.trove.{mutable => trovescala} + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.printutil.errprint + +/** + * A class for "memoizing" words, i.e. mapping them to some other type + * (e.g. Int) that should be faster to compare and potentially require + * less space. + */ +abstract class Memoizer { + /** + * The type of a memoized word. + */ + type Word + /** + * Map a word as a string to its memoized form. + */ + def memoize_string(word: String): Word + /** + * Map a word from its memoized form back to a string. + */ + def unmemoize_string(word: Word): String + + /** + * The type of a mutable map from memoized words to Ints. + */ + type WordIntMap + /** + * Create a mutable map from memoized words to Ints. + */ + def create_word_int_map(): WordIntMap + /** + * The type of a mutable map from memoized words to Doubles. + */ + type WordDoubleMap + /** + * Create a mutable map from memoized words to Doubles. + */ + def create_word_double_map(): WordDoubleMap + + lazy val blank_memoized_string = memoize_string("") + + def lowercase_memoized_word(word: Word) = + memoize_string(unmemoize_string(word).toLowerCase) +} + +/** + * The memoizer we actually use. Maps word strings to Ints. Uses Trove + * for extremely fast and memory-efficient hash tables, making use of the + * Trove-Scala interface for easy access to the Trove hash tables. + */ +class IntStringMemoizer extends Memoizer { + type Word = Int + val invalid_word: Word = 0 + + protected var next_word_count: Word = 1 + + // For replacing strings with ints. This should save space on 64-bit + // machines (string pointers are 8 bytes, ints are 4 bytes) and might + // also speed lookup. + //protected val word_id_map = mutable.Map[String,Word]() + protected val word_id_map = trovescala.ObjectIntMap[String]() + + // Map in the opposite direction. + //protected val id_word_map = mutable.Map[Word,String]() + protected val id_word_map = trovescala.IntObjectMap[String]() + + def memoize_string(word: String) = { + val index = word_id_map.getOrElse(word, 0) + if (index != 0) index + else { + val newind = next_word_count + next_word_count += 1 + word_id_map(word) = newind + id_word_map(newind) = word + newind + } + } + + def unmemoize_string(word: Word) = id_word_map(word) + + def create_word_int_map() = trovescala.IntIntMap() + type WordIntMap = trovescala.IntIntMap + def create_word_double_map() = trovescala.IntDoubleMap() + type WordDoubleMap = trovescala.IntDoubleMap +} + +/** + * The memoizer we actually use. Maps word strings to Ints. Uses Trove + * for extremely fast and memory-efficient hash tables, making use of the + * Trove-Scala interface for easy access to the Trove hash tables. + */ +class TestIntStringMemoizer extends IntStringMemoizer { + override def memoize_string(word: String) = { + val cur_nwi = next_word_count + val index = super.memoize_string(word) + if (debug("memoize")) { + if (next_word_count != cur_nwi) + errprint("Memoizing new string %s to ID %s", word, index) + else + errprint("Memoizing existing string %s to ID %s", word, index) + } + assert(super.unmemoize_string(index) == word) + index + } + + override def unmemoize_string(word: Word) = { + if (!(id_word_map contains word)) { + errprint("Can't find ID %s in id_word_map", word) + errprint("Word map:") + var its = id_word_map.toList.sorted + for ((key, value) <- its) + errprint("%s = %s", key, value) + assert(false, "Exiting due to bad code in unmemoize_string") + null + } else { + val string = super.unmemoize_string(word) + if (debug("memoize")) + errprint("Unmemoizing existing ID %s to string %s", word, string) + assert(super.memoize_string(string) == word) + string + } + } +} + +/** + * A memoizer for testing that doesn't actually do anything -- the memoized + * words are also strings. This tests that we don't make any assumptions + * about memoized words being Ints. + */ +object IdentityMemoizer extends Memoizer { + type Word = String + val invalid_word: Word = null + def memoize_string(word: String): Word = word + def unmemoize_string(word: Word): String = word + + type WordIntMap = mutable.Map[Word, Int] + def create_word_int_map() = intmap[Word]() + type WordDoubleMap = mutable.Map[Word, Double] + def create_word_double_map() = doublemap[Word]() +} + diff --git a/src/main/scala/opennlp/fieldspring/worddist/NgramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/NgramWordDist.scala new file mode 100644 index 0000000..b92ea9f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/NgramWordDist.scala @@ -0,0 +1,463 @@ +/////////////////////////////////////////////////////////////////////////////// +// NgramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import math._ +import collection.mutable +import util.control.Breaks._ + +import java.io._ + +import opennlp.fieldspring.util.collectionutil.DynamicArray +import opennlp.fieldspring.util.textdbutil +import opennlp.fieldspring.util.ioutil.{FileHandler, FileFormatException} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import WordDist.memoizer._ + +object NgramStorage { + type Ngram = Iterable[String] +} + +/** + * An interface for storing and retrieving ngrams. + */ +trait NgramStorage extends ItemStorage[NgramStorage.Ngram] { + type Ngram = NgramStorage.Ngram +} + +/** + * An implementation for storing and retrieving ngrams using OpenNLP. + */ +class OpenNLPNgramStorer extends NgramStorage { + + import opennlp.tools.ngram._ + import opennlp.tools.util.StringList + + val model = new NGramModel() + + /**************************** Abstract functions ***********************/ + + /** + * Add an n-gram with the given count. If the n-gram exists already, + * add the count to the existing value. + */ + def add_item(ngram: Ngram, count: Double) { + if (count != count.toInt) + throw new IllegalArgumentException( + "Partial count %s not allowed in this class" format count) + val sl_ngram = new StringList(ngram.toSeq: _*) + /* OpenNLP only lets you either add 1 to a possibly non-existing n-gram + or set the count of an existing n-gram. */ + model.add(sl_ngram) + if (count != 1) { + val existing_count = model.getCount(sl_ngram) + model.setCount(sl_ngram, existing_count + count.toInt - 1) + } + } + + def set_item(ngram: Ngram, count: Double) { + if (count != count.toInt) + throw new IllegalArgumentException( + "Partial count %s not allowed in this class" format count) + val sl_ngram = new StringList(ngram.toSeq: _*) + /* OpenNLP only lets you either add 1 to a possibly non-existing n-gram + or set the count of an existing n-gram. */ + model.add(sl_ngram) + model.setCount(sl_ngram, count.toInt) + } + + /** + * Remove an n-gram, if it exists. + */ + def remove_item(ngram: Ngram) { + val sl_ngram = new StringList(ngram.toSeq: _*) + model.remove(sl_ngram) + } + + /** + * Return whether a given n-gram is stored. + */ + def contains(ngram: Ngram) = { + val sl_ngram = new StringList(ngram.toSeq: _*) + model.contains(sl_ngram) + } + + /** + * Return whether a given n-gram is stored. + */ + def get_item(ngram: Ngram) = { + val sl_ngram = new StringList(ngram.toSeq: _*) + model.getCount(sl_ngram) + } + + /** + * Iterate over all n-grams that are stored. + */ + def iter_keys = { + import collection.JavaConversions._ + // Iterators suck. Should not be exposed directly. + // Actually you can iterate over the model without using `iterator` + // but it doesn't appear to work right -- it generates the entire + // list before iterating over it. Doing it as below iterates over + // the list as it's generated. + for (x <- model.iterator.toIterable) + yield x.iterator.toIterable + } + + /** + * Iterate over all n-grams that are stored. + */ + def iter_items = { + import collection.JavaConversions._ + // Iterators suck. Should not be exposed directly. + // Actually you can iterate over the model without using `iterator` + // but it doesn't appear to work right -- it generates the entire + // list before iterating over it. Doing it as below iterates over + // the list as it's generated. + for (x <- model.iterator.toIterable) + yield (x.iterator.toIterable, model.getCount(x).toDouble) + } + + /** + * Total number of tokens stored. + */ + def num_tokens = model.numberOfGrams.toDouble + + /** + * Total number of n-gram types (i.e. number of distinct n-grams) + * stored for n-grams of size `len`. + */ + def num_types = model.size +} + +/** + * An implementation for storing and retrieving ngrams using OpenNLP. + */ +//class SimpleNgramStorer extends NgramStorage { +// +// /** +// * A sequence of separate maps (or possibly a "sorted list" of tuples, +// * to save memory?) of (ngram, count) items, specifying the counts of +// * all ngrams seen at least once. Each map contains the ngrams for +// * a particular value of N. The map for N = 0 is unused. +// * +// * These are given as double because in some cases they may store "partial" +// * counts (in particular, when the K-d tree code does interpolation on +// * cells). FIXME: This seems ugly, perhaps there is a better way? +// * +// * FIXME: Currently we store ngrams using a Word, which is simply an index +// * into a string table, where the ngram is shoehorned into a string using +// * the format of the Fieldspring corpus (words separated by colons, with +// * URL-encoding for embedded colons). This is very wasteful of space, +// * and inefficient too with extra encodings/decodings. We need a better +// * implementation with a trie and such. +// */ +// var counts = Vector(null, create_word_double_map()) +// var num_tokens = Vector(0.0, 0.0) +// def max_ngram_size = counts.length - 1 +// // var num_types = Vector(0) +// def num_types(n: Int) = counts(n).size +// +// def num_word_tokens = num_tokens(1) +// def num_word_types = num_types(1) +// +// def innerToString: String +// +// /** +// * Ensure that we will be able to store an N-gram for a given N. +// */ +// def ensure_ngram_fits(n: Int) { +// assert(counts.length == num_tokens.length) +// while (n > max_ngram_size) { +// counts :+= create_word_double_map() +// num_tokens :+= 0.0 +// } +// } +// +// /** +// * Record an n-gram (encoded into a string) and associated count. +// */ +// def record_encoded_ngram(egram: String, count: Int) { +// val n = egram.count(_ == ':') + 1 +// val mgram = memoize_string(egram) +// ensure_ngram_fits(n) +// counts(n)(mgram) += count +// num_tokens(n) += count +// } +// +//} + +/** + * Basic n-gram word distribution with tables listing counts for each n-gram. + * This is an abstract class because the smoothing algorithm (how to return + * probabilities for a given n-gram) isn't specified. This class takes care + * of storing the n-grams. + * + * FIXME: For unigrams storage is less of an issue, but for n-grams we + * may have multiple storage implementations, potentially swappable (i.e. + * we can specify them separately from e.g. the smoothing method of other + * properties). This suggests that at some point we may find it useful + * to outsource the storage implementation/management to another class. + * + * Some terminology: + * + * Name Short name Scala type + * ------------------------------------------------------------------------- + * Unencoded, unmemoized ngram ngram Iterable[String] + * Unencoded ngram, memoized words mwgram Iterable[Word] + * Encoded, unmemoized ngram egram String + * Encoded, memoized ngram mgram Word + * + * Note that "memoizing", in its general sense, simply converts an ngram + * into a single number (of type "Word") to identify the ngram. There is + * no theoretical reason why we have to encode the ngram into a single string + * and then "memoize" the encoded string in this fashion. Memoizing could + * involve e.g. creating some sort of trie, with a pointer to the appropriate + * memory location (or index) in the trie serving as the value returned back + * after memoization. + * + * @param factory A `WordDistFactory` object used to create this distribution. + * The object also stores global properties of various sorts (e.g. for + * smothing). + * @param note_globally Whether n-grams added to this distribution should + * have an effect on the global statistics stored in the factory. + */ + +abstract class NgramWordDist( + factory: WordDistFactory, + note_globally: Boolean + ) extends WordDist(factory, note_globally) with FastSlowKLDivergence { + import NgramStorage.Ngram + type Item = Ngram + val model = new OpenNLPNgramStorer + + def innerToString: String + + override def toString = { + val finished_str = + if (!finished) ", unfinished" else "" + val num_words_to_print = 15 + val need_dots = model.num_types > num_words_to_print + val items = + for ((word, count) <- (model.iter_items.toSeq.sortWith(_._2 > _._2). + view(0, num_words_to_print))) + yield "%s=%s" format (word mkString " ", count) + val words = (items mkString " ") + (if (need_dots) " ..." else "") + "NgramWordDist(%d types, %s tokens%s%s, %s)" format ( + model.num_types, model.num_tokens, innerToString, finished_str, words) + } + + def lookup_ngram(ngram: Ngram): Double + + /** + * This is a basic unigram implementation of the computation of the + * KL-divergence between this distribution and another distribution, + * including possible debug information. + * + * Computing the KL divergence is a bit tricky, especially in the + * presence of smoothing, which assigns probabilities even to words not + * seen in either distribution. We have to take into account: + * + * 1. Words in this distribution (may or may not be in the other). + * 2. Words in the other distribution that are not in this one. + * 3. Words in neither distribution but seen globally. + * 4. Words never seen at all. + * + * The computation of steps 3 and 4 depends heavily on the particular + * smoothing algorithm; in the absence of smoothing, these steps + * contribute nothing to the overall KL-divergence. + * + */ + def slow_kl_divergence_debug(xother: WordDist, partial: Boolean = false, + return_contributing_words: Boolean = false): + (Double, collection.Map[String, Double]) = { + assert(false, "Not implemented") + (0.0, null) + } + + /** + * Steps 3 and 4 of KL-divergence computation. + * @see #slow_kl_divergence_debug + */ + def kl_divergence_34(other: NgramWordDist): Double + + def get_nbayes_logprob(xworddist: WordDist) = { + assert(false, "Not implemented") + 0.0 + } + + def find_most_common_word(pred: String => Boolean): Option[Word] = { + assert(false, "Not implemented") + None + } +} + +/** + * Class for constructing an n-gram word distribution from a corpus or other + * data source. + */ +class DefaultNgramWordDistConstructor( + factory: WordDistFactory, + ignore_case: Boolean, + stopwords: Set[String], + whitelist: Set[String], + minimum_word_count: Int = 1, + max_ngram_length: Int = 3 +) extends WordDistConstructor(factory: WordDistFactory) { + import NgramStorage.Ngram + /** + * Internal map holding the encoded ngrams and counts. + */ + protected val parsed_ngrams = mutable.Map[String, Int]() + + /** + * Given a field containing the encoded representation of the n-grams of a + * document, parse and store internally. + */ + protected def parse_counts(countstr: String) { + parsed_ngrams.clear() + val ngramcounts = countstr.split(" ") + for (ngramcount <- ngramcounts) yield { + val (egram, count) = textdbutil.shallow_split_count_map_field(ngramcount) + if (parsed_ngrams contains egram) + throw FileFormatException( + "Ngram %s seen twice in same counts list" format egram) + parsed_ngrams(egram) = count + } + } + + var seen_documents = new scala.collection.mutable.HashSet[String]() + + /** + * Returns true if the n-gram was counted, false if it was ignored due to + * stoplisting and/or whitelisting. */ + protected def add_ngram_with_count(dist: NgramWordDist, + ngram: Ngram, count: Int): Boolean = { + val lgram = maybe_lowercase(ngram) + // FIXME: Not right with stopwords or whitelist + //if (!stopwords.contains(lgram) && + // (whitelist.size == 0 || whitelist.contains(lgram))) { + // dist.add_item(lgram, count) + // true + //} + //else + // false + dist.model.add_item(lgram, count) + return true + } + + protected def imp_add_document(gendist: WordDist, + words: Iterable[String], max_ngram_length: Int) { + val dist = gendist.asInstanceOf[NgramWordDist] + for (ngram <- (1 to max_ngram_length).flatMap(words.sliding(_))) + add_ngram_with_count(dist, ngram, 1) + } + + protected def imp_add_document(gendist: WordDist, + words: Iterable[String]) { + imp_add_document(gendist, words, max_ngram_length) + } + + protected def imp_add_word_distribution(gendist: WordDist, + genother: WordDist, partial: Double) { + // FIXME: Implement partial! + val dist = gendist.asInstanceOf[NgramWordDist].model + val other = genother.asInstanceOf[NgramWordDist].model + for ((ngram, count) <- other.iter_items) + dist.add_item(ngram, count) + } + + /** + * Incorporate a set of (key, value) pairs into the distribution. + * The number of pairs to add should be taken from `num_ngrams`, not from + * the actual length of the arrays passed in. The code should be able + * to handle the possibility that the same ngram appears multiple times, + * adding up the counts for each appearance of the ngram. + */ + protected def add_parsed_ngrams(gendist: WordDist, + grams: collection.Map[String, Int]) { + val dist = gendist.asInstanceOf[NgramWordDist] + assert(!dist.finished) + assert(!dist.finished_before_global) + var addedTypes = 0 + var addedTokens = 0 + var totalTokens = 0 + for ((egram, count) <- grams) { + val ngram = textdbutil.decode_ngram_for_count_map_field(egram) + if (add_ngram_with_count(dist, ngram, count)) { + addedTypes += 1 + addedTokens += count + } + totalTokens += count + } + // Way too much output to keep enabled + //errprint("Fraction of word types kept:"+(addedTypes.toDouble/num_ngrams)) + //errprint("Fraction of word tokens kept:"+(addedTokens.toDouble/totalTokens)) + } + + protected def imp_finish_before_global(dist: WordDist) { + val model = dist.asInstanceOf[NgramWordDist].model + val oov = Seq("-OOV-") + + /* Add the distribution to the global stats before eliminating + infrequent words. */ + factory.note_dist_globally(dist) + + // If 'minimum_word_count' was given, then eliminate n-grams whose count + // is too small by replacing them with -OOV-. + // FIXME!!! This should almost surely operate at the word level, not the + // n-gram level. + if (minimum_word_count > 1) { + for ((ngram, count) <- model.iter_items if count < minimum_word_count) { + model.remove_item(ngram) + model.add_item(oov, count) + } + } + } + + def maybe_lowercase(ngram: Ngram) = + if (ignore_case) ngram.map(_ toLowerCase) else ngram + + def initialize_distribution(doc: GenericDistDocument, countstr: String, + is_training_set: Boolean) { + parse_counts(countstr) + // Now set the distribution on the document; but don't use the test + // set's distributions in computing global smoothing values and such. + // + // FIXME: What is the purpose of first_time_document_seen??? When does + // it occur that we see a document multiple times? + var first_time_document_seen = !seen_documents.contains(doc.title) + + val dist = factory.create_word_dist(note_globally = + is_training_set && first_time_document_seen) + add_parsed_ngrams(dist, parsed_ngrams) + seen_documents += doc.title + doc.dist = dist + } +} + +/** + * General factory for NgramWordDist distributions. + */ +abstract class NgramWordDistFactory extends WordDistFactory { } diff --git a/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringBigramWordDist.scala.bitrotted b/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringBigramWordDist.scala.bitrotted new file mode 100644 index 0000000..c829562 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringBigramWordDist.scala.bitrotted @@ -0,0 +1,282 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2011 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +//////// +//////// PseudoGoodTuringBigramWordDist.scala +//////// +//////// Copyright (c) 2011 Ben Wing. +//////// + +package opennlp.fieldspring.worddist + +import math._ +import collection.mutable +import util.control.Breaks._ + +import java.io._ + +import opennlp.fieldspring.util.collectionutil._ +import opennlp.fieldspring.util.printutil.errprint + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import WordDist.memoizer._ + +/** + * This class implements a bigram version of the abstract factory for the + * simple Good-Turing code in PseudoGoodTuringSmoothedWordDist.scala. + */ +class PseudoGoodTuringBigramWordDistFactory extends BigramWordDistFactory { + // Total number of types seen once + var total_num_types_seen_once = 0 + + // Estimate of number of unseen word types for all articles + var total_num_unseen_word_types = 0 + + /** + * Overall probabilities over all articles of seeing a word in an article, + * for all words seen at least once in any article, computed using the + * empirical frequency of a word among all articles, adjusted by the mass + * to be assigned to globally unseen words (words never seen at all), i.e. + * the value in 'globally_unseen_word_prob'. We start out by storing raw + * counts, then adjusting them. + */ + var overall_word_probs = create_word_double_map() + var owp_adjusted = false + + // The total probability mass to be assigned to words not seen at all in + // any article, estimated using Good-Turing smoothing as the unadjusted + // empirical probability of having seen a word once. + var globally_unseen_word_prob = 0.0 + + // For articles whose word counts are not known, use an empty list to + // look up in. + // unknown_article_counts = ([], []) + + def finish_global_distribution() { + /* We do in-place conversion of counts to probabilities. Make sure + this isn't done twice!! */ + assert (!owp_adjusted) + owp_adjusted = true + // Now, adjust overall_word_probs accordingly. + //// FIXME: A simple calculation reveals that in the scheme where we use + //// globally_unseen_word_prob, total_num_types_seen_once cancels out and + //// we never actually have to compute it. + total_num_types_seen_once = overall_word_probs.values count (_ == 1.0) + globally_unseen_word_prob = + total_num_types_seen_once.toDouble/total_num_word_tokens + for ((word, count) <- overall_word_probs) + overall_word_probs(word) = ( + count.toDouble/total_num_word_tokens*(1.0 - globally_unseen_word_prob)) + // A very rough estimate, perhaps totally wrong + total_num_unseen_word_types = + total_num_types_seen_once max (total_num_word_types/20) + if (debug("bigram")) + errprint("Total num types = %s, total num tokens = %s, total num_seen_once = %s, globally unseen word prob = %s, total mass = %s", + total_num_word_types, total_num_word_tokens, + total_num_types_seen_once, globally_unseen_word_prob, + globally_unseen_word_prob + (overall_word_probs.values sum)) + } + + def set_bigram_word_dist(doc: GenericDistDocument, + keys: Array[String], values: Array[Int], num_words: Int, + bigram_keys: Array[(String, String)], bigram_values: Array[Int], + num_bigrams: Int, is_training_set: Boolean) { + doc.dist = + new PseudoGoodTuringBigramWordDist(this, keys, values, num_words, + bigram_keys, bigram_values, num_bigrams, + note_globally = is_training_set) + } + + def create_word_dist() = + new PseudoGoodTuringBigramWordDist(this, Array[String](), Array[Int](), 0, + Array[(String, String)](), Array[Int](), 0) +} +/** + * Create a pseudo-Good-Turing smoothed word distribution given a table + * listing counts for each word, initialized from the given key/value pairs. + * + * @param key Array holding keys, possibly over-sized, so that the internal + * arrays from DynamicArray objects can be used + * @param values Array holding values corresponding to each key, possibly + * oversize + * @param num_words Number of actual key/value pairs to be stored + * @param note_globally If true, add the word counts to the global word count + * statistics. + */ + +class PseudoGoodTuringBigramWordDist( + val factory: PseudoGoodTuringBigramWordDistFactory, + unigramKeys: Array[String], + unigramValues: Array[Int], + num_unigrams: Int, + bigramKeys: Array[(String, String)], + bigramValues: Array[Int], + num_bigrams: Int, + val note_globally: Boolean = true +) extends BigramWordDist(unigramKeys, unigramValues, num_unigrams, + bigramKeys, bigramValues, num_bigrams) { + //val FastAlgorithms = FastPseudoGoodTuringSmoothedWordDist + type TThis = PseudoGoodTuringBigramWordDist + + if (note_globally) { + //assert(!factory.owp_adjusted) + for ((word, count) <- unicounts) { + if (!(factory.overall_word_probs contains word)) + factory.total_num_word_types += 1 + // Record in overall_word_probs; note more tokens seen. + factory.overall_word_probs(word) += count + factory.total_num_word_tokens += count + } + } + + /** Total probability mass to be assigned to all words not + seen in the article, estimated (motivated by Good-Turing + smoothing) as the unadjusted empirical probability of + having seen a word once. + */ + var unseen_mass = 0.5 + /** + Probability mass assigned in 'overall_word_probs' to all words not seen + in the article. This is 1 - (sum over W in A of overall_word_probs[W]). + The idea is that we compute the probability of seeing a word W in + article A as + + -- if W has been seen before in A, use the following: + COUNTS[W]/TOTAL_TOKENS*(1 - UNSEEN_MASS) + -- else, if W seen in any articles (W in 'overall_word_probs'), + use UNSEEN_MASS * (overall_word_probs[W] / OVERALL_UNSEEN_MASS). + The idea is that overall_word_probs[W] / OVERALL_UNSEEN_MASS is + an estimate of p(W | W not in A). We have to divide by + OVERALL_UNSEEN_MASS to make these probabilities be normalized + properly. We scale p(W | W not in A) by the total probability mass + we have available for all words not seen in A. + -- else, use UNSEEN_MASS * globally_unseen_word_prob / NUM_UNSEEN_WORDS, + where NUM_UNSEEN_WORDS is an estimate of the total number of words + "exist" but haven't been seen in any articles. One simple idea is + to use the number of words seen once in any article. This certainly + underestimates this number if not too many articles have been seen + but might be OK if many articles seen. + */ + var overall_unseen_mass = 1.0 + + def innerToString = ", %.2f unseen mass" format unseen_mass + + /** + * Here we compute the value of `overall_unseen_mass`, which depends + * on the global `overall_word_probs` computed from all of the + * distributions. + */ + + protected def imp_finish_after_global() { + // Make sure that overall_word_probs has been computed properly. + assert(factory.owp_adjusted) + + // Compute probabilities. Use a very simple version of Good-Turing + // smoothing where we assign to unseen words the probability mass of + // words seen once, and adjust all other probs accordingly. + val num_types_seen_once = unicounts.values count (_ == 1) + unseen_mass = + if (num_word_tokens > 0) + // If no words seen only once, we will have a problem if we assign 0 + // to the unseen mass, as unseen words will end up with 0 probability. + // However, if we assign a value of 1.0 to unseen_mass (which could + // happen in case all words seen exactly once), then we will end + // up assigning 0 probability to seen words. So we arbitrarily + // limit it to 0.5, which is pretty damn much mass going to unseen + // words. + 0.5 min ((1.0 max num_types_seen_once)/num_word_tokens) + else 0.5 + overall_unseen_mass = 1.0 - ( + (for (ind <- unicounts.keys) + yield factory.overall_word_probs(ind)) sum) + //if (use_sorted_list) + // counts = new SortedList(counts) + } + + override def finish(minimum_word_count: Int = 0) { + super.finish(minimum_word_count) + if (debug("lots")) { + errprint("""For word dist, total tokens = %s, unseen_mass = %s, overall unseen mass = %s""", + num_word_tokens, unseen_mass, overall_unseen_mass) + } + } + + def cosine_similarity(other: WordDist, partial: Boolean = false, + smoothed: Boolean = false) = { + throw new UnsupportedOperationException("Not implemented yet") + } + + def kl_divergence_34(other: BigramWordDist): Double = { + throw new UnsupportedOperationException("Not implemented yet") + } + + def lookup_word(word: Word) = { + assert(finished) + // if (debug("some")) { + // errprint("Found counts for document %s, num word types = %s", + // doc, wordcounts(0).length) + // errprint("Unknown prob = %s, overall_unseen_mass = %s", + // unseen_mass, overall_unseen_mass) + // } + val retval = unicounts.get(word) match { + case None => { + factory.overall_word_probs.get(word) match { + case None => { +if(debug("bigram")) + errprint("unseen_mass: %s, globally_unseen_word_prob %s, total_num_unseen_word_types %s", unseen_mass, factory.globally_unseen_word_prob, factory.total_num_unseen_word_types) + val wordprob = (unseen_mass*factory.globally_unseen_word_prob + / factory.total_num_unseen_word_types) + if (debug("bigram")) + errprint("Word %s, never seen at all, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + case Some(owprob) => { + val wordprob = unseen_mass * owprob / overall_unseen_mass + //if (wordprob <= 0) + // warning("Bad values; unseen_mass = %s, overall_word_probs[word] = %s, overall_unseen_mass = %s", + // unseen_mass, factory.overall_word_probs[word], + // factory.overall_unseen_mass) + if (debug("bigram")) + errprint("Word %s, seen but not in document, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + } + } + case Some(wordcount) => { + //if (wordcount <= 0 or num_word_tokens <= 0 or unseen_mass >= 1.0) + // warning("Bad values; wordcount = %s, unseen_mass = %s", + // wordcount, unseen_mass) + // for ((word, count) <- self.counts) + // errprint("%s: %s", word, count) + val wordprob = wordcount.toDouble/num_word_tokens*(1.0 - unseen_mass) + if (debug("bigram")) + errprint("Word %s, seen in document, wordprob = %s", + unmemoize_string(word), wordprob) + wordprob + } + } + retval + } + + def lookup_bigram(word: Word) = { + throw new UnsupportedOperationException("Not implemented yet") + } +} + diff --git a/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringUnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringUnigramWordDist.scala new file mode 100644 index 0000000..08e7bd2 --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/PseudoGoodTuringUnigramWordDist.scala @@ -0,0 +1,109 @@ +/////////////////////////////////////////////////////////////////////////////// +// PseudoGoodTuringUnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import opennlp.fieldspring.util.printutil.errprint + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ + +/** + * This class implements a simple version of Good-Turing smoothing where we + * assign probability mass to unseen words equal to the probability mass of + * all words seen once, and rescale the remaining probabilities accordingly. + * ("Proper" Good-Turing is more general and adjusts the probabilities of + * words seen N times according to the number of words seen N-1 times. + * FIXME: I haven't thought carefully enough to make sure that this + * simplified version actually makes theoretical sense, although I assume it + * does because it can be seen as a better version of add-one or + * add-some-value smoothing, where instead of just adding some arbitrary + * value to unseen words, we determine the amount to add in a way that makes + * theoretical sense (and in particular ensures that we don't assign 99% + * of the mass or whatever to unseen words, as can happen to add-one smoothing + * especially for bigrams or trigrams). + */ +class PseudoGoodTuringUnigramWordDistFactory( + interpolate_string: String + ) extends DiscountedUnigramWordDistFactory(interpolate_string == "yes") { + // Total number of types seen once + var total_num_types_seen_once = 0 + + // The total probability mass to be assigned to words not seen at all in + // any document, estimated using Good-Turing smoothing as the unadjusted + // empirical probability of having seen a word once. + var globally_unseen_word_prob = 0.0 + + // For documents whose word counts are not known, use an empty list to + // look up in. + // unknown_document_counts = ([], []) + + override def finish_global_distribution() { + // Now, adjust overall_word_probs accordingly. + //// FIXME: A simple calculation reveals that in the scheme where we use + //// globally_unseen_word_prob, total_num_types_seen_once cancels out and + //// we never actually have to compute it. + /* Definitely no longer used in the "new way". + total_num_types_seen_once = overall_word_probs.values count (_ == 1.0) + globally_unseen_word_prob = + total_num_types_seen_once.toDouble/total_num_word_tokens + */ + // A very rough estimate, perhaps totally wrong + total_num_unseen_word_types = + total_num_types_seen_once max (total_num_word_types/20) + if (debug("tons")) + errprint("Total num types = %s, total num tokens = %s, total num_seen_once = %stotal mass = %s", + total_num_word_types, total_num_word_tokens, + total_num_types_seen_once, + (overall_word_probs.values sum)) + super.finish_global_distribution() + } + + def create_word_dist(note_globally: Boolean) = + new PseudoGoodTuringUnigramWordDist(this, note_globally) +} + +class PseudoGoodTuringUnigramWordDist( + factory: WordDistFactory, + note_globally: Boolean +) extends DiscountedUnigramWordDist( + factory, note_globally + ) { + /** + * Here we compute the value of `overall_unseen_mass`, which depends + * on the global `overall_word_probs` computed from all of the + * distributions. + */ + override protected def imp_finish_after_global() { + // Compute probabilities. Use a very simple version of Good-Turing + // smoothing where we assign to unseen words the probability mass of + // words seen once, and adjust all other probs accordingly. + val num_types_seen_once = model.iter_items count { case (k,v) => v == 1 } + unseen_mass = + if (model.num_tokens > 0) + // If no words seen only once, we will have a problem if we assign 0 + // to the unseen mass, as unseen words will end up with 0 probability. + // However, if we assign a value of 1.0 to unseen_mass (which could + // happen in case all words seen exactly once), then we will end + // up assigning 0 probability to seen words. So we arbitrarily + // limit it to 0.5, which is pretty damn much mass going to unseen + // words. + 0.5 min ((1.0 max num_types_seen_once)/model.num_tokens) + else 0.5 + super.imp_finish_after_global() + } +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/UnigramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/UnigramWordDist.scala new file mode 100644 index 0000000..bb2b21c --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/UnigramWordDist.scala @@ -0,0 +1,390 @@ +/////////////////////////////////////////////////////////////////////////////// +// UnigramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// Copyright (C) 2012 Mike Speriosu, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import math._ +import collection.mutable +import util.control.Breaks._ + +import java.io._ + +import opennlp.fieldspring.util.collectionutil.DynamicArray +import opennlp.fieldspring.util.textdbutil +import opennlp.fieldspring.util.ioutil.{FileHandler, FileFormatException} +import opennlp.fieldspring.util.printutil.{errprint, warning} + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.gridlocate.GenericTypes._ + +import WordDist.memoizer._ + +/** + * An interface for storing and retrieving vocabulary items (e.g. words, + * n-grams, etc.). + * + * @tparam Item Type of the items stored. + */ +class UnigramStorage extends ItemStorage[Word] { + + /** + * A map (or possibly a "sorted list" of tuples, to save memory?) of + * (word, count) items, specifying the counts of all words seen + * at least once. These are given as double because in some cases + * they may store "partial" counts (in particular, when the K-d tree + * code does interpolation on cells). FIXME: This seems ugly, perhaps + * there is a better way? + */ + val counts = create_word_double_map() + var tokens_accurate = true + var num_tokens_val = 0.0 + + def add_item(item: Word, count: Double) { + counts(item) += count + num_tokens_val += count + } + + def set_item(item: Word, count: Double) { + counts(item) = count + tokens_accurate = false + } + + def remove_item(item: Word) { + counts -= item + tokens_accurate = false + } + + def contains(item: Word) = counts contains item + + def get_item(item: Word) = counts(item) + + def iter_items = counts.toIterable + + def iter_keys = counts.keys + + def num_tokens = { + if (!tokens_accurate) { + num_tokens_val = counts.values.sum + tokens_accurate = true + } + num_tokens_val + } + + def num_types = counts.size +} + +/** + * Unigram word distribution with a table listing counts for each word, + * initialized from the given key/value pairs. + * + * @param key Array holding keys, possibly over-sized, so that the internal + * arrays from DynamicArray objects can be used + * @param values Array holding values corresponding to each key, possibly + * oversize + * @param num_words Number of actual key/value pairs to be stored + * statistics. + */ + +abstract class UnigramWordDist( + factory: WordDistFactory, + note_globally: Boolean + ) extends WordDist(factory, note_globally) with FastSlowKLDivergence { + type Item = Word + val pmodel = new UnigramStorage() + val model = pmodel + + def innerToString: String + + override def toString = { + val finished_str = + if (!finished) ", unfinished" else "" + val num_words_to_print = 15 + val need_dots = model.num_types > num_words_to_print + val items = + for ((word, count) <- + model.iter_items.toSeq.sortWith(_._2 > _._2). + view(0, num_words_to_print)) + yield "%s=%s" format (unmemoize_string(word), count) + val words = (items mkString " ") + (if (need_dots) " ..." else "") + "UnigramWordDist(%d types, %s tokens%s%s, %s)" format ( + model.num_types, model.num_tokens, innerToString, + finished_str, words) + } + + /** + * This is a basic unigram implementation of the computation of the + * KL-divergence between this distribution and another distribution, + * including possible debug information. + * + * Computing the KL divergence is a bit tricky, especially in the + * presence of smoothing, which assigns probabilities even to words not + * seen in either distribution. We have to take into account: + * + * 1. Words in this distribution (may or may not be in the other). + * 2. Words in the other distribution that are not in this one. + * 3. Words in neither distribution but seen globally. + * 4. Words never seen at all. + * + * The computation of steps 3 and 4 depends heavily on the particular + * smoothing algorithm; in the absence of smoothing, these steps + * contribute nothing to the overall KL-divergence. + * + */ + def slow_kl_divergence_debug(xother: WordDist, partial: Boolean = false, + return_contributing_words: Boolean = false) = { + val other = xother.asInstanceOf[UnigramWordDist] + var kldiv = 0.0 + val contribs = + if (return_contributing_words) mutable.Map[String, Double]() else null + // 1. + for (word <- model.iter_keys) { + val p = lookup_word(word) + val q = other.lookup_word(word) + if (q == 0.0) + { } // This is OK, we just skip these words + else if (p <= 0.0 || q <= 0.0) + errprint("Warning: problematic values: p=%s, q=%s, word=%s", p, q, word) + else { + kldiv += p*(log(p) - log(q)) + if (return_contributing_words) + contribs(unmemoize_string(word)) = p*(log(p) - log(q)) + } + } + + if (partial) + (kldiv, contribs) + else { + // Step 2. + for (word <- other.model.iter_keys if !(model contains word)) { + val p = lookup_word(word) + val q = other.lookup_word(word) + kldiv += p*(log(p) - log(q)) + if (return_contributing_words) + contribs(unmemoize_string(word)) = p*(log(p) - log(q)) + } + + val retval = kldiv + kl_divergence_34(other) + (retval, contribs) + } + } + + /** + * Steps 3 and 4 of KL-divergence computation. + * @see #slow_kl_divergence_debug + */ + def kl_divergence_34(other: UnigramWordDist): Double + + def get_nbayes_logprob(xworddist: WordDist) = { + val worddist = xworddist.asInstanceOf[UnigramWordDist] + var logprob = 0.0 + for ((word, count) <- worddist.model.iter_items) { + val value = lookup_word(word) + if (value <= 0) { + // FIXME: Need to figure out why this happens (perhaps the word was + // never seen anywhere in the training data? But I thought we have + // a case to handle that) and what to do instead. + errprint("Warning! For word %s, prob %s out of range", word, value) + } else + logprob += log(value) + } + // FIXME: Also use baseline (prior probability) + logprob + } + + /** + * Return the probabilitiy of a given word in the distribution. + */ + def lookup_word(word: Word): Double + + /** + * Look for the most common word matching a given predicate. + * @param pred Predicate, passed the raw (unmemoized) form of a word. + * Should return true if a word matches. + * @return Most common word matching the predicate (wrapped with + * Some()), or None if no match. + */ + def find_most_common_word(pred: String => Boolean): Option[Word] = { + val filtered = + (for ((word, count) <- model.iter_items if pred(unmemoize_string(word))) + yield (word, count)).toSeq + if (filtered.length == 0) None + else { + val (maxword, maxcount) = filtered maxBy (_._2) + Some(maxword) + } + } +} + +class DefaultUnigramWordDistConstructor( + factory: WordDistFactory, + ignore_case: Boolean, + stopwords: Set[String], + whitelist: Set[String], + minimum_word_count: Int = 1 +) extends WordDistConstructor(factory: WordDistFactory) { + /** + * Initial size of the internal DynamicArray objects; an optimization. + */ + protected val initial_dynarr_size = 1000 + /** + * Internal DynamicArray holding the keys (canonicalized words). + */ + protected val keys_dynarr = + new DynamicArray[String](initial_alloc = initial_dynarr_size) + /** + * Internal DynamicArray holding the values (word counts). + */ + protected val values_dynarr = + new DynamicArray[Int](initial_alloc = initial_dynarr_size) + /** + * Set of the raw, uncanonicalized words seen, to check that an + * uncanonicalized word isn't seen twice. (Canonicalized words may very + * well occur multiple times.) + */ + protected val raw_keys_set = mutable.Set[String]() + + protected def parse_counts(countstr: String) { + keys_dynarr.clear() + values_dynarr.clear() + raw_keys_set.clear() + for ((word, count) <- textdbutil.decode_count_map(countstr)) { + /* FIXME: Is this necessary? */ + if (raw_keys_set contains word) + throw FileFormatException( + "Word %s seen twice in same counts list" format word) + raw_keys_set += word + keys_dynarr += word + values_dynarr += count + } + } + + var seen_documents = new scala.collection.mutable.HashSet[String]() + + // Returns true if the word was counted, false if it was ignored due to stoplisting + // and/or whitelisting + protected def add_word_with_count(model: UnigramStorage, word: String, + count: Int): Boolean = { + val lword = maybe_lowercase(word) + if (!stopwords.contains(lword) && + (whitelist.size == 0 || whitelist.contains(lword))) { + model.add_item(memoize_string(lword), count) + true + } + else + false + } + + protected def imp_add_document(dist: WordDist, words: Iterable[String]) { + val model = dist.asInstanceOf[UnigramWordDist].model + for (word <- words) + add_word_with_count(model, word, 1) + } + + protected def imp_add_word_distribution(dist: WordDist, other: WordDist, + partial: Double) { + // FIXME: Implement partial! + val model = dist.asInstanceOf[UnigramWordDist].model + val othermodel = other.asInstanceOf[UnigramWordDist].model + for ((word, count) <- othermodel.iter_items) + model.add_item(word, count) + } + + /** + * Actual implementation of `add_keys_values` by subclasses. + * External callers should use `add_keys_values`. + */ + protected def imp_add_keys_values(dist: WordDist, keys: Array[String], + values: Array[Int], num_words: Int) { + val model = dist.asInstanceOf[UnigramWordDist].model + var addedTypes = 0 + var addedTokens = 0 + var totalTokens = 0 + for (i <- 0 until num_words) { + if(add_word_with_count(model, keys(i), values(i))) { + addedTypes += 1 + addedTokens += values(i) + } + totalTokens += values(i) + } + // Way too much output to keep enabled + //errprint("Fraction of word types kept:"+(addedTypes.toDouble/num_words)) + //errprint("Fraction of word tokens kept:"+(addedTokens.toDouble/totalTokens)) + } + + /** + * Incorporate a set of (key, value) pairs into the distribution. + * The number of pairs to add should be taken from `num_words`, not from + * the actual length of the arrays passed in. The code should be able + * to handle the possibility that the same word appears multiple times, + * adding up the counts for each appearance of the word. + */ + protected def add_keys_values(dist: WordDist, + keys: Array[String], values: Array[Int], num_words: Int) { + assert(!dist.finished) + assert(!dist.finished_before_global) + assert(keys.length >= num_words) + assert(values.length >= num_words) + imp_add_keys_values(dist, keys, values, num_words) + } + + protected def imp_finish_before_global(gendist: WordDist) { + val dist = gendist.asInstanceOf[UnigramWordDist] + val model = dist.model + val oov = memoize_string("-OOV-") + + /* Add the distribution to the global stats before eliminating + infrequent words. */ + factory.note_dist_globally(dist) + + // If 'minimum_word_count' was given, then eliminate words whose count + // is too small. + if (minimum_word_count > 1) { + for ((word, count) <- model.iter_items if count < minimum_word_count) { + model.remove_item(word) + model.add_item(oov, count) + } + } + } + + def maybe_lowercase(word: String) = + if (ignore_case) word.toLowerCase else word + + def initialize_distribution(doc: GenericDistDocument, countstr: String, + is_training_set: Boolean) { + parse_counts(countstr) + // Now set the distribution on the document; but don't use the test + // set's distributions in computing global smoothing values and such. + // + // FIXME: What is the purpose of first_time_document_seen??? When does + // it occur that we see a document multiple times? + var first_time_document_seen = !seen_documents.contains(doc.title) + + val dist = factory.create_word_dist(note_globally = + is_training_set && first_time_document_seen) + add_keys_values(dist, keys_dynarr.array, values_dynarr.array, + keys_dynarr.length) + seen_documents += doc.title + doc.dist = dist + } +} + +/** + * General factory for UnigramWordDist distributions. + */ +abstract class UnigramWordDistFactory extends WordDistFactory { } diff --git a/src/main/scala/opennlp/fieldspring/worddist/UnsmoothedNgramWordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/UnsmoothedNgramWordDist.scala new file mode 100644 index 0000000..455820b --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/UnsmoothedNgramWordDist.scala @@ -0,0 +1,75 @@ +/////////////////////////////////////////////////////////////////////////////// +// UnsmoothedNgramWordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +class UnsmoothedNgramWordDistFactory extends NgramWordDistFactory { + def create_word_dist(note_globally: Boolean) = + new UnsmoothedNgramWordDist(this, note_globally) + + def finish_global_distribution() { + } +} + +class UnsmoothedNgramWordDist( + gen_factory: WordDistFactory, + note_globally: Boolean +) extends NgramWordDist(gen_factory, note_globally) { + import NgramStorage.Ngram + + type TThis = UnsmoothedNgramWordDist + + def innerToString = "" + + // For some reason, retrieving this value from the model is fantastically slow + var num_tokens = 0.0 + + protected def imp_finish_after_global() { + num_tokens = model.num_tokens + } + + def fast_kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false) = { + assert(false, "Not implemented") + 0.0 + } + + def cosine_similarity(other: WordDist, partial: Boolean = false, + smoothed: Boolean = false) = { + assert(false, "Not implemented") + 0.0 + } + + def kl_divergence_34(other: NgramWordDist) = { + assert(false, "Not implemented") + 0.0 + } + + /** + * Actual implementation of steps 3 and 4 of KL-divergence computation, given + * a value that we may want to compute as part of step 2. + */ + def inner_kl_divergence_34(other: TThis, + overall_probs_diff_words: Double) = { + assert(false, "Not implemented") + 0.0 + } + + def lookup_ngram(ngram: Ngram) = + model.get_item(ngram).toDouble / num_tokens +} diff --git a/src/main/scala/opennlp/fieldspring/worddist/WordDist.scala b/src/main/scala/opennlp/fieldspring/worddist/WordDist.scala new file mode 100644 index 0000000..e2f3e6f --- /dev/null +++ b/src/main/scala/opennlp/fieldspring/worddist/WordDist.scala @@ -0,0 +1,522 @@ +/////////////////////////////////////////////////////////////////////////////// +// WordDist.scala +// +// Copyright (C) 2010, 2011, 2012 Ben Wing, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.worddist + +import math._ + +import opennlp.fieldspring.util.ioutil.FileHandler +import opennlp.fieldspring.util.printutil.{errprint, warning} +import opennlp.fieldspring.util.Serializer + +import opennlp.fieldspring.gridlocate.GridLocateDriver.Debug._ +import opennlp.fieldspring.gridlocate.GenericTypes._ +// FIXME! For reference to GridLocateDriver.Params +import opennlp.fieldspring.gridlocate.GridLocateDriver + +import WordDist.memoizer._ + +// val use_sorted_list = false + +////////////////////////////////////////////////////////////////////////////// +// Word distributions // +////////////////////////////////////////////////////////////////////////////// + +/** + * A trait that adds an implementation of `#kl_divergence` in terms of + * a slow version with debugging info and a fast version, and optionally + * compares the two. + */ +trait FastSlowKLDivergence { + def get_kl_divergence_cache(): KLDivergenceCache + + /** + * This is a basic implementation of the computation of the KL-divergence + * between this distribution and another distribution, including possible + * debug information. Useful for checking against the other, faster + * implementation in `fast_kl_divergence`. + * + * @param xother The other distribution to compute against. + * @param partial If true, only compute the contribution involving words + * that exist in our distribution; otherwise we also have to take into + * account words in the other distribution even if we haven't seen them, + * and often also (esp. in the presence of smoothing) the contribution + * of all other words in the vocabulary. + * @param return_contributing_words If true, return a map listing + * the words (or n-grams, or whatever) in both distributions (or, for a + * partial KL-divergence, the words in our distribution) and the amount + * of total KL-divergence they compute, useful for debugging. + * + * @return A tuple of (divergence, word_contribs) where the first + * value is the actual KL-divergence and the second is the map + * of word contributions as described above; will be null if + * not requested. + */ + def slow_kl_divergence_debug(xother: WordDist, partial: Boolean = false, + return_contributing_words: Boolean = false): + (Double, collection.Map[String, Double]) + + /** + * Compute the KL-divergence using the "slow" algorithm of + * #slow_kl_divergence_debug, but without requesting or returning debug + * info. + */ + def slow_kl_divergence(other: WordDist, partial: Boolean = false) = { + val (kldiv, contribs) = + slow_kl_divergence_debug(other, partial, false) + kldiv + } + + /** + * A fast, optimized implementation of KL-divergence. See the discussion in + * `slow_kl_divergence_debug`. + */ + def fast_kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false): Double + + /** + * Check fast and slow KL-divergence versions against each other. + */ + def test_kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false) = { + val slow_kldiv = slow_kl_divergence(other, partial) + val fast_kldiv = fast_kl_divergence(cache, other, partial) + if (abs(fast_kldiv - slow_kldiv) > 1e-8) { + errprint("Fast KL-div=%s but slow KL-div=%s", fast_kldiv, slow_kldiv) + assert(fast_kldiv == slow_kldiv) + } + fast_kldiv + } + + /** + * The actual kl_divergence implementation. The value `test_kldiv` + * below can be set to true to compare fast and slow against either + * other, throwing an assertion failure if they are more than a very + * small amount different (the small amount rather than 0 to account for + * possible rounding error). + */ + protected def imp_kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean) = { + val test_kldiv = GridLocateDriver.Params.test_kl + if (test_kldiv) + test_kl_divergence(cache, other, partial) + else + fast_kl_divergence(cache, other, partial) + } +} + +/** + * A class that controls how to construct a word distribution, whether the + * source comes from a data file, a document, another distribution, etc. + * The actual words added can be transformed in various ways, e.g. + * case-folding the words (typically by converting to all lowercase), ignoring + * words seen in a stoplist, converting some words to a generic -OOV-, + * eliminating words seen less than a minimum nmber of times, etc. + */ +abstract class WordDistConstructor(factory: WordDistFactory) { + /** + * Actual implementation of `add_document` by subclasses. + * External callers should use `add_document`. + */ + protected def imp_add_document(dist: WordDist, words: Iterable[String]) + + /** + * Actual implementation of `add_word_distribution` by subclasses. + * External callers should use `add_word_distribution`. + */ + protected def imp_add_word_distribution(dist: WordDist, other: WordDist, + partial: Double = 1.0) + + /** + * Actual implementation of `finish_before_global` by subclasses. + * External callers should use `finish_before_global`. + */ + protected def imp_finish_before_global(dist: WordDist) + + /** + * Incorporate a document into the distribution. The document is described + * by a sequence of words. + */ + def add_document(dist: WordDist, words: Iterable[String]) { + assert(!dist.finished) + assert(!dist.finished_before_global) + imp_add_document(dist, words) + } + + /** + * Incorporate the given distribution into our distribution. + * `partial` is a scaling factor (between 0.0 and 1.0) used for + * interpolating multiple distributions. + */ + def add_word_distribution(dist: WordDist, other: WordDist, + partial: Double = 1.0) { + assert(!dist.finished) + assert(!dist.finished_before_global) + assert(partial >= 0.0 && partial <= 1.0) + imp_add_word_distribution(dist, other, partial) + } + + /** + * Partly finish computation of distribution. This is called when the + * distribution has been completely populated with words, and no more + * modifications (e.g. incorporation of words or other distributions) will + * be made to the distribution. It should do any additional changes that + * depend on the distribution being complete, but which do not depend on + * the global word-distribution statistics having been computed. (These + * statistics can be computed only after *all* word distributions that + * are used to create these global statistics have been completely + * populated.) + * + * @see #finish_after_global + * + */ + def finish_before_global(dist: WordDist) { + assert(!dist.finished) + assert(!dist.finished_before_global) + imp_finish_before_global(dist) + dist.finished_before_global = true + } + + /** + * Create the word distribution of a document, given the value of the field + * describing the distribution (typically called "counts" or "text"). + * + * @param doc Document to set the distribution of. + * @param diststr String from the document file, describing the distribution. + * @param is_training_set True if this document is in the training set. + * Generally, global (e.g. back-off) statistics should be initialized + * only from training-set documents. + */ + def initialize_distribution(doc: GenericDistDocument, countstr: String, + is_training_set: Boolean) +} + +class KLDivergenceCache { +} + +/** + * A factory object for WordDists (word distributions). Currently, there is + * only one factory object in a particular instance of the application + * (i.e. it's a singleton), but the particular factory used depends on a + * command-line parameter. + */ +abstract class WordDistFactory { + /** + * Total number of word types seen (size of vocabulary) + */ + var total_num_word_types = 0 + + /** + * Total number of word tokens seen + */ + var total_num_word_tokens = 0 + + /** + * Corresponding constructor object for building up the word distribution + */ + var constructor: WordDistConstructor = _ + + def set_word_dist_constructor(constructor: WordDistConstructor) { + this.constructor = constructor + } + + def create_word_dist(): WordDist = create_word_dist(note_globally = false) + + /** + * Create an empty word distribution. If `note_globally` is true, + * the distribution is meant to be added to the global word-distribution + * statistics (see below). + */ + def create_word_dist(note_globally: Boolean): WordDist + + /** + * Add the given distribution to the global word-distribution statistics, + * if any. + */ + def note_dist_globally(dist: WordDist) { } + + /** + * Finish computing any global word-distribution statistics, e.g. tables for + * doing back-off. This is called after all of the relevant WordDists + * have been created. In practice, the "relevant" distributions are those + * associated with training documents, which are read in + * during `read_word_counts`. + */ + def finish_global_distribution() +} + +object WordDist { + /** + * Object describing how we memoize words (i.e. convert them to Int + * indices, for faster operations on them). + * + * FIXME: Should probably be stored globally or at least elsewhere, since + * memoization is more general than just for words in word distributions, + * and is used elsewhere in the code for other things. We should probably + * move the memoization code into the `util` package. + */ + val memoizer = new IntStringMemoizer + //val memoizer = IdentityMemoizer +} + +/** + * An interface for storing and retrieving vocabulary items (e.g. words, + * n-grams, etc.). + * + * @tparam Item Type of the items stored. + */ +trait ItemStorage[Item] { + + /** + * Add an item with the given count. If the item exists already, + * add the count to the existing value. + */ + def add_item(item: Item, count: Double) + + /** + * Set the item to the given count. If the item exists already, + * replace its value with the given one. + */ + def set_item(item: Item, count: Double) + + /** + * Remove an item, if it exists. + */ + def remove_item(item: Item) + + /** + * Return whether a given item is stored. + */ + def contains(item: Item): Boolean + + /** + * Return the count of a given item. + */ + def get_item(item: Item): Double + + /** + * Iterate over all items that are stored. + */ + def iter_items: Iterable[(Item, Double)] + + /** + * Iterate over all keys that are stored. + */ + def iter_keys: Iterable[Item] + + /** + * Total number of tokens stored. + */ + def num_tokens: Double + + /** + * Total number of item types (i.e. number of distinct items) + * stored. + */ + def num_types: Int +} + +/** + * A word distribution, i.e. a statistical distribution over words in + * a document, cell, etc. + */ +abstract class WordDist(factory: WordDistFactory, + val note_globally: Boolean) { + type Item + val model: ItemStorage[Item] + + /** + * Whether we have finished computing the distribution, and therefore can + * reliably do probability lookups. + */ + var finished = false + /** + * Whether we have already called `finish_before_global`. If so, this + * means we can't modify the distribution any more. + */ + var finished_before_global = false + + /** + * Incorporate a document into the distribution. + */ + def add_document(words: Iterable[String]) { + factory.constructor.add_document(this, words) + } + + /** + * Incorporate the given distribution into our distribution. + * `partial` is a scaling factor (between 0.0 and 1.0) used for + * interpolating multiple distributions. + */ + def add_word_distribution(other: WordDist, partial: Double = 1.0) { + factory.constructor.add_word_distribution(this, other, partial) + } + + /** + * Partly finish computation of distribution. This is called when the + * distribution has been completely populated with words, and no more + * modifications (e.g. incorporation of words or other distributions) will + * be made to the distribution. It should do any additional changes that + * depend on the distribution being complete, but which do not depend on + * the global word-distribution statistics having been computed. (These + * statistics can be computed only after *all* word distributions that + * are used to create these global statistics have been completely + * populated.) + * + * @see #finish_after_global + * + */ + def finish_before_global() { + factory.constructor.finish_before_global(this) + } + + /** + * Actual implementation of `finish_after_global` by subclasses. + * External callers should use `finish_after_global`. + */ + protected def imp_finish_after_global() + + /** + * Completely finish computation of the word distribution. This is called + * after finish_global_distribution() on the factory method, and can be + * used to compute values for the distribution that depend on the + * global word-distribution statistics. + */ + def finish_after_global() { + assert(!finished) + assert(finished_before_global) + imp_finish_after_global() + finished = true + } + + def dunning_log_likelihood_1(a: Double, b: Double, c: Double, d: Double) = { + val cprime = a+c + val dprime = b+d + val factor = (a+b)/(cprime+dprime) + val e1 = cprime*factor + val e2 = dprime*factor + val f1 = if (a > 0) a*math.log(a/e1) else 0.0 + val f2 = if (b > 0) b*math.log(b/e2) else 0.0 + val g2 = 2*(f1+f2) + g2 + } + + def dunning_log_likelihood_2(a: Double, b: Double, c: Double, d: Double) = { + val N = a+b+c+d + def m(x: Double) = if (x > 0) x*math.log(x) else 0.0 + 2*(m(a) + m(b) + m(c) + m(d) + m(N) - m(a+b) - m(a+c) - m(b+d) - m(c+d)) + } + + /** + * Return the Dunning log-likelihood of an item in two different corpora, + * indicating how much more likely the item is in this corpus than the + * other to be different. + */ + def dunning_log_likelihood_2x1(item: Item, other: WordDist) = { + val a = model.get_item(item).toDouble + // This cast is kind of ugly but I don't see a way around it. + val b = other.model.get_item(item.asInstanceOf[other.Item]).toDouble + val c = model.num_tokens.toDouble - a + val d = other.model.num_tokens.toDouble - b + val val1 = dunning_log_likelihood_1(a, b, c, d) + val val2 = dunning_log_likelihood_2(a, b, c, d) + if (debug("dunning")) + errprint("Comparing log-likelihood for %g %g %g %g = %g vs. %g", + a, b, c, d, val1, val2) + val1 + } + + def dunning_log_likelihood_2x2(item: Item, other_b: WordDist, + other_c: WordDist, other_d: WordDist) = { + val a = model.get_item(item).toDouble + // This cast is kind of ugly but I don't see a way around it. + val b = other_b.model.get_item(item.asInstanceOf[other_b.Item]).toDouble + val c = other_c.model.get_item(item.asInstanceOf[other_c.Item]).toDouble + val d = other_d.model.get_item(item.asInstanceOf[other_d.Item]).toDouble + val val1 = dunning_log_likelihood_1(a, b, c, d) + val val2 = dunning_log_likelihood_2(a, b, c, d) + if (debug("dunning")) + errprint("Comparing log-likelihood for %g %g %g %g = %g vs. %g", + a, b, c, d, val1, val2) + val2 + } + + /** + * Actual implementation of `kl_divergence` by subclasses. + * External callers should use `kl_divergence`. + */ + protected def imp_kl_divergence(cache: KLDivergenceCache, + other: WordDist, partial: Boolean): Double + + /** + * Compute the KL-divergence between this distribution and another + * distribution. + * + * @param cache Cached information about this distribution, used to + * speed up computations. Never needs to be supplied; null can always + * be given, to cause a new cache to be created. + * @param other The other distribution to compute against. + * @param partial If true, only compute the contribution involving words + * that exist in our distribution; otherwise we also have to take + * into account words in the other distribution even if we haven't + * seen them, and often also (esp. in the presence of smoothing) the + * contribution of all other words in the vocabulary. + * + * @return The KL-divergence value. + */ + def kl_divergence(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false) = { + assert(finished) + assert(other.finished) + imp_kl_divergence(cache, other, partial) + } + + def get_kl_divergence_cache(): KLDivergenceCache = null + + /** + * Compute the symmetric KL-divergence between two distributions by averaging + * the respective one-way KL-divergences in each direction. + * + * @param partial Same as in `kl_divergence`. + */ + def symmetric_kldiv(cache: KLDivergenceCache, other: WordDist, + partial: Boolean = false) = { + 0.5*this.kl_divergence(cache, other, partial) + + 0.5*this.kl_divergence(cache, other, partial) + } + + /** + * Implementation of the cosine similarity between this and another + * distribution, using either unsmoothed or smoothed probabilities. + * + * @param partial Same as in `kl_divergence`. + * @param smoothed If true, use smoothed probabilities, if smoothing exists; + * otherwise, do unsmoothed. + */ + def cosine_similarity(other: WordDist, partial: Boolean = false, + smoothed: Boolean = false): Double + + /** + * For a document described by its distribution 'worddist', return the + * log probability log p(worddist|other worddist) using a Naive Bayes + * algorithm. + * + * @param worddist Distribution of document. + */ + def get_nbayes_logprob(worddist: WordDist): Double +} diff --git a/src/test/scala/opennlp/fieldspring/topo/Coordinate.scala b/src/test/scala/opennlp/fieldspring/topo/Coordinate.scala new file mode 100644 index 0000000..586c2af --- /dev/null +++ b/src/test/scala/opennlp/fieldspring/topo/Coordinate.scala @@ -0,0 +1,50 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.topo + +import org.specs._ +import org.specs.runner._ + +class CoordinateTest extends JUnit4(CoordinateSpec) +object CoordinateSpec extends Specification { + + "A degree-constructed coordinate" should { + val coordinate = Coordinate.fromDegrees(45, -45) + "have the correct radian value for latitude" in { + coordinate.getLat must_== math.Pi / 4 + } + + "have the correct radian value for longitude" in { + coordinate.getLng must_== -math.Pi / 4 + } + + "be equal to its radian-constructed equivalent" in { + coordinate must_== Coordinate.fromRadians(math.Pi / 4, -math.Pi / 4) + } + } + + "A coordinate at the origin" should { + val coordinate = Coordinate.fromDegrees(0, 0) + "have the correct angular distance from a coordinate 1 radian away horizontally" in { + coordinate.distance(Coordinate.fromRadians(0, 1)) must_== 1 + } + + "have the correct distance from a coordinate 1 radian away vertically" in { + coordinate.distance(Coordinate.fromRadians(1, 0)) must_== 1 + } + } +} + diff --git a/src/test/scala/testparse.scala b/src/test/scala/testparse.scala new file mode 100644 index 0000000..e0b51ff --- /dev/null +++ b/src/test/scala/testparse.scala @@ -0,0 +1,138 @@ +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical._ +import scala.util.parsing.input.CharArrayReader.EofCh + +sealed abstract class Expr { + def matches(x: Seq[String]): Boolean +} + +case class EConst(value: Seq[String]) extends Expr { + def matches(x: Seq[String]) = x containsSlice value +} + +case class EAnd(left:Expr, right:Expr) extends Expr { + def matches(x: Seq[String]) = left.matches(x) && right.matches(x) +} + +case class EOr(left:Expr, right:Expr) extends Expr { + def matches(x: Seq[String]) = left.matches(x) || right.matches(x) +} + +case class ENot(e:Expr) extends Expr { + def matches(x: Seq[String]) = !e.matches(x) +} + +class ExprLexical extends StdLexical { + override def token: Parser[Token] = floatingToken | super.token + + def floatingToken: Parser[Token] = + rep1(digit) ~ optFraction ~ optExponent ^^ + { case intPart ~ frac ~ exp => NumericLit( + (intPart mkString "") :: frac :: exp :: Nil mkString "")} + + def chr(c:Char) = elem("", ch => ch==c ) + def sign = chr('+') | chr('-') + def optSign = opt(sign) ^^ { + case None => "" + case Some(sign) => sign + } + def fraction = '.' ~ rep(digit) ^^ { + case dot ~ ff => dot :: (ff mkString "") :: Nil mkString "" + } + def optFraction = opt(fraction) ^^ { + case None => "" + case Some(fraction) => fraction + } + def exponent = (chr('e') | chr('E')) ~ optSign ~ rep1(digit) ^^ { + case e ~ optSign ~ exp => e :: optSign :: (exp mkString "") :: Nil mkString "" + } + def optExponent = opt(exponent) ^^ { + case None => "" + case Some(exponent) => exponent + } +} + +class FilterLexical extends StdLexical { + // see `token` in `Scanners` + override def token: Parser[Token] = + ( delim + | unquotedWordChar ~ rep( unquotedWordChar ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | '\"' ~ rep( quotedWordChar ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\"' ~> failure("unclosed string literal") + | failure("illegal character") + ) + + def isPrintable(ch: Char) = + !ch.isControl && !ch.isSpaceChar && !ch.isWhitespace && ch != EofCh + def isPrintableNonDelim(ch: Char) = + isPrintable(ch) && ch != '(' && ch != ')' + def unquotedWordChar = elem("unquoted word char", + ch => ch != '"' && isPrintableNonDelim(ch)) + def quotedWordChar = elem("quoted word char", + ch => ch != '"' && ch != '\n' && ch != EofCh) + + // // see `whitespace in `Scanners` + // def whitespace: Parser[Any] = rep( + // whitespaceChar + // // | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + // ) + + override protected def processIdent(name: String) = + if (reserved contains name) Keyword(name) else StringLit(name) +} + +object FilterParser extends StandardTokenParsers { + override val lexical = new FilterLexical + lexical.reserved ++= List("AND", "OR", "NOT") + // lexical.delimiters ++= List("&","|","!","(",")") + lexical.delimiters ++= List("(",")") + + def word = stringLit ^^ { s => EConst(Seq(s)) } + + def words = + word.+ ^^ { x => EConst(x.flatMap(_ match { case EConst(y) => y })) } + + + def parens: Parser[Expr] = "(" ~> expr <~ ")" + + def not: Parser[ENot] = "NOT" ~> term ^^ { ENot(_) } + + def term = ( words | parens | not ) + + def andexpr = term * ( + "AND" ^^^ { (a:Expr, b:Expr) => EAnd(a,b) } ) + + def orexpr = andexpr * ( + "OR" ^^^ { (a:Expr, b:Expr) => EOr(a,b) } ) + + def expr = ( orexpr | term ) + + def parse(s :String) = { + val tokens = new lexical.Scanner(s) + phrase(expr)(tokens) + } + + def apply(s: String): Expr = { + parse(s) match { + case Success(tree, _) => tree + case e: NoSuccess => + throw new IllegalArgumentException("Bad syntax: "+s) + } + } + + def test(exprstr: String, tweet: Seq[String]) = { + parse(exprstr) match { + case Success(tree, _) => + println("Tree: "+tree) + val v = tree.matches(tweet) + println("Eval: "+v) + case e: NoSuccess => Console.err.println(e) + } + } + + //A main method for testing + def main(args: Array[String]) = test(args(0), args(1).split("""\s""")) +}

+ * A result must be within 2 ulps of the correctly rounded result. Results + * must be semi-monotonic. + * + * @param y - the ordinate coordinate + * @param x - the abscissa coordinate + * @return the theta component of the point (r, theta) in polar + * coordinates that corresponds to the point (x, y) in Cartesian coordinates. + */ + public static double atan2(double y, double x) { + if (Double.isNaN(y) || Double.isNaN(x)) { + return Double.NaN; + } else if (Double.isInfinite(y)) { + if (y > 0.0) { + if (Double.isInfinite(x)) { + if (x > 0.0) { + return Math.PI / 4.0; + } else { + return 3.0 * Math.PI / 4.0; + } + } else if (x != 0.0) { + return Math.PI / 2.0; + } + } + else { + if (Double.isInfinite(x)) { + if (x > 0.0) { + return -Math.PI / 4.0; + } else { + return -3.0 * Math.PI / 4.0; + } + } else if (x != 0.0) { + return -Math.PI / 2.0; + } + } + } else if (y == 0.0) { + if (x > 0.0) { + return y; + } else if (x < 0.0) { + return Math.PI; + } + } else if (Double.isInfinite(x)) { + if (x > 0.0) { + if (y > 0.0) { + return 0.0; + } else if (y < 0.0) { + return -0.0; + } + } + else { + if (y > 0.0) { + return Math.PI; + } else if (y < 0.0) { + return -Math.PI; + } + } + } else if (x == 0.0) { + if (y > 0.0) { + return Math.PI / 2.0; + } else if (y < 0.0) { + return -Math.PI / 2.0; + } + } + + // Implementation a simple version ported from a PASCAL implementation at + // . + double arcTangent; + + // Use arctan() avoiding division by zero. + if (Math.abs(x) > Math.abs(y)) { + arcTangent = FastTrig.atan(y / x); + } else { + arcTangent = FastTrig.atan(x / y); // -PI/4 <= a <= PI/4. + + if (arcTangent < 0.0) { + // a is negative, so we're adding. + arcTangent = -Math.PI / 2 - arcTangent; + } else { + arcTangent = Math.PI / 2 - arcTangent; + } + } + + // Adjust result to be from [-PI, PI] + if (x < 0.0) { + if (y < 0.0) { + arcTangent = arcTangent - Math.PI; + } else { + arcTangent = arcTangent + Math.PI; + } + } + + return arcTangent; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/IOUtil.java b/src/main/java/opennlp/fieldspring/tr/util/IOUtil.java new file mode 100644 index 0000000..dcae2dd --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/IOUtil.java @@ -0,0 +1,112 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2007 Jason Baldridge, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.*; + +/** + * Handy methods for interacting with the file system and running commands. + * + * @author Jason Baldridge + * @created April 15, 2004 + */ +public class IOUtil { + + /** + * Write a string to a specified file. + * + * @param contents The string containing the contents of the file + * @param outfile The File object identifying the location of the file + */ + public static void writeStringToFile(String contents, File outfile) { + try { + BufferedWriter bw = new BufferedWriter(new FileWriter(outfile)); + bw.write(contents); + bw.flush(); + bw.close(); + } catch (IOException ioe) { + System.out.println("Input error writing to " + outfile.getName()); + } + return; + } + + /** + * Reads a file as a single String all in one. Based on code snippet + * found on web (http://snippets.dzone.com/posts/show/1335). Reads raw bytes; + * doesn't do any conversion (e.g. using UTF-8 or whatever). + * + * @param infile The File object identifying the location of the file + */ + public static String readFileAsString(File infile) throws java.io.IOException { + byte[] buffer = new byte[(int) infile.length()]; + FileInputStream f = new FileInputStream(infile); + f.read(buffer); + return new String(buffer); + } + + /** + * Calls runCommand/2 assuming that wait=true. + * + * @param cmd The string containing the command to execute + */ + public static void runCommand (String cmd) { + runCommand(cmd, true); + } + + /** + * Run a command with the option of waiting for it to finish. + * + * @param cmd The string containing the command to execute + * @param wait True if the caller should wait for this thread to + * finish before continuing, false otherwise. + */ + public static void runCommand (String cmd, boolean wait) { + try { + System.out.println("Running command: "+ cmd); + Process proc = Runtime.getRuntime().exec(cmd); + + // This needs to be done, otherwise some processes fill up + // some Java buffer and make it so the spawned process + // doesn't complete. + BufferedReader br = + new BufferedReader(new InputStreamReader(proc.getInputStream())); + + String line = null; + while ((line = br.readLine()) != null) { + // while (br.readLine() != null) { + // just eat up the inputstream + + // Use this if you want to see the output from running + // the command. + System.out.println(line); + } + + if (wait) { + try { + proc.waitFor(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + proc.getInputStream().close(); + proc.getOutputStream().close(); + proc.getErrorStream().close(); + } catch (IOException e) { + System.out.println("Unable to run command: "+cmd); + } + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java b/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java new file mode 100644 index 0000000..e74b7a4 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java @@ -0,0 +1,387 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Taesun Moon, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamWriter; + +//import opennlp.fieldspring.old.topostructs.Coordinate; +import opennlp.fieldspring.tr.topo.*; +import java.text.*; + +/** + * Class of static methods for generating KML headers, footers and other stuff. + * + * @author tsmoon + */ +public class KMLUtil { + // Minimum number of pixels the (small) square region (NOT our Region) represented by each city must occupy on the screen for its label to appear: + private final static int MIN_LOD_PIXELS = 16; + + //public static double RADIUS = .15; + //public static double SPIRAL_RADIUS = .2; + public static double RADIUS = .05; + public static double SPIRAL_RADIUS = .05; + public static int SIDES = 10; + public static int BARSCALE = 50000; + + private final static DecimalFormat df = new DecimalFormat("#.####"); + + public static void writeWithCharacters(XMLStreamWriter w, String localName, String text) + throws XMLStreamException { + w.writeStartElement(localName); + w.writeCharacters(text); + w.writeEndElement(); + } + + public static void writeHeader(XMLStreamWriter w, String name) + throws XMLStreamException { + w.writeStartDocument("UTF-8", "1.0"); + w.writeStartElement("kml"); + w.writeDefaultNamespace("http://www.opengis.net/kml/2.2"); + w.writeNamespace("gx", "http://www.google.com/kml/ext/2.2"); + w.writeNamespace("kml", "http://www.opengis.net/kml/2.2"); + w.writeNamespace("atom", "http://www.w3.org/2005/Atom"); + w.writeStartElement("Document"); + + w.writeStartElement("Style"); + w.writeAttribute("id", "bar"); + w.writeStartElement("PolyStyle"); + KMLUtil.writeWithCharacters(w, "outline", "0"); + w.writeEndElement(); // PolyStyle + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ff0000ff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeStartElement("IconStyle"); + w.writeEmptyElement("Icon"); + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "whiteLine"); + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ffffffff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "redLine"); + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ff0000ff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "blue"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/blu-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "yellow"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/ylw-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "green"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/grn-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "white"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/wht-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "downArrowIcon"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/pal4/icon28.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "smallDownArrowIcon"); + w.writeStartElement("IconStyle"); + KMLUtil.writeWithCharacters(w, "scale", "0.25"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/pal4/icon28.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "noIcon"); + w.writeStartElement("IconStyle"); + w.writeEmptyElement("Icon"); + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "context"); + w.writeStartElement("LabelStyle"); + KMLUtil.writeWithCharacters(w, "scale", "0"); + w.writeEndElement(); // LabelStyle + w.writeEndElement(); // Style + + w.writeStartElement("Folder"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeWithCharacters(w, "open", "1"); + KMLUtil.writeWithCharacters(w, "description", "Distribution of place names found in " + name); + KMLUtil.writeLookAt(w, 42, -102, 0, 5000000, 53.45434856, 0); + } + + public static void writeFooter(XMLStreamWriter w) + throws XMLStreamException { + w.writeEndElement(); // Folder + w.writeEndElement(); // Document + w.writeEndElement(); // kml + } + + public static void writeCoordinate(XMLStreamWriter w, Coordinate coord, + int sides, double radius, double height) + throws XMLStreamException { + final double radianUnit = 2 * Math.PI / sides; + final double startRadian = radianUnit / 2; + + w.writeStartElement("coordinates"); + w.writeCharacters("\n"); + //for (double currentRadian = startRadian; currentRadian <= startRadian + 2 * Math.PI; currentRadian += radianUnit) { + for (double currentRadian = startRadian; currentRadian >= startRadian - 2 * Math.PI; currentRadian -= radianUnit) { + double lat = coord.getLatDegrees() + radius * Math.cos(currentRadian); + double lon = coord.getLngDegrees() + radius * Math.sin(currentRadian); + w.writeCharacters(df.format(lon) + "," + df.format(lat) + "," + df.format(height) + "\n"); + } + + w.writeEndElement(); // coordinates + } + + public static void writeRegion(XMLStreamWriter w, Coordinate coord, double radius) + throws XMLStreamException { + w.writeStartElement("Region"); + w.writeStartElement("LatLonAltBox"); + KMLUtil.writeWithCharacters(w, "north", df.format(coord.getLatDegrees() + radius)); + KMLUtil.writeWithCharacters(w, "south", df.format(coord.getLatDegrees() - radius)); + KMLUtil.writeWithCharacters(w, "east", df.format(coord.getLngDegrees() + radius)); + KMLUtil.writeWithCharacters(w, "west", df.format(coord.getLngDegrees() - radius)); + w.writeEndElement(); // LatLonAltBox + w.writeStartElement("Lod"); + KMLUtil.writeWithCharacters(w, "minLodPixels", Integer.toString(KMLUtil.MIN_LOD_PIXELS)); + w.writeEndElement(); // Lod + w.writeEndElement(); // Region + } + + public static void writeLookAt(XMLStreamWriter w, double lat, double lon, + double alt, double range, double tilt, double heading) + throws XMLStreamException { + w.writeStartElement("LookAt"); + KMLUtil.writeWithCharacters(w, "latitude", "" + lat); + KMLUtil.writeWithCharacters(w, "longitude", "" + lon); + KMLUtil.writeWithCharacters(w, "altitude", "" + alt); + KMLUtil.writeWithCharacters(w, "range", "" + range); + KMLUtil.writeWithCharacters(w, "tilt", "" + tilt); + KMLUtil.writeWithCharacters(w, "heading", "" + heading); + w.writeEndElement(); // LookAt + } + + public static void writePlacemark(XMLStreamWriter w, String name, + Coordinate coord, double radius) throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writePinPlacemark(XMLStreamWriter w, String name, + Coordinate coord) throws XMLStreamException { + writePinPlacemark(w, name, coord, null); + } + + public static void writePinTimeStampPlacemark(XMLStreamWriter w, String name, + Coordinate coord, String context, + int timeIndex) throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeWithCharacters(w, "description", context); + w.writeStartElement("TimeSpan"); + KMLUtil.writeWithCharacters(w, "begin", timeIndex + ""); + KMLUtil.writeWithCharacters(w, "end", (timeIndex + 5) + ""); + w.writeEndElement(); // TimeSpan + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writePinPlacemark(XMLStreamWriter w, String name, + Coordinate coord, String styleUrl) throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + //KMLUtil.writeRegion(w, coord, radius); + if(styleUrl != null && styleUrl.length() > 0) + KMLUtil.writeWithCharacters(w, "styleUrl", "#"+styleUrl); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writeLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2) + throws XMLStreamException { + writeLinePlacemark(w, coord1, coord2, "whiteLine"); + } + + public static void writeLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2, String styleUrl) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#"+styleUrl); + w.writeStartElement("LineString"); + KMLUtil.writeWithCharacters(w, "gx:altitudeOffset", "0"); + KMLUtil.writeWithCharacters(w, "extrude", "1"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "clampToGround"); + KMLUtil.writeWithCharacters(w, "gx:drawOrder", "0"); + w.writeStartElement("coordinates"); + w.writeCharacters(df.format(coord1.getLngDegrees())+","+df.format(coord1.getLatDegrees())+",0\n"); + w.writeCharacters(df.format(coord2.getLngDegrees())+","+df.format(coord2.getLatDegrees())+",0\n"); + w.writeEndElement(); // coordinates + w.writeEndElement(); // LineString + w.writeEndElement(); // Placemark + } + + public static void writeArcLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("LineString"); + KMLUtil.writeWithCharacters(w, "gx:altitudeOffset", "0"); + KMLUtil.writeWithCharacters(w, "extrude", "0"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + KMLUtil.writeWithCharacters(w, "gx:drawOrder", "0"); + w.writeStartElement("coordinates"); + double dist = coord1.distanceInKm(coord2); + double lngDiff = coord2.getLngDegrees() - coord1.getLngDegrees(); + double latDiff = coord2.getLatDegrees() - coord1.getLatDegrees(); + w.writeCharacters(df.format(coord1.getLngDegrees())+","+df.format(coord1.getLatDegrees())+",0\n"); + for(double i = .1; i <= .9; i += .1) { + w.writeCharacters(df.format(coord1.getLngDegrees()+lngDiff*i)+"," + +df.format(coord1.getLatDegrees()+latDiff*i)+"," + +(dist*dist)*.05/**(.5-Math.abs(.5-i))*/+"\n"); + } + //w.writeCharacters(df.format(coord1.getLngDegrees()+lngDiff*.9)+"," + // +df.format(coord1.getLatDegrees()+latDiff*.9)+","+(dist*20)+"\n"); + w.writeCharacters(df.format(coord2.getLngDegrees())+","+df.format(coord2.getLatDegrees())+",0"); + w.writeEndElement(); // coordinates + w.writeEndElement(); // LineString + w.writeEndElement(); // Placemark + } + + public static void writePolygon(XMLStreamWriter w, String name, + Coordinate coord, int sides, double radius, double height) + throws XMLStreamException { + /*w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark*/ + + writePlacemark(w, name, coord, radius); + + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name + " POLYGON"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Style"); + w.writeStartElement("PolyStyle"); + KMLUtil.writeWithCharacters(w, "color", "dc0155ff"); + w.writeEndElement(); // PolyStyle + w.writeEndElement(); // Style + w.writeStartElement("Polygon"); + KMLUtil.writeWithCharacters(w, "extrude", "1"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + w.writeStartElement("outerBoundaryIs"); + w.writeStartElement("LinearRing"); + KMLUtil.writeCoordinate(w, coord, sides, radius, height); + w.writeEndElement(); // LinearRing + w.writeEndElement(); // outerBoundaryIs + w.writeEndElement(); // Polygon + w.writeEndElement(); // Placemark + } + + public static void writeSpiralPoint(XMLStreamWriter w, String name, int id, + String context, Coordinate coord, double radius) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", String.format("%s #%d", name, id + 1)); + KMLUtil.writeWithCharacters(w, "description", context); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#context"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writeFloatingPlacemark(XMLStreamWriter w, String name, + Coordinate coord, double height) + throws XMLStreamException { + KMLUtil.writeFloatingPlacemark(w, name, coord, height, "#smallDownArrowIcon"); + } + + public static void writeFloatingPlacemark(XMLStreamWriter w, String name, + Coordinate coord, double height, + String styleUrl) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeWithCharacters(w, "visibility", "1"); + KMLUtil.writeLookAt(w, coord.getLatDegrees(), coord.getLngDegrees(), height, 500.656, 40.557, -148.412); + KMLUtil.writeWithCharacters(w, "styleUrl", styleUrl); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + KMLUtil.writeWithCharacters(w, "coordinates", + df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees()) + "," + df.format(height)); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java~ b/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java~ new file mode 100644 index 0000000..b32db9f --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/KMLUtil.java~ @@ -0,0 +1,362 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Taesun Moon, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.util; + +import javax.xml.stream.XMLStreamException; +import javax.xml.stream.XMLStreamWriter; + +//import opennlp.fieldspring.old.topostructs.Coordinate; +import opennlp.fieldspring.tr.topo.*; +import java.text.*; + +/** + * Class of static methods for generating KML headers, footers and other stuff. + * + * @author tsmoon + */ +public class KMLUtil { + // Minimum number of pixels the (small) square region (NOT our Region) represented by each city must occupy on the screen for its label to appear: + private final static int MIN_LOD_PIXELS = 16; + + //public static double RADIUS = .15; + //public static double SPIRAL_RADIUS = .2; + public static double RADIUS = .05; + public static double SPIRAL_RADIUS = .05; + public static int SIDES = 10; + public static int BARSCALE = 50000; + + private final static DecimalFormat df = new DecimalFormat("#.####"); + + public static void writeWithCharacters(XMLStreamWriter w, String localName, String text) + throws XMLStreamException { + w.writeStartElement(localName); + w.writeCharacters(text); + w.writeEndElement(); + } + + public static void writeHeader(XMLStreamWriter w, String name) + throws XMLStreamException { + w.writeStartDocument("UTF-8", "1.0"); + w.writeStartElement("kml"); + w.writeDefaultNamespace("http://www.opengis.net/kml/2.2"); + w.writeNamespace("gx", "http://www.google.com/kml/ext/2.2"); + w.writeNamespace("kml", "http://www.opengis.net/kml/2.2"); + w.writeNamespace("atom", "http://www.w3.org/2005/Atom"); + w.writeStartElement("Document"); + + w.writeStartElement("Style"); + w.writeAttribute("id", "bar"); + w.writeStartElement("PolyStyle"); + KMLUtil.writeWithCharacters(w, "outline", "0"); + w.writeEndElement(); // PolyStyle + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ff0000ff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeStartElement("IconStyle"); + w.writeEmptyElement("Icon"); + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "whiteLine"); + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ffffffff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "redLine"); + w.writeStartElement("LineStyle"); + KMLUtil.writeWithCharacters(w, "color", "ff0000ff"); + KMLUtil.writeWithCharacters(w, "width", "3"); + w.writeEndElement(); // LineStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "blue"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/blu-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "yellow"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/ylw-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "green"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/paddle/grn-blank-lv.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "downArrowIcon"); + w.writeStartElement("IconStyle"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/pal4/icon28.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "smallDownArrowIcon"); + w.writeStartElement("IconStyle"); + KMLUtil.writeWithCharacters(w, "scale", "0.25"); + w.writeStartElement("Icon"); + KMLUtil.writeWithCharacters(w, "href", "http://maps.google.com/mapfiles/kml/pal4/icon28.png"); + w.writeEndElement(); // Icon + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "noIcon"); + w.writeStartElement("IconStyle"); + w.writeEmptyElement("Icon"); + w.writeEndElement(); // IconStyle + w.writeEndElement(); // Style + + w.writeStartElement("Style"); + w.writeAttribute("id", "context"); + w.writeStartElement("LabelStyle"); + KMLUtil.writeWithCharacters(w, "scale", "0"); + w.writeEndElement(); // LabelStyle + w.writeEndElement(); // Style + + w.writeStartElement("Folder"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeWithCharacters(w, "open", "1"); + KMLUtil.writeWithCharacters(w, "description", "Distribution of place names found in " + name); + KMLUtil.writeLookAt(w, 42, -102, 0, 5000000, 53.45434856, 0); + } + + public static void writeFooter(XMLStreamWriter w) + throws XMLStreamException { + w.writeEndElement(); // Folder + w.writeEndElement(); // Document + w.writeEndElement(); // kml + } + + public static void writeCoordinate(XMLStreamWriter w, Coordinate coord, + int sides, double radius, double height) + throws XMLStreamException { + final double radianUnit = 2 * Math.PI / sides; + final double startRadian = radianUnit / 2; + + w.writeStartElement("coordinates"); + w.writeCharacters("\n"); + //for (double currentRadian = startRadian; currentRadian <= startRadian + 2 * Math.PI; currentRadian += radianUnit) { + for (double currentRadian = startRadian; currentRadian >= startRadian - 2 * Math.PI; currentRadian -= radianUnit) { + double lat = coord.getLatDegrees() + radius * Math.cos(currentRadian); + double lon = coord.getLngDegrees() + radius * Math.sin(currentRadian); + w.writeCharacters(df.format(lon) + "," + df.format(lat) + "," + df.format(height) + "\n"); + } + + w.writeEndElement(); // coordinates + } + + public static void writeRegion(XMLStreamWriter w, Coordinate coord, double radius) + throws XMLStreamException { + w.writeStartElement("Region"); + w.writeStartElement("LatLonAltBox"); + KMLUtil.writeWithCharacters(w, "north", df.format(coord.getLatDegrees() + radius)); + KMLUtil.writeWithCharacters(w, "south", df.format(coord.getLatDegrees() - radius)); + KMLUtil.writeWithCharacters(w, "east", df.format(coord.getLngDegrees() + radius)); + KMLUtil.writeWithCharacters(w, "west", df.format(coord.getLngDegrees() - radius)); + w.writeEndElement(); // LatLonAltBox + w.writeStartElement("Lod"); + KMLUtil.writeWithCharacters(w, "minLodPixels", Integer.toString(KMLUtil.MIN_LOD_PIXELS)); + w.writeEndElement(); // Lod + w.writeEndElement(); // Region + } + + public static void writeLookAt(XMLStreamWriter w, double lat, double lon, + double alt, double range, double tilt, double heading) + throws XMLStreamException { + w.writeStartElement("LookAt"); + KMLUtil.writeWithCharacters(w, "latitude", "" + lat); + KMLUtil.writeWithCharacters(w, "longitude", "" + lon); + KMLUtil.writeWithCharacters(w, "altitude", "" + alt); + KMLUtil.writeWithCharacters(w, "range", "" + range); + KMLUtil.writeWithCharacters(w, "tilt", "" + tilt); + KMLUtil.writeWithCharacters(w, "heading", "" + heading); + w.writeEndElement(); // LookAt + } + + public static void writePlacemark(XMLStreamWriter w, String name, + Coordinate coord, double radius) throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writePinPlacemark(XMLStreamWriter w, String name, + Coordinate coord) throws XMLStreamException { + writePinPlacemark(w, name, coord, null); + } + + public static void writePinPlacemark(XMLStreamWriter w, String name, + Coordinate coord, String styleUrl) throws XMLStreamException { + w.writeStartElement("Placemark"); + //KMLUtil.writeWithCharacters(w, "name", name); + //KMLUtil.writeRegion(w, coord, radius); + if(styleUrl != null && styleUrl.length() > 0) + KMLUtil.writeWithCharacters(w, "styleUrl", "#"+styleUrl); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writeLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2) + throws XMLStreamException { + writeLinePlacemark(w, coord1, coord2, "whiteLine"); + } + + public static void writeLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2, String styleUrl) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#"+styleUrl); + w.writeStartElement("LineString"); + KMLUtil.writeWithCharacters(w, "gx:altitudeOffset", "0"); + KMLUtil.writeWithCharacters(w, "extrude", "1"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "clampToGround"); + KMLUtil.writeWithCharacters(w, "gx:drawOrder", "0"); + w.writeStartElement("coordinates"); + w.writeCharacters(df.format(coord1.getLngDegrees())+","+df.format(coord1.getLatDegrees())+",0\n"); + w.writeCharacters(df.format(coord2.getLngDegrees())+","+df.format(coord2.getLatDegrees())+",0\n"); + w.writeEndElement(); // coordinates + w.writeEndElement(); // LineString + w.writeEndElement(); // Placemark + } + + public static void writeArcLinePlacemark(XMLStreamWriter w, Coordinate coord1, Coordinate coord2) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("LineString"); + KMLUtil.writeWithCharacters(w, "gx:altitudeOffset", "0"); + KMLUtil.writeWithCharacters(w, "extrude", "0"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + KMLUtil.writeWithCharacters(w, "gx:drawOrder", "0"); + w.writeStartElement("coordinates"); + double dist = coord1.distanceInKm(coord2); + double lngDiff = coord2.getLngDegrees() - coord1.getLngDegrees(); + double latDiff = coord2.getLatDegrees() - coord1.getLatDegrees(); + w.writeCharacters(df.format(coord1.getLngDegrees())+","+df.format(coord1.getLatDegrees())+",0\n"); + for(double i = .1; i <= .9; i += .1) { + w.writeCharacters(df.format(coord1.getLngDegrees()+lngDiff*i)+"," + +df.format(coord1.getLatDegrees()+latDiff*i)+"," + +(dist*dist)*.05/**(.5-Math.abs(.5-i))*/+"\n"); + } + //w.writeCharacters(df.format(coord1.getLngDegrees()+lngDiff*.9)+"," + // +df.format(coord1.getLatDegrees()+latDiff*.9)+","+(dist*20)+"\n"); + w.writeCharacters(df.format(coord2.getLngDegrees())+","+df.format(coord2.getLatDegrees())+",0"); + w.writeEndElement(); // coordinates + w.writeEndElement(); // LineString + w.writeEndElement(); // Placemark + } + + public static void writePolygon(XMLStreamWriter w, String name, + Coordinate coord, int sides, double radius, double height) + throws XMLStreamException { + /*w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark*/ + + writePlacemark(w, name, coord, radius); + + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name + " POLYGON"); + KMLUtil.writeWithCharacters(w, "styleUrl", "#bar"); + w.writeStartElement("Style"); + w.writeStartElement("PolyStyle"); + KMLUtil.writeWithCharacters(w, "color", "dc0155ff"); + w.writeEndElement(); // PolyStyle + w.writeEndElement(); // Style + w.writeStartElement("Polygon"); + KMLUtil.writeWithCharacters(w, "extrude", "1"); + KMLUtil.writeWithCharacters(w, "tessellate", "1"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + w.writeStartElement("outerBoundaryIs"); + w.writeStartElement("LinearRing"); + KMLUtil.writeCoordinate(w, coord, sides, radius, height); + w.writeEndElement(); // LinearRing + w.writeEndElement(); // outerBoundaryIs + w.writeEndElement(); // Polygon + w.writeEndElement(); // Placemark + } + + public static void writeSpiralPoint(XMLStreamWriter w, String name, int id, + String context, Coordinate coord, double radius) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", String.format("%s #%d", name, id + 1)); + KMLUtil.writeWithCharacters(w, "description", context); + KMLUtil.writeRegion(w, coord, radius); + KMLUtil.writeWithCharacters(w, "styleUrl", "#context"); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "coordinates", df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees())); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } + + public static void writeFloatingPlacemark(XMLStreamWriter w, String name, + Coordinate coord, double height) + throws XMLStreamException { + KMLUtil.writeFloatingPlacemark(w, name, coord, height, "#smallDownArrowIcon"); + } + + public static void writeFloatingPlacemark(XMLStreamWriter w, String name, + Coordinate coord, double height, + String styleUrl) + throws XMLStreamException { + w.writeStartElement("Placemark"); + KMLUtil.writeWithCharacters(w, "name", name); + KMLUtil.writeWithCharacters(w, "visibility", "1"); + KMLUtil.writeLookAt(w, coord.getLatDegrees(), coord.getLngDegrees(), height, 500.656, 40.557, -148.412); + KMLUtil.writeWithCharacters(w, "styleUrl", styleUrl); + w.writeStartElement("Point"); + KMLUtil.writeWithCharacters(w, "altitudeMode", "relativeToGround"); + KMLUtil.writeWithCharacters(w, "coordinates", + df.format(coord.getLngDegrees()) + "," + df.format(coord.getLatDegrees()) + "," + df.format(height)); + w.writeEndElement(); // Point + w.writeEndElement(); // Placemark + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/Lexicon.java b/src/main/java/opennlp/fieldspring/tr/util/Lexicon.java new file mode 100644 index 0000000..1633374 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/Lexicon.java @@ -0,0 +1,33 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.Serializable; +import java.util.List; + +public interface Lexicon + extends Serializable, Iterable { + public boolean contains(A entry); + public int get(A entry); + public int getOrAdd(A entry); + public A atIndex(int index); + public int size(); + public boolean isGrowing(); + public void stopGrowing(); + public void startGrowing(); + public List concatenate(Lexicon other); +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/MemoryUtil.java b/src/main/java/opennlp/fieldspring/tr/util/MemoryUtil.java new file mode 100644 index 0000000..a9fb2d8 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/MemoryUtil.java @@ -0,0 +1,34 @@ +/* + * This class contains tools for checking memory usage during runtime. + */ + +package opennlp.fieldspring.tr.util; + +public class MemoryUtil { + public static long getMemoryUsage(){ + takeOutGarbage(); + long totalMemory = Runtime.getRuntime().totalMemory(); + + takeOutGarbage(); + long freeMemory = Runtime.getRuntime().freeMemory(); + + return (totalMemory - freeMemory); + } + + private static void takeOutGarbage() { + collectGarbage(); + collectGarbage(); + } + + private static void collectGarbage() { + try { + System.gc(); + Thread.currentThread().sleep(100); + System.runFinalization(); + Thread.currentThread().sleep(100); + } + catch (Exception ex){ + ex.printStackTrace(); + } + } +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/SimpleCountingLexicon.java b/src/main/java/opennlp/fieldspring/tr/util/SimpleCountingLexicon.java new file mode 100644 index 0000000..7509ed2 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/SimpleCountingLexicon.java @@ -0,0 +1,89 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class SimpleCountingLexicon + extends SimpleLexicon implements CountingLexicon { + + private static final long serialVersionUID = 42L; + + private ArrayList counts; + + public SimpleCountingLexicon(int capacity) { + super(capacity); + this.counts = new ArrayList(capacity); + } + + public SimpleCountingLexicon() { + this(2048); + } + + public int getOrAdd(A entry) { + Integer index = this.map.get(entry); + if (index == null) { + if (this.growing) { + index = this.entries.size(); + this.map.put(entry, index); + this.entries.add(entry); + this.counts.add(1); + } else { + throw new UnsupportedOperationException("Cannot add to a non-growing lexicon."); + } + } else { + this.counts.set(index, this.counts.get(index) + 1); + } + + return index; + } + + public int count(A entry) { + return this.counts.get(this.map.get(entry)); + } + + public int countAtIndex(int index) { + return this.counts.get(index); + } + + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeBoolean(this.growing); + out.writeObject(this.entries); + out.writeObject(this.counts); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + this.growing = in.readBoolean(); + this.entries = (ArrayList) in.readObject(); + this.counts = (ArrayList) in.readObject(); + this.map = new HashMap(this.entries.size()); + + for (int i = 0; i < this.entries.size(); i++) { + this.map.put(this.entries.get(i), i); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/SimpleLexicon.java b/src/main/java/opennlp/fieldspring/tr/util/SimpleLexicon.java new file mode 100644 index 0000000..6760681 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/SimpleLexicon.java @@ -0,0 +1,121 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class SimpleLexicon implements Lexicon { + + private static final long serialVersionUID = 42L; + + protected Map map; + protected ArrayList entries; + protected boolean growing; + + public SimpleLexicon(int capacity) { + this.map = new HashMap(capacity, 0.5F); + this.entries = new ArrayList(capacity); + this.growing = true; + } + + public SimpleLexicon() { + this(2048); + } + + public boolean contains(A entry) { + return this.map.containsKey(entry); + } + + public int get(A entry) { + Integer index = this.map.get(entry); + return (index == null) ? -1 : index; + } + + public int getOrAdd(A entry) { + Integer index = this.map.get(entry); + if (index == null) { + if (this.growing) { + index = this.entries.size(); + this.map.put(entry, index); + this.entries.add(entry); + } else { + throw new UnsupportedOperationException("Cannot add to a non-growing lexicon."); + } + } + return index; + } + + public A atIndex(int index) { + return this.entries.get(index); + } + + public int size() { + return this.entries.size(); + } + + public boolean isGrowing() { + return this.growing; + } + + public void stopGrowing() { + if (this.growing) { + this.entries.trimToSize(); + this.growing = false; + } + } + + public void startGrowing() { + this.growing = true; + } + + public List concatenate(Lexicon other) { + List conversions = new ArrayList(other.size()); + for (A entry : other) { + conversions.add(this.getOrAdd(entry)); + } + return conversions; + } + + public Iterator iterator() { + return this.entries.iterator(); + } + + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeBoolean(this.growing); + out.writeObject(this.entries); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + this.growing = in.readBoolean(); + this.entries = (ArrayList) in.readObject(); + this.map = new HashMap(this.entries.size()); + + for (int i = 0; i < this.entries.size(); i++) { + this.map.put(this.entries.get(i), i); + } + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/Span.java b/src/main/java/opennlp/fieldspring/tr/util/Span.java new file mode 100644 index 0000000..56bcb9c --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/Span.java @@ -0,0 +1,48 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.util.ArrayList; +import java.util.List; +import java.io.*; + +public class Span implements Serializable { + + private static final long serialVersionUID = 42L; + + private final int start; + private final int end; + private final A item; + + public Span(int start, int end, A item) { + this.start = start; + this.end = end; + this.item = item; + } + + public int getStart() { + return this.start; + } + + public int getEnd() { + return this.end; + } + + public A getItem() { + return this.item; + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/StringDoublePair.java b/src/main/java/opennlp/fieldspring/tr/util/StringDoublePair.java new file mode 100644 index 0000000..46631de --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/StringDoublePair.java @@ -0,0 +1,37 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Taesun Moon, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// + +package opennlp.fieldspring.tr.util; + +/** + * + * @author tsmoon + */ +public class StringDoublePair implements Comparable { + + public String stringValue; + public double doubleValue; + + public StringDoublePair(String s, double d) { + stringValue = s; + doubleValue = d; + } + + public int compareTo(StringDoublePair p) { + return stringValue.compareTo(p.stringValue); + } + +} diff --git a/src/main/java/opennlp/fieldspring/tr/util/StringEditMapper.java b/src/main/java/opennlp/fieldspring/tr/util/StringEditMapper.java new file mode 100644 index 0000000..16486b1 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/StringEditMapper.java @@ -0,0 +1,62 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2010 Travis Brown, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import opennlp.fieldspring.tr.util.Span; + +public class StringEditMapper extends EditMapper { + public StringEditMapper(List s, List t) { + super(s, t); + } + + @Override + protected int delCost(String x) { + return x.length(); + } + + @Override + protected int insCost(String x) { + return x.length(); + } + + @Override + protected int subCost(String x, String y) { + int[][] ds = new int[x.length() + 1][y.length() + 1]; + for (int i = 0; i <= x.length(); i++) { ds[i][0] = i; } + for (int j = 0; j <= y.length(); j++) { ds[0][j] = j; } + + for (int i = 1; i <= x.length(); i++) { + for (int j = 1; j <= x.length(); j++) { + int del = ds[i - 1][j] + 1; + int ins = ds[1][j - 1] + 1; + int sub = ds[i - 1][j - 1] + (x.charAt(i - 1) == y.charAt(j - 1) ? 0 : 1); + ds[i][j] = StringEditMapper.minimum(del, ins, sub); + } + } + + return ds[x.length()][y.length()]; + } + + private static int minimum(int a, int b, int c) { + return Math.min(Math.min(a, b), c); + } +} + diff --git a/src/main/java/opennlp/fieldspring/tr/util/StringUtil.java b/src/main/java/opennlp/fieldspring/tr/util/StringUtil.java new file mode 100644 index 0000000..a3e4523 --- /dev/null +++ b/src/main/java/opennlp/fieldspring/tr/util/StringUtil.java @@ -0,0 +1,226 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright (C) 2007 Jason Baldridge, The University of Texas at Austin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////////////// +package opennlp.fieldspring.tr.util; + +/** + * Handy utilities for dealing with strings. + * + * @author Jason Baldridge + * @version $Revision: 1.53 $, $Date: 2006/10/12 21:20:44 $ + */ +public class StringUtil { + + public static boolean containsAlphanumeric(String s) { + for(int i = 0; i < s.length(); i++) { + if(Character.isLetterOrDigit(s.charAt(i))) + return true; + } + return false; + } + + public static String[] splitAtLast(char sep, String s) { + int lastIndex = s.lastIndexOf(sep); + String[] wordtag = {s.substring(0, lastIndex), + s.substring(lastIndex + 1, s.length())}; + return wordtag; + } + + public static String join(Object[] array) { + return join(" ", array); + } + + public static String join(String sep, Object[] array) { + StringBuilder sb = new StringBuilder(array[0].toString()); + for (int i = 1; i < array.length; i++) { + sb.append(sep).append(array[i].toString()); + } + return sb.toString(); + } + + public static String join(String[] array) { + return join(" ", array); + } + + public static String join(String sep, String[] array) { + StringBuilder sb = new StringBuilder(array[0]); + for (int i = 1; i < array.length; i++) { + sb.append(sep).append(array[i]); + } + return sb.toString(); + } + + public static String join(int[] array) { + return join(" ", array); + } + + public static String join(String sep, int[] array) { + StringBuilder sb = new StringBuilder(Integer.toString(array[0])); + for (int i = 1; i < array.length; i++) { + sb.append(sep).append(array[i]); + } + return sb.toString(); + } + + public static String join(double[] array) { + return join(" ", array); + } + + public static String join(String sep, double[] array) { + StringBuilder sb = new StringBuilder(Double.toString(array[0])); + for (int i = 1; i < array.length; i++) { + sb.append(sep).append(array[i]); + } + return sb.toString(); + } + + public static String mergeJoin(String sep, String[] a1, String[] a2) { + if (a1.length != a2.length) { + System.out.println("Unable to merge String arrays of different length!"); + return join(a1); + } + + StringBuilder sb = new StringBuilder(a1[0]); + sb.append(sep).append(a2[0]); + for (int i = 1; i < a1.length; i++) { + sb.append(" ").append(a1[i]).append(sep).append(a2[i]); + } + return sb.toString(); + } + + /** + *