Skip to content

Commit

Permalink
feat: Synthetic difference in differences (#2095)
Browse files Browse the repository at this point in the history
* Estimators for diff-in-diff, synthetic control and synthetic diff-in-diff

* add more params

* refactor

Signed-off-by: Jason Wang <[email protected]>

* adding unit tests for linalg

* more unit tests

* Unit test for DiffInDiffEstimator

* more unit tests

* unit test for SyntheticControlEstimator

* unit test for SyntheticDiffInDiffEstimator

* logClass

* Python code gen

Signed-off-by: Jason Wang <[email protected]>

* pyspark wrapper

* expose loss history

* fix bugs for synthetic control

* fix time effects for synthetic control estimator

* fix unit test

* add notebook

* fixing indexing logic

* add file headers

Signed-off-by: Jason Wang <[email protected]>

* Add feature name to logClass call

* more scalastyle fixes

* More scalastyle and unit test fixes

* Python style fix

Signed-off-by: Jason Wang <[email protected]>

* fix unit test

* fix more python style issue

* python style fix

* fix unit test

* Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala

Co-authored-by: Mark Hamilton <[email protected]>

* Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala

Co-authored-by: Mark Hamilton <[email protected]>

* Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala

Co-authored-by: Mark Hamilton <[email protected]>

* addressing comments

* extract some constants to findUnusedColumn

* Expose zeta as an optional parameter, also return the RMSE for unit weights and time weights fitting

* Replace constant TimeIdxCol and UnitIdxCol with findUnusedColumn

Signed-off-by: Jason Wang <[email protected]>

* typo

* Adding notebook to sidebar

* fix bad merge

* address code review comments

* Update docs/Explore Algorithms/Causal Inference/Quickstart - Synthetic difference in differences.ipynb

Co-authored-by: Mark Hamilton <[email protected]>

* clean synapse widget output state

* remove invalid image links

---------

Signed-off-by: Jason Wang <[email protected]>
Co-authored-by: Mark Hamilton <[email protected]>
  • Loading branch information
memoryz and mhamilton723 authored Jan 12, 2024
1 parent 8ebf298 commit cbc022c
Show file tree
Hide file tree
Showing 27 changed files with 2,710 additions and 32 deletions.
50 changes: 50 additions & 0 deletions core/src/main/python/synapse/ml/causal/DiffInDiffModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= "3":
basestring = str

from synapse.ml.causal._DiffInDiffModel import _DiffInDiffModel
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession, DataFrame
from pyspark import SparkContext, SQLContext


@inherit_doc
class DiffInDiffModel(_DiffInDiffModel):
@staticmethod
def _mapOption(option, func):
return func(option.get()) if option.isDefined() else None

@staticmethod
def _unwrapOption(option):
return DiffInDiffModel._mapOption(option, lambda x: x)

def __init__(self, java_obj=None) -> None:
super(DiffInDiffModel, self).__init__(java_obj=java_obj)

ctx = SparkContext._active_spark_context
sql_ctx = SQLContext.getOrCreate(ctx)

self.summary = java_obj.getSummary()
self.treatmentEffect = self.summary.treatmentEffect()
self.standardError = self.summary.standardError()
self.timeIntercept = DiffInDiffModel._unwrapOption(self.summary.timeIntercept())
self.unitIntercept = DiffInDiffModel._unwrapOption(self.summary.unitIntercept())
self.timeWeights = DiffInDiffModel._mapOption(
java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.unitWeights = DiffInDiffModel._mapOption(
java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.timeRMSE = DiffInDiffModel._unwrapOption(self.summary.timeRMSE())
self.unitRMSE = DiffInDiffModel._unwrapOption(self.summary.unitRMSE())
self.zeta = DiffInDiffModel._unwrapOption(self.summary.zeta())
self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryTimeWeightsJava()
)
self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryUnitWeightsJava()
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.causal.linalg.DVector
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.DataFrameParam
import org.apache.spark.SparkException
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql.types.{BooleanType, NumericType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.util

abstract class BaseDiffInDiffEstimator(override val uid: String)
extends Estimator[DiffInDiffModel]
with DiffInDiffEstimatorParams {

private def validateFieldNumericOrBooleanType(field: StructField): Unit = {
val dataType = field.dataType
require(dataType.isInstanceOf[NumericType] || dataType == BooleanType,
s"Column ${field.name} must be numeric type or boolean type, but got $dataType instead.")
}

protected def validateFieldNumericType(field: StructField): Unit = {
val dataType = field.dataType
require(dataType.isInstanceOf[NumericType],
s"Column ${field.name} must be numeric type, but got $dataType instead.")
}

override def transformSchema(schema: StructType): StructType = {
validateFieldNumericOrBooleanType(schema(getPostTreatmentCol))
validateFieldNumericOrBooleanType(schema(getTreatmentCol))
validateFieldNumericType(schema(getOutcomeCol))
schema
}

override def copy(extra: ParamMap): Estimator[DiffInDiffModel] = defaultCopy(extra)

private[causal] val findInteractionCol = DatasetExtensions.findUnusedColumnName("interaction") _

private[causal] def fitLinearModel(df: DataFrame,
featureCols: Array[String],
fitIntercept: Boolean,
weightCol: Option[String] = None) = {

val featuresCol = DatasetExtensions.findUnusedColumnName("features", df)
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol(featuresCol)

val regression = weightCol
.map(new LinearRegression().setWeightCol)
.getOrElse(new LinearRegression())

regression
.setFeaturesCol(featuresCol)
.setLabelCol(getOutcomeCol)
.setFitIntercept(fitIntercept)
.setLoss("squaredError")
.setRegParam(1E-10)

assembler.transform _ andThen regression.fit apply df
}
}

case class DiffInDiffSummary(treatmentEffect: Double, standardError: Double,
timeWeights: Option[DVector] = None,
timeIntercept: Option[Double] = None,
timeRMSE: Option[Double] = None,
unitWeights: Option[DVector] = None,
unitIntercept: Option[Double] = None,
unitRMSE: Option[Double] = None,
zeta: Option[Double] = None,
lossHistoryTimeWeights: Option[List[Double]] = None,
lossHistoryUnitWeights: Option[List[Double]] = None) {
import scala.collection.JavaConverters._

def getLossHistoryTimeWeightsJava: Option[util.List[Double]] = {
lossHistoryTimeWeights.map(_.asJava)
}

def getLossHistoryUnitWeightsJava: Option[util.List[Double]] = {
lossHistoryUnitWeights.map(_.asJava)
}
}

class DiffInDiffModel(override val uid: String)
extends Model[DiffInDiffModel]
with HasUnitCol
with HasTimeCol
with Wrappable
with ComplexParamsWritable
with SynapseMLLogging {

logClass(FeatureNames.Causal)

final val timeIndex = new DataFrameParam(this, "timeIndex", "time index")
def getTimeIndex: DataFrame = $(timeIndex)
def setTimeIndex(value: DataFrame): this.type = set(timeIndex, value)

final val timeIndexCol = new Param[String](this, "timeIndexCol", "time index column")
def getTimeIndexCol: String = $(timeIndexCol)
def setTimeIndexCol(value: String): this.type = set(timeIndexCol, value)

final val unitIndex = new DataFrameParam(this, "unitIndex", "unit index")
def getUnitIndex: DataFrame = $(unitIndex)
def setUnitIndex(value: DataFrame): this.type = set(unitIndex, value)

final val unitIndexCol = new Param[String](this, "unitIndexCol", "unit index column")
def getUnitIndexCol: String = $(unitIndexCol)
def setUnitIndexCol(value: String): this.type = set(unitIndexCol, value)

override protected lazy val pyInternalWrapper = true

def this() = this(Identifiable.randomUID("DiffInDiffModel"))

private final var summary: Option[DiffInDiffSummary] = None

def getSummary: DiffInDiffSummary = summary.getOrElse {
throw new SparkException(
s"No summary available for this ${this.getClass.getSimpleName}")
}

private[causal] def setSummary(summary: Option[DiffInDiffSummary]): this.type = {
this.summary = summary
this
}

override def copy(extra: ParamMap): DiffInDiffModel = {
copyValues(new DiffInDiffModel(uid), extra)
.setSummary(this.summary)
.setParent(parent)
}

override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF

override def transformSchema(schema: StructType): StructType = schema

def getTimeWeights: Option[DataFrame] = {
(get(timeIndex), getSummary.timeWeights) match {
case (Some(idxDf), Some(timeWeights)) =>
Some(
idxDf.join(timeWeights, idxDf(getTimeIndexCol) === timeWeights("i"), "left_outer")
.select(
idxDf(getTimeCol),
timeWeights("value")
)
)
case _ =>
None
}
}

def getUnitWeights: Option[DataFrame] = {
(get(unitIndex), getSummary.unitWeights) match {
case (Some(idxDf), Some(unitWeights)) =>
Some(
idxDf.join(unitWeights, idxDf(getUnitIndexCol) === unitWeights("i"), "left_outer")
.select(
idxDf(getUnitCol),
unitWeights("value")
)
)
case _ =>
None
}
}
}

object DiffInDiffModel extends ComplexParamsReadable[DiffInDiffModel]

trait DiffInDiffEstimatorParams extends Params
with HasTreatmentCol
with HasOutcomeCol
with HasPostTreatmentCol
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import breeze.linalg.{DenseVector => BDV}
import com.microsoft.azure.synapse.ml.causal.linalg.DVector
trait CacheOps[T] {
def checkpoint(data: T): T = data
def cache(data: T): T = data
}

object BDVCacheOps extends CacheOps[BDV[Double]] {
override def checkpoint(data: BDV[Double]): BDV[Double] = data
override def cache(data: BDV[Double]): BDV[Double] = data
}

object DVectorCacheOps extends CacheOps[DVector] {
override def checkpoint(data: DVector): DVector = data.localCheckpoint(true)
override def cache(data: DVector): DVector = data.cache
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._

class DiffInDiffEstimator(override val uid: String)
extends BaseDiffInDiffEstimator(uid)
with ComplexParamsWritable
with Wrappable
with SynapseMLLogging {

logClass(FeatureNames.Causal)

def this() = this(Identifiable.randomUID("DiffInDiffEstimator"))

override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({
val interactionCol = findInteractionCol(dataset.columns.toSet)
val postTreatment = col(getPostTreatmentCol)
val treatment = col(getTreatmentCol)
val outcome = col(getOutcomeCol)

val didData = dataset.select(
postTreatment.cast(IntegerType).as(getPostTreatmentCol),
treatment.cast(IntegerType).as(getTreatmentCol),
outcome.cast(DoubleType).as(getOutcomeCol)
)
.withColumn(interactionCol, treatment * postTreatment)

val linearModel = fitLinearModel(
didData,
Array(getPostTreatmentCol, getTreatmentCol, interactionCol),
fitIntercept = true
)

val treatmentEffect = linearModel.coefficients(2)
val standardError = linearModel.summary.coefficientStandardErrors(2)
val summary = DiffInDiffSummary(treatmentEffect, standardError)

copyValues(new DiffInDiffModel(this.uid))
.setSummary(Some(summary))
.setParent(this)
}, dataset.columns.length)
}

object DiffInDiffEstimator extends ComplexParamsReadable[DiffInDiffEstimator]
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasLabelCol, HasWeightCol}
import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasWeightCol}
import com.microsoft.azure.synapse.ml.param.EstimatorParam
import org.apache.spark.ml.classification.{LogisticRegression, ProbabilisticClassifier}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.ParamInjections.HasParallelismInjected
import org.apache.spark.ml.param.shared.{HasMaxIter, HasPredictionCol}
import org.apache.spark.ml.param.{DoubleArrayParam, DoubleParam, Param, Params}
import org.apache.spark.ml.classification.{LogisticRegression, ProbabilisticClassifier}
import org.apache.spark.ml.param.shared.HasMaxIter
import org.apache.spark.ml.param.{DoubleArrayParam, DoubleParam, Params}
import org.apache.spark.ml.regression.Regressor

trait HasTreatmentCol extends Params {
val treatmentCol = new Param[String](this, "treatmentCol", "treatment column")
def getTreatmentCol: String = $(treatmentCol)

/**
* Set name of the column which will be used as treatment
*
* @group setParam
*/
def setTreatmentCol(value: String): this.type = set(treatmentCol, value)
}

trait HasOutcomeCol extends Params {
val outcomeCol: Param[String] = new Param[String](this, "outcomeCol", "outcome column")
def getOutcomeCol: String = $(outcomeCol)

/**
* Set name of the column which will be used as outcome
*
* @group setParam
*/
def setOutcomeCol(value: String): this.type = set(outcomeCol, value)
}
import org.apache.spark.ml.{Estimator, Model}

trait DoubleMLParams extends Params
with HasTreatmentCol with HasOutcomeCol with HasFeaturesCol
Expand Down Expand Up @@ -86,7 +62,7 @@ trait DoubleMLParams extends Params
def setSampleSplitRatio(value: Array[Double]): this.type = set(sampleSplitRatio, value)

private[causal] object DoubleMLModelTypes extends Enumeration {
type TreatmentType = Value
type DoubleMLModelTypes = Value
val Binary, Continuous = Value
}

Expand Down
Loading

0 comments on commit cbc022c

Please sign in to comment.