Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code improvements in Scala reference implementation #5

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ object MailCount {
// extract month from time string
m._1.substring(0, 7),
// extract email address from sender
m._2.substring(m._2.lastIndexOf("<") + 1, m._2.length - 1) ) }
m._2.substring(m._2.lastIndexOf("<") + 1, m._2.length - 1),
// add counter to each record
1) }
// group by month and sender and count the number of records per group
.groupBy(0, 1).reduceGroup { ms => ms.foldLeft(("","",0))( (c, m) => (m._1, m._2, c._3+1)) }
.groupBy(0, 1).sum(2)
// print the result
.print

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ object ReplyGraph {
def main(args: Array[String]) {

// parse parameters
val params = ParameterTool.fromArgs(args);
val input = params.getRequired("input");
val params = ParameterTool.fromArgs(args)
val input = params.getRequired("input")

// set up the execution environment
val env = ExecutionEnvironment.getExecutionEnvironment
Expand All @@ -59,11 +59,11 @@ object ReplyGraph {

// compute reply connections by joining on messageId and reply-to
val replyConnections = addressMails
.join(addressMails).where(2).equalTo(0) { (l,r) => (l._2, r._2) }
.join(addressMails).where(2).equalTo(0) { (l,r) => (l._2, r._2, 1) }

// count connections for each pair of addresses
replyConnections
.groupBy(0,1).reduceGroup( cs => cs.foldLeft(("","",0))( (l,r) => (r._1, r._2, l._3+1) ) )
.groupBy(0,1).sum(2)
.print

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@

package com.dataArtisans.flinkTraining.exercises.dataSetScala.tfIdf

import java.util.StringTokenizer
import java.util.regex.Pattern

import com.dataArtisans.flinkTraining.dataSetPreparation.MBoxParser
import org.apache.flink.api.common.functions.{FlatMapFunction}
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.api.scala._
import org.apache.flink.util.Collector

import scala.collection.mutable.{HashMap, HashSet}

/**
* Scala reference implementation for the "TF-IDF" exercise of the Flink training.
* The task of the exercise is to compute the TF-IDF score for words in mails of the
Expand All @@ -40,137 +37,110 @@ import scala.collection.mutable.{HashMap, HashSet}
*/
object MailTFIDF {

val STOP_WORDS: Array[String] = Array (
"the", "i", "a", "an", "at", "are", "am", "for", "and", "or", "is", "there", "it", "this",
"that", "on", "was", "by", "of", "to", "in", "to", "message", "not", "be", "with", "you",
"have", "as", "can")

def main(args: Array[String]) {

// parse paramters
val params = ParameterTool.fromArgs(args);
val input = params.getRequired("input");
// parse parameters
val params = ParameterTool.fromArgs(args)
val input = params.getRequired("input")

val stopWords = List(
"the", "i", "a", "an", "at", "are", "am", "for", "and", "or", "is", "there", "it", "this",
"that", "on", "was", "by", "of", "to", "in", "to", "message", "not", "be", "with", "you",
"have", "as", "can")

// pattern for recognizing acceptable 'word'
val wordPattern: Pattern = Pattern.compile("(\\p{Alpha})+")

// set up the execution environment
val env = ExecutionEnvironment.getExecutionEnvironment

// read messageId and body field of the input data
val mails = env.readCsvFile[(String, String)](
"/users/fhueske/data/flinkdevlistparsed/",
input,
lineDelimiter = MBoxParser.MAIL_RECORD_DELIM,
fieldDelimiter = MBoxParser.MAIL_FIELD_DELIM,
includedFields = Array(0,4)
)

// count mails in data set
val mailCnt = mails.count
val mailCnt = mails.count()

// compute term-frequency (TF)
val tf = mails
.flatMap(new TFComputer(STOP_WORDS))
val tf = mails.flatMap(new TFComputer(stopWords, wordPattern))

// compute document frequency (number of mails that contain a word at least once)
val df = mails
// extract unique words from mails
.flatMap(new UniqueWordExtractor(STOP_WORDS))
// count number of mails for each word
.groupBy(0).reduce { (l,r) => (l._1, l._2 + r._2) }

val df = mails.flatMap(new UniqueWordExtractor(stopWords, wordPattern))
// group by the words
.groupBy(0)
// count the number of documents in each group (df)
.sum(1)
// compute TF-IDF score from TF, DF, and total number of mails
val tfidf = tf.join(df).where(1).equalTo(0)
{ (l, r) => (l._1, l._2, l._3 * (mailCnt.toDouble / r._2) ) }
val tfidf = tf
.join(df)
// where "word" from tf
.where(1)
// is equal "word" from df
.equalTo(0) {
(l, r) => (l._1, l._2, l._3 * (mailCnt.toDouble / r._2))
}

// print the result
tfidf
.print
.print()

}

/**
* Computes the frequency of each word in a mail.
* Words consist only of alphabetical characters. Frequent words (stop words) are filtered out.
* extract list of unique words in each document
*
* @param stopWordsA Array of words that are filtered out.
* @param stopWords a list of stop words that should be omitted
* @param wordPattern pattern that defines a word
*/
class TFComputer(stopWordsA: Array[String])
extends FlatMapFunction[(String, String), (String, String, Int)] {

val stopWords: HashSet[String] = new HashSet[String]
val wordCounts: HashMap[String, Int] = new HashMap[String, Int]
// initialize word pattern match for sequences of alphabetical characters
val wordPattern: Pattern = Pattern.compile("(\\p{Alpha})+")

// initialize set of stop words
for(sw <- stopWordsA) {
this.stopWords.add(sw)
class UniqueWordExtractor(stopWords: List[String], wordPattern: Pattern)
extends FlatMapFunction[(String, String), (String, Int)] {

def flatMap(mail: (String, String), out: Collector[(String, Int)]): Unit = {
val output = mail._2.toLowerCase
// split the body
.split(Array(' ', '\t', '\n', '\r', '\f'))
// filter out stop words and non-words
.filter(w => !stopWords.contains(w) && wordPattern.matcher(w).matches())
// count the number of occurrences of a word in each document
.distinct
// emit every word that appeared in a document
output.foreach(m => out.collect(m, 1))
}

override def flatMap(t: (String, String), out: Collector[(String, String, Int)]): Unit = {
// clear word counts
wordCounts.clear

// split mail along whitespaces
val tokens = new StringTokenizer(t._2)
// for each word candidate
while (tokens.hasMoreTokens) {
// normalize word to lower case
val word = tokens.nextToken.toLowerCase
if (!stopWords.contains(word) && wordPattern.matcher(word).matches) {
// word candidate is not a stop word and matches the word pattern
// increase word count
val cnt = wordCounts.getOrElse(word, 0)
wordCounts.put(word, cnt+1)
}
}
// emit all word counts per document and word
for (wc <- wordCounts.iterator) {
out.collect( (t._1, wc._1, wc._2) )
}
}
}

/**
* Extracts the unique words in a mail.
* Words consist only of alphabetical characters. Frequent words (stop words) are filtered out.
* Calculate term frequency for each word per document
*
* @param stopWordsA Array of words that are filtered out.
* @param stopWords a list of stop words that should be omitted
* @param wordPattern pattern that defines a word
*/
class UniqueWordExtractor(stopWordsA: Array[String])
extends FlatMapFunction[(String, String), (String, Int) ] {

val stopWords: HashSet[String] = new HashSet[String]
val uniqueWords: HashSet[String] = new HashSet[String]
// initalize pattern to match words
val wordPattern: Pattern = Pattern.compile("(\\p{Alpha})+")

// initialize set of stop words
for(sw <- stopWordsA) {
this.stopWords.add(sw)
}

override def flatMap(t: (String, String), out: Collector[(String, Int)]): Unit = {
// clear unique words
uniqueWords.clear()

// split mail along whitespaces
val tokens = new StringTokenizer(t._2)
// for each word candidate
while(tokens.hasMoreTokens) {
// normalize word to lower case
val word = tokens.nextToken.toLowerCase
if (!stopWords.contains(word) && wordPattern.matcher(word).matches) {
// word candiate is not a stop word and matches the word pattern
uniqueWords.add(word)
}
}
class TFComputer(stopWords: List[String], wordPattern: Pattern)
extends FlatMapFunction[(String, String), (String, String, Int)] {

// emit all words that occurred at least once
for(w <- uniqueWords) {
out.collect( (w, 1) )
def flatMap(mail: (String, String), out: Collector[(String, String, Int)]): Unit = {
// extract email id
val id = mail._1
val output = mail._2.toLowerCase
// split the body
.split(Array(' ', '\t', '\n', '\r', '\f'))
// filter out stop words and non-words
.filter(w => !stopWords.contains(w) && wordPattern.matcher(w).matches())
// count the number of occurrences of a word in each document
.map(m => (m, 1)).groupBy(_._1).map {
case (item, count) => (item, count.foldLeft(0)(_ + _._2))
}
// emit the document id, a term, and its number of occurrences in the document
output.foreach(m => out.collect(id, m._1, m._2))
}

}

}
}