-
Notifications
You must be signed in to change notification settings - Fork 837
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Synthetic difference in differences (#2095)
* Estimators for diff-in-diff, synthetic control and synthetic diff-in-diff * add more params * refactor Signed-off-by: Jason Wang <[email protected]> * adding unit tests for linalg * more unit tests * Unit test for DiffInDiffEstimator * more unit tests * unit test for SyntheticControlEstimator * unit test for SyntheticDiffInDiffEstimator * logClass * Python code gen Signed-off-by: Jason Wang <[email protected]> * pyspark wrapper * expose loss history * fix bugs for synthetic control * fix time effects for synthetic control estimator * fix unit test * add notebook * fixing indexing logic * add file headers Signed-off-by: Jason Wang <[email protected]> * Add feature name to logClass call * more scalastyle fixes * More scalastyle and unit test fixes * Python style fix Signed-off-by: Jason Wang <[email protected]> * fix unit test * fix more python style issue * python style fix * fix unit test * Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala Co-authored-by: Mark Hamilton <[email protected]> * Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticControlEstimator.scala Co-authored-by: Mark Hamilton <[email protected]> * Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticDiffInDiffEstimator.scala Co-authored-by: Mark Hamilton <[email protected]> * addressing comments * extract some constants to findUnusedColumn * Expose zeta as an optional parameter, also return the RMSE for unit weights and time weights fitting * Replace constant TimeIdxCol and UnitIdxCol with findUnusedColumn Signed-off-by: Jason Wang <[email protected]> * typo * Adding notebook to sidebar * fix bad merge * address code review comments * Update docs/Explore Algorithms/Causal Inference/Quickstart - Synthetic difference in differences.ipynb Co-authored-by: Mark Hamilton <[email protected]> * clean synapse widget output state * remove invalid image links --------- Signed-off-by: Jason Wang <[email protected]> Co-authored-by: Mark Hamilton <[email protected]>
- Loading branch information
1 parent
8ebf298
commit cbc022c
Showing
27 changed files
with
2,710 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
import sys | ||
|
||
if sys.version >= "3": | ||
basestring = str | ||
|
||
from synapse.ml.causal._DiffInDiffModel import _DiffInDiffModel | ||
from pyspark.ml.common import inherit_doc | ||
from pyspark.sql import SparkSession, DataFrame | ||
from pyspark import SparkContext, SQLContext | ||
|
||
|
||
@inherit_doc | ||
class DiffInDiffModel(_DiffInDiffModel): | ||
@staticmethod | ||
def _mapOption(option, func): | ||
return func(option.get()) if option.isDefined() else None | ||
|
||
@staticmethod | ||
def _unwrapOption(option): | ||
return DiffInDiffModel._mapOption(option, lambda x: x) | ||
|
||
def __init__(self, java_obj=None) -> None: | ||
super(DiffInDiffModel, self).__init__(java_obj=java_obj) | ||
|
||
ctx = SparkContext._active_spark_context | ||
sql_ctx = SQLContext.getOrCreate(ctx) | ||
|
||
self.summary = java_obj.getSummary() | ||
self.treatmentEffect = self.summary.treatmentEffect() | ||
self.standardError = self.summary.standardError() | ||
self.timeIntercept = DiffInDiffModel._unwrapOption(self.summary.timeIntercept()) | ||
self.unitIntercept = DiffInDiffModel._unwrapOption(self.summary.unitIntercept()) | ||
self.timeWeights = DiffInDiffModel._mapOption( | ||
java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx) | ||
) | ||
self.unitWeights = DiffInDiffModel._mapOption( | ||
java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx) | ||
) | ||
self.timeRMSE = DiffInDiffModel._unwrapOption(self.summary.timeRMSE()) | ||
self.unitRMSE = DiffInDiffModel._unwrapOption(self.summary.unitRMSE()) | ||
self.zeta = DiffInDiffModel._unwrapOption(self.summary.zeta()) | ||
self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption( | ||
self.summary.getLossHistoryTimeWeightsJava() | ||
) | ||
self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption( | ||
self.summary.getLossHistoryUnitWeightsJava() | ||
) |
183 changes: 183 additions & 0 deletions
183
core/src/main/scala/com/microsoft/azure/synapse/ml/causal/BaseDiffInDiffEstimator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.causal | ||
|
||
import com.microsoft.azure.synapse.ml.causal.linalg.DVector | ||
import com.microsoft.azure.synapse.ml.codegen.Wrappable | ||
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions | ||
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} | ||
import com.microsoft.azure.synapse.ml.param.DataFrameParam | ||
import org.apache.spark.SparkException | ||
import org.apache.spark.ml.feature.VectorAssembler | ||
import org.apache.spark.ml.param.{Param, ParamMap, Params} | ||
import org.apache.spark.ml.regression.LinearRegression | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model} | ||
import org.apache.spark.sql.types.{BooleanType, NumericType, StructField, StructType} | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
||
import java.util | ||
|
||
abstract class BaseDiffInDiffEstimator(override val uid: String) | ||
extends Estimator[DiffInDiffModel] | ||
with DiffInDiffEstimatorParams { | ||
|
||
private def validateFieldNumericOrBooleanType(field: StructField): Unit = { | ||
val dataType = field.dataType | ||
require(dataType.isInstanceOf[NumericType] || dataType == BooleanType, | ||
s"Column ${field.name} must be numeric type or boolean type, but got $dataType instead.") | ||
} | ||
|
||
protected def validateFieldNumericType(field: StructField): Unit = { | ||
val dataType = field.dataType | ||
require(dataType.isInstanceOf[NumericType], | ||
s"Column ${field.name} must be numeric type, but got $dataType instead.") | ||
} | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
validateFieldNumericOrBooleanType(schema(getPostTreatmentCol)) | ||
validateFieldNumericOrBooleanType(schema(getTreatmentCol)) | ||
validateFieldNumericType(schema(getOutcomeCol)) | ||
schema | ||
} | ||
|
||
override def copy(extra: ParamMap): Estimator[DiffInDiffModel] = defaultCopy(extra) | ||
|
||
private[causal] val findInteractionCol = DatasetExtensions.findUnusedColumnName("interaction") _ | ||
|
||
private[causal] def fitLinearModel(df: DataFrame, | ||
featureCols: Array[String], | ||
fitIntercept: Boolean, | ||
weightCol: Option[String] = None) = { | ||
|
||
val featuresCol = DatasetExtensions.findUnusedColumnName("features", df) | ||
val assembler = new VectorAssembler() | ||
.setInputCols(featureCols) | ||
.setOutputCol(featuresCol) | ||
|
||
val regression = weightCol | ||
.map(new LinearRegression().setWeightCol) | ||
.getOrElse(new LinearRegression()) | ||
|
||
regression | ||
.setFeaturesCol(featuresCol) | ||
.setLabelCol(getOutcomeCol) | ||
.setFitIntercept(fitIntercept) | ||
.setLoss("squaredError") | ||
.setRegParam(1E-10) | ||
|
||
assembler.transform _ andThen regression.fit apply df | ||
} | ||
} | ||
|
||
case class DiffInDiffSummary(treatmentEffect: Double, standardError: Double, | ||
timeWeights: Option[DVector] = None, | ||
timeIntercept: Option[Double] = None, | ||
timeRMSE: Option[Double] = None, | ||
unitWeights: Option[DVector] = None, | ||
unitIntercept: Option[Double] = None, | ||
unitRMSE: Option[Double] = None, | ||
zeta: Option[Double] = None, | ||
lossHistoryTimeWeights: Option[List[Double]] = None, | ||
lossHistoryUnitWeights: Option[List[Double]] = None) { | ||
import scala.collection.JavaConverters._ | ||
|
||
def getLossHistoryTimeWeightsJava: Option[util.List[Double]] = { | ||
lossHistoryTimeWeights.map(_.asJava) | ||
} | ||
|
||
def getLossHistoryUnitWeightsJava: Option[util.List[Double]] = { | ||
lossHistoryUnitWeights.map(_.asJava) | ||
} | ||
} | ||
|
||
class DiffInDiffModel(override val uid: String) | ||
extends Model[DiffInDiffModel] | ||
with HasUnitCol | ||
with HasTimeCol | ||
with Wrappable | ||
with ComplexParamsWritable | ||
with SynapseMLLogging { | ||
|
||
logClass(FeatureNames.Causal) | ||
|
||
final val timeIndex = new DataFrameParam(this, "timeIndex", "time index") | ||
def getTimeIndex: DataFrame = $(timeIndex) | ||
def setTimeIndex(value: DataFrame): this.type = set(timeIndex, value) | ||
|
||
final val timeIndexCol = new Param[String](this, "timeIndexCol", "time index column") | ||
def getTimeIndexCol: String = $(timeIndexCol) | ||
def setTimeIndexCol(value: String): this.type = set(timeIndexCol, value) | ||
|
||
final val unitIndex = new DataFrameParam(this, "unitIndex", "unit index") | ||
def getUnitIndex: DataFrame = $(unitIndex) | ||
def setUnitIndex(value: DataFrame): this.type = set(unitIndex, value) | ||
|
||
final val unitIndexCol = new Param[String](this, "unitIndexCol", "unit index column") | ||
def getUnitIndexCol: String = $(unitIndexCol) | ||
def setUnitIndexCol(value: String): this.type = set(unitIndexCol, value) | ||
|
||
override protected lazy val pyInternalWrapper = true | ||
|
||
def this() = this(Identifiable.randomUID("DiffInDiffModel")) | ||
|
||
private final var summary: Option[DiffInDiffSummary] = None | ||
|
||
def getSummary: DiffInDiffSummary = summary.getOrElse { | ||
throw new SparkException( | ||
s"No summary available for this ${this.getClass.getSimpleName}") | ||
} | ||
|
||
private[causal] def setSummary(summary: Option[DiffInDiffSummary]): this.type = { | ||
this.summary = summary | ||
this | ||
} | ||
|
||
override def copy(extra: ParamMap): DiffInDiffModel = { | ||
copyValues(new DiffInDiffModel(uid), extra) | ||
.setSummary(this.summary) | ||
.setParent(parent) | ||
} | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF | ||
|
||
override def transformSchema(schema: StructType): StructType = schema | ||
|
||
def getTimeWeights: Option[DataFrame] = { | ||
(get(timeIndex), getSummary.timeWeights) match { | ||
case (Some(idxDf), Some(timeWeights)) => | ||
Some( | ||
idxDf.join(timeWeights, idxDf(getTimeIndexCol) === timeWeights("i"), "left_outer") | ||
.select( | ||
idxDf(getTimeCol), | ||
timeWeights("value") | ||
) | ||
) | ||
case _ => | ||
None | ||
} | ||
} | ||
|
||
def getUnitWeights: Option[DataFrame] = { | ||
(get(unitIndex), getSummary.unitWeights) match { | ||
case (Some(idxDf), Some(unitWeights)) => | ||
Some( | ||
idxDf.join(unitWeights, idxDf(getUnitIndexCol) === unitWeights("i"), "left_outer") | ||
.select( | ||
idxDf(getUnitCol), | ||
unitWeights("value") | ||
) | ||
) | ||
case _ => | ||
None | ||
} | ||
} | ||
} | ||
|
||
object DiffInDiffModel extends ComplexParamsReadable[DiffInDiffModel] | ||
|
||
trait DiffInDiffEstimatorParams extends Params | ||
with HasTreatmentCol | ||
with HasOutcomeCol | ||
with HasPostTreatmentCol |
21 changes: 21 additions & 0 deletions
21
core/src/main/scala/com/microsoft/azure/synapse/ml/causal/CacheOps.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.causal | ||
|
||
import breeze.linalg.{DenseVector => BDV} | ||
import com.microsoft.azure.synapse.ml.causal.linalg.DVector | ||
trait CacheOps[T] { | ||
def checkpoint(data: T): T = data | ||
def cache(data: T): T = data | ||
} | ||
|
||
object BDVCacheOps extends CacheOps[BDV[Double]] { | ||
override def checkpoint(data: BDV[Double]): BDV[Double] = data | ||
override def cache(data: BDV[Double]): BDV[Double] = data | ||
} | ||
|
||
object DVectorCacheOps extends CacheOps[DVector] { | ||
override def checkpoint(data: DVector): DVector = data.localCheckpoint(true) | ||
override def cache(data: DVector): DVector = data.cache | ||
} |
54 changes: 54 additions & 0 deletions
54
core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DiffInDiffEstimator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.causal | ||
|
||
import com.microsoft.azure.synapse.ml.codegen.Wrappable | ||
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions | ||
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable} | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.functions.col | ||
import org.apache.spark.sql.types._ | ||
|
||
class DiffInDiffEstimator(override val uid: String) | ||
extends BaseDiffInDiffEstimator(uid) | ||
with ComplexParamsWritable | ||
with Wrappable | ||
with SynapseMLLogging { | ||
|
||
logClass(FeatureNames.Causal) | ||
|
||
def this() = this(Identifiable.randomUID("DiffInDiffEstimator")) | ||
|
||
override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({ | ||
val interactionCol = findInteractionCol(dataset.columns.toSet) | ||
val postTreatment = col(getPostTreatmentCol) | ||
val treatment = col(getTreatmentCol) | ||
val outcome = col(getOutcomeCol) | ||
|
||
val didData = dataset.select( | ||
postTreatment.cast(IntegerType).as(getPostTreatmentCol), | ||
treatment.cast(IntegerType).as(getTreatmentCol), | ||
outcome.cast(DoubleType).as(getOutcomeCol) | ||
) | ||
.withColumn(interactionCol, treatment * postTreatment) | ||
|
||
val linearModel = fitLinearModel( | ||
didData, | ||
Array(getPostTreatmentCol, getTreatmentCol, interactionCol), | ||
fitIntercept = true | ||
) | ||
|
||
val treatmentEffect = linearModel.coefficients(2) | ||
val standardError = linearModel.summary.coefficientStandardErrors(2) | ||
val summary = DiffInDiffSummary(treatmentEffect, standardError) | ||
|
||
copyValues(new DiffInDiffModel(this.uid)) | ||
.setSummary(Some(summary)) | ||
.setParent(this) | ||
}, dataset.columns.length) | ||
} | ||
|
||
object DiffInDiffEstimator extends ComplexParamsReadable[DiffInDiffEstimator] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.