-
Notifications
You must be signed in to change notification settings - Fork 607
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- several new data sets: Boxes, Mites, Lynx_Hare, Panda_nuts, Moralizing_gods - new depends and imports: shape, igraph, dagitty - new drawdag() DAG drawing function that uses igraph methods. Still raw. Previous function renamed drawdagitty().
- Loading branch information
Richard McElreath
committed
Sep 1, 2019
1 parent
b0c1913
commit d0342e3
Showing
20 changed files
with
2,190 additions
and
248 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
Package: rethinking | ||
Type: Package | ||
Title: Statistical Rethinking book package | ||
Version: 1.88 | ||
Date: 2019-07-24 | ||
Version: 1.90 | ||
Date: 2019-08-04 | ||
Author: Richard McElreath | ||
Maintainer: Richard McElreath <[email protected]> | ||
Imports: coda, MASS, mvtnorm, loo | ||
Depends: rstan (>= 2.10.0), parallel, methods, stats, graphics | ||
Imports: coda, MASS, mvtnorm, loo, shape, igraph | ||
Depends: rstan (>= 2.10.0), parallel, methods, stats, graphics, dagitty | ||
Description: Utilities for fitting and comparing models | ||
License: GPL (>= 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,248 +1,133 @@ | ||
# my DAG drawing function to extend dagitty plot method | ||
|
||
drawdag <- function( x , col_arrow="black" , col_segment="black" , col_labels="black" , cex=1 , lwd=1.5 , goodarrow=TRUE , xlim , ylim , shapes , col_shapes , radius=3 , add=FALSE , xkcd=FALSE , ... ){ | ||
require(dagitty) | ||
x <- as.dagitty( x ) | ||
dagitty:::.supportsTypes(x,c("dag","mag","pdag")) | ||
coords <- coordinates( x ) | ||
if( any( !is.finite( coords$x ) | !is.finite( coords$y ) ) ){ | ||
coords <- coordinates( graphLayout( x ) ) | ||
#stop("Please supply plot coordinates for graph! See ?coordinates and ?graphLayout.") | ||
} | ||
labels <- names(coords$x) | ||
if ( add==FALSE ) { | ||
par(mar=rep(0,4)) | ||
plot.new() | ||
par(new=TRUE) | ||
} | ||
wx <- sapply( paste0("mm",labels), | ||
function(s) strwidth(s,units="inches") ) | ||
wy <- sapply( paste0("\n",labels), | ||
function(s) strheight(s,units="inches") ) | ||
ppi.x <- dev.size("in")[1] / (max(coords$x)-min(coords$x)) | ||
ppi.y <- dev.size("in")[2] / (max(coords$y)-min(coords$y)) | ||
wx <- wx/ppi.x | ||
wy <- wy/ppi.y | ||
if ( missing(xlim) ) | ||
xlim <- c(min(coords$x-wx/2),max(coords$x+wx/2)) | ||
if ( missing(ylim) ) | ||
ylim <- c(-max(coords$y+wy/2),-min(coords$y-wy/2)) | ||
if (add==FALSE ) plot( NA, xlim=xlim, ylim=ylim, xlab="", ylab="", bty="n", | ||
xaxt="n", yaxt="n" ) | ||
wx <- sapply( labels, | ||
function(s) strwidth(paste0("xx",s)) ) | ||
wy <- sapply( labels, | ||
function(s) strheight(paste0("\n",s)) ) | ||
asp <- par("pin")[1]/diff(par("usr")[1:2]) / | ||
(par("pin")[2]/diff(par("usr")[3:4])) | ||
ex <- edges(x) | ||
ax1 <- rep(0,nrow(ex)) | ||
ax2 <- rep(0,nrow(ex)) | ||
ay1 <- rep(0,nrow(ex)) | ||
ay2 <- rep(0,nrow(ex)) | ||
axc <- rep(0,nrow(ex)) | ||
ayc <- rep(0,nrow(ex)) | ||
acode <- rep(2,nrow(ex)) | ||
has.control.point <- rep(FALSE,nrow(ex)) | ||
for( i in seq_len(nrow(ex)) ){ | ||
if( ex[i,3] == "<->" ){ | ||
acode[i] <- 3 | ||
has.control.point[i] <- TRUE | ||
} | ||
if( ex[i,3] == "--" ){ | ||
acode[i] <- 0 | ||
} | ||
l1 <- as.character(ex[i,1]); l2 <- as.character(ex[i,2]) | ||
x1 <- coords$x[l1]; y1 <- coords$y[l1] | ||
x2 <- coords$x[l2]; y2 <- coords$y[l2] | ||
if( is.na( ex[i,4] ) || is.na( ex[i,5] ) ){ | ||
cp <- dagitty:::.autoControlPoint( x1, y1, x2, y2, asp, | ||
.2*as.integer( acode[i]==3 ) ) | ||
} else { | ||
cp <- list(x=ex[i,4],y=ex[i,5]) | ||
has.control.point[i] <- TRUE | ||
} | ||
bi1 <- dagitty:::.lineSegBoxIntersect( x1-wx[l1]/2,y1-wy[l1]/2, | ||
x1+wx[l1]/2,y1+wy[l1]/2, x1, y1, cp$x, cp$y ) | ||
bi2 <- dagitty:::.lineSegBoxIntersect( x2-wx[l2]/2,y2-wy[l2]/2, | ||
x2+wx[l2]/2,y2+wy[l2]/2, cp$x, cp$y, x2, y2 ) | ||
if( length(bi1) == 2 ){ | ||
x1 <- bi1$x; y1 <- bi1$y | ||
} | ||
if( length(bi2) == 2 ){ | ||
x2 <- bi2$x; y2 <- bi2$y | ||
} | ||
ax1[i] <- x1; ax2[i] <- x2 | ||
ay1[i] <- y1; ay2[i] <- y2 | ||
axc[i] <- cp$x; ayc[i] <- cp$y | ||
} | ||
directed <- acode==2 & !has.control.point | ||
undirected <- acode==0 & !has.control.point | ||
|
||
arr.width <- 0.15 | ||
arr.type <- "curved" | ||
arr.adj <- 1 | ||
if ( xkcd==TRUE ) { | ||
for ( ii in 1:length(ax1[directed]) ) { | ||
xkcd_lines( c(ax1[directed][ii],ax2[directed][ii]) , c(-ay1[directed][ii],-ay2[directed][ii]) , col=col_segment , lwd=lwd*2 , lwdbg=lwd*4 , seg=10 ) | ||
}#ii | ||
arr.width <- arr.width * lwd | ||
arr.type <- "triangle" | ||
arr.adj <- 0.5 | ||
goodarrow <- TRUE | ||
} | ||
|
||
if ( goodarrow==TRUE ) { | ||
require(shape) | ||
shape::Arrows( ax1[directed], -ay1[directed], | ||
ax2[directed], -ay2[directed], arr.length=0.2 , arr.width=arr.width, col=col_arrow , lwd=lwd , arr.adj=arr.adj , arr.type=arr.type ) | ||
} else | ||
arrows( ax1[directed], -ay1[directed], | ||
ax2[directed], -ay2[directed], length=0.1, col=col_arrow , lwd=lwd ) | ||
|
||
segments( ax1[undirected], -ay1[undirected], ax2[undirected], -ay2[undirected], col=col_segment , lwd=lwd ) | ||
|
||
for( i in which( has.control.point ) ){ | ||
dag_arc( ax1[i], -ay1[i], | ||
ax2[i], -ay2[i], axc[i], -ayc[i], | ||
col=c( col_arrow , col_segment )[1+(acode[i]==0)], | ||
code=acode[i], length=0.1, lwd=lwd+(acode[i]==0) , goodarrow=goodarrow ) | ||
# my DAG drawing function that uses igraph library | ||
|
||
drawdag <- function( x , | ||
layout=layout_nicely , | ||
vertex.size=30 , | ||
vertex.label.family="sans" , | ||
edge.curved=0 , | ||
edge.color="black" , | ||
edge.width=1.5, | ||
edge.arrow.width=1.2 , | ||
interact=FALSE , | ||
... ) { | ||
the_edges <- dagitty::edges(x) | ||
the_vars <- unique( c( levels(the_edges[[1]]) , levels(the_edges[[2]]) ) ) | ||
# translate edges list to adjacency matrix | ||
n <- length(the_vars) | ||
adjmat <- adjmat_from_dag( x ) | ||
# draw | ||
mgraph <- graph_from_adjacency_matrix( adjmat , mode="directed" ) | ||
the_shapes <- rep("none",n) | ||
# any latent vars? | ||
unobs_vars <- dagitty::latents(x) | ||
if ( length(unobs_vars)>0 ) { | ||
for ( uv in unobs_vars ) the_shapes[ which(the_vars==uv) ] <- "circle" | ||
} | ||
|
||
# node shapes? | ||
# should be named list with "c" for circle or "b" for box | ||
if ( !missing(shapes) ) { | ||
if ( length(shapes)>0 ) { | ||
for ( i in 1:length(shapes) ) { | ||
the_label <- names(shapes)[i] | ||
j <- which( labels==the_label ) | ||
if ( missing(col_shapes) ) col_shapes <- col_labels[ min(j,length(col_labels)) ] | ||
if ( length(j)>0 ) { | ||
cpch <- 1 | ||
if ( shapes[[i]]=="fc" ) cpch <- 16 | ||
if ( shapes[[i]] %in% c("c","fc") ) | ||
#circle( coords$x[the_label] , -coords$y[the_label] , r=radius , lwd=lwd , col=col_shapes ) | ||
points( coords$x[the_label] , -coords$y[the_label] , cex=radius , lwd=lwd , col=col_shapes , pch=cpch ) | ||
} | ||
}#i | ||
}#>0 | ||
if ( interact==FALSE ) { | ||
if ( class(layout)!="matrix" ) | ||
the_layout <- do.call( layout , list(mgraph) ) | ||
else | ||
the_layout <- layout | ||
plot( mgraph , | ||
vertex.color = "white" , | ||
vertex.size = vertex.size , | ||
vertex.shape = the_shapes , | ||
vertex.label.family = vertex.label.family , | ||
edge.arrow.size = 0.6 , | ||
edge.curved = edge.curved , | ||
edge.color = edge.color , | ||
edge.width = edge.width , | ||
vertex.label = the_vars , | ||
vertex.label.color = "black" , | ||
seed = 1 , layout=the_layout , ... ) | ||
} else { | ||
tkplot( mgraph ) | ||
} | ||
# node labels | ||
text( coords$x, -coords$y[labels], labels , cex=cex , col=col_labels ) | ||
rownames(the_layout) <- the_vars | ||
colnames(the_layout) <- c("x-coord","y-coord") | ||
return(invisible(the_layout)) | ||
} | ||
|
||
circle <- function( x , y , r=1 , npts=100 , ... ) { | ||
theta <- seq( 0, 2*pi , length = npts ) | ||
lines( x = x + r * cos(theta) , y = y + r * sin(theta) , ... ) | ||
} | ||
|
||
dag_arc <- function (x1, y1, x2, y2, xm, ym, col = "gray", length = 0.1, | ||
code = 3, lwd = 1.5 , goodarrow=TRUE ) { | ||
x <- c(x1, xm, x2) | ||
y <- c(y1, ym, y2) | ||
res <- xspline(x, y, 1, draw = FALSE) | ||
lines( res , col = col , lwd=lwd ) | ||
nr <- length(res$x) | ||
if (code >= 3) { | ||
if ( goodarrow==TRUE ) { | ||
require(shape) | ||
shape::Arrows( res$x[4], res$y[4], res$x[1], res$y[1], arr.length=0.2 , arr.width=0.15, col=col , lwd=lwd , arr.adj=1 , arr.type="curved" ) | ||
} else | ||
arrows(res$x[1], res$y[1], res$x[4], res$y[4], col = col, | ||
code = 1, length = length, lwd = lwd) | ||
} | ||
if (code >= 2) { | ||
if ( goodarrow==TRUE ) { | ||
require(shape) | ||
shape::Arrows( res$x[nr-3], res$y[nr-3], res$x[nr], res$y[nr], arr.length=0.2 , arr.width=0.15, col=col , lwd=lwd , arr.adj=1 , arr.type="curved" ) | ||
} else | ||
arrows(res$x[nr - 3], res$y[nr - 3], res$x[nr], res$y[nr], | ||
col = col, code = 2, length = length, lwd = lwd) | ||
# take dagitty dag and translate to adjacency matrix | ||
adjmat_from_dag <- function( x ) { | ||
the_edges <- dagitty::edges(x) | ||
the_vars <- unique( c( levels(the_edges[[1]]) , levels(the_edges[[2]]) ) ) | ||
# translate edges list to adjacency matrix | ||
n <- length(the_vars) | ||
adjmat <- matrix( 0 , nrow=n , ncol=n ) | ||
rownames(adjmat) <- the_vars | ||
colnames(adjmat) <- the_vars | ||
for ( i in 1:nrow(the_edges) ) { | ||
node1 <- as.character( the_edges[i,1] ) | ||
node2 <- as.character( the_edges[i,2] ) | ||
tie_type <- as.character( the_edges[i,3] ) | ||
if ( tie_type=="->" ) { | ||
# make directed tie | ||
adjmat[ node1 , node2 ] <- 1 | ||
} | ||
} | ||
return(adjmat) | ||
} | ||
|
||
# function to map coordinates from dag x to reduced dag y | ||
# returns new dagitty | ||
# used by drawopenpaths | ||
dag_copy_coords <- function( x , y ) { | ||
orig <- coordinates(x) | ||
in_dag <- names( coordinates(y)$x ) | ||
new_coords <- coordinates(y) | ||
new_coords$x <- orig$x[ in_dag ] | ||
new_coords$y <- orig$y[ in_dag ] | ||
coordinates(y) <- new_coords | ||
return(y) | ||
} | ||
|
||
# function to overlay open paths, conditioning on list Z | ||
drawopenpaths <- function( x , Z=list() , col_arrow="red" , ... ) { | ||
x <- as.dagitty( x ) | ||
# get all paths | ||
path_list <- paths( x , Z=Z ) | ||
shapes <- list() | ||
if ( length(Z)>0 ) { | ||
for ( i in 1:length(Z) ) shapes[[ Z[[i]] ]] <- "c" | ||
# function to interactively draw layout for given dag | ||
sketchdag <- function( x , cleanup=1 , plot=TRUE , rescale=FALSE , grid=0.2 , ... ) { | ||
the_edges <- dagitty::edges(x) | ||
the_vars <- unique( c( levels(the_edges[[1]]) , levels(the_edges[[2]]) ) ) | ||
n <- length(the_vars) | ||
adjmat <- adjmat_from_dag( x ) | ||
print(the_vars) | ||
plot( NULL , xlim=c(-1,1) , ylim=c(-1,1) , type='n' , axes=FALSE , ann=FALSE ) | ||
if ( grid != FALSE ) { | ||
# draw alignment grid | ||
cseq <- seq( from=-1 , to=1 , by=grid ) | ||
for ( xc in cseq ) abline( h=xc , lwd=0.5 ) | ||
for ( yc in cseq ) abline( v=yc , lwd=0.5 ) | ||
} | ||
# draw open paths in highlight color | ||
for ( i in 1:length(path_list) ) { | ||
if ( path_list$open[i]==TRUE ) { | ||
path_dag <- dagitty( concat( "dag { " , path_list$paths[i] , " }" ) ) | ||
path_dag <- dag_copy_coords( x , path_dag ) | ||
drawdag( path_dag , col_arrow=col_arrow , add=TRUE , ... ) | ||
} | ||
}#i | ||
return(invisible(path_list)) | ||
pts <- matrix( NA , nrow=n , ncol=2 ) | ||
for ( i in 1:n ) { | ||
pt <- locator( 1 , type="p" , pch=the_vars[i] ) | ||
pts[i,1] <- pt$x | ||
pts[i,2] <- pt$y | ||
} | ||
rownames(pts) <- the_vars | ||
colnames(pts) <- c("x-coord","y-coord") | ||
if ( cleanup != FALSE ) pts <- round( pts , cleanup ) | ||
if ( plot==TRUE ) | ||
return(drawdag(x,layout=pts,...)) | ||
else | ||
return(pts) | ||
} | ||
|
||
# test code | ||
if (FALSE) { | ||
|
||
plot( NULL , xlim=c(-1,1) , ylim=c(-1,1) ) | ||
circle( 0 , 0 , 1 ) | ||
|
||
library(rethinking) | ||
library(dagitty) | ||
plant_dag <- dagitty( "dag { | ||
h0 -> h1 | ||
f -> h1 | ||
t -> f | ||
}") | ||
coordinates( plant_dag ) <- list( x=c(h0=0,t=2,f=1,h1=1) , y=c(h0=0,t=0,f=1,h1=2) ) | ||
#Other graph layouts: add_layout_, component_wise, layout_as_bipartite, layout_as_star, layout_as_tree, layout_in_circle, layout_nicely, layout_on_grid, layout_on_sphere, layout_randomly, layout_with_dh, layout_with_fr, layout_with_gem, layout_with_graphopt, layout_with_lgl, layout_with_mds, layout_with_sugiyama, layout_, merge_coords, norm_coords, normalize | ||
|
||
drawdag( plant_dag , cex=1.2 , col_labels=c("red","black","red","red") , col_arrow=c("red","black","red") , goodarrow=TRUE ) | ||
library(rethinking) | ||
|
||
exdag <- dagitty( "dag { | ||
z -> x -> y | ||
x <- U -> y | ||
}") | ||
coordinates( exdag ) <- list( x=c(z=0,x=1,y=2,U=1.5) , y=c(z=0,x=0,y=0,U=-1) ) | ||
drawdag( exdag ) | ||
exdag <- dagitty( "dag { | ||
U [unobserved] | ||
X [exposure] | ||
Y [outcome] | ||
Z -> X -> Y | ||
X <- U -> Y | ||
}") | ||
|
||
# drawing paths | ||
l <- drawdag( exdag , layout=layout_in_circle ) | ||
drawdag( exdag , layout=l ) | ||
|
||
g <- dagitty( "dag { | ||
x -> y | ||
a -> x | ||
b -> y | ||
a -> z <- b | ||
x [exposure] | ||
y [outcome] | ||
}" , layout=TRUE ) | ||
exdag2 <- dagitty( 'dag { | ||
G [exposure,unobserved] | ||
P [outcome] | ||
"G*" <- G -> P -> W -> RG | ||
RG -> "G*" | ||
}') | ||
|
||
drawdag( g , col_arrow="gray" ) | ||
|
||
drawopenpaths( g , col_arrow="black" ) | ||
drawopenpaths( g , Z=list("z") , col_arrow="black" ) | ||
|
||
# xkcd example | ||
exdag <- dagitty( "dag { | ||
z -> x -> y | ||
x <- U -> y | ||
}") | ||
coordinates( exdag ) <- list( x=c(z=0,x=1,y=2,U=1.5) , y=c(z=0,x=0,y=0,U=-1) ) | ||
drawdag( exdag , xkcd=TRUE , lwd=1.5 ) | ||
drawdag( exdag2 ) | ||
|
||
l <- sketchdag( exdag2 ) | ||
drawdag( exdag2 , layout=round(l,2) ) | ||
|
||
sketchdag( exdag2 , asp=0.6 ) | ||
|
||
} | ||
|
Oops, something went wrong.