---
title: "Regression Trees -- How splits are determined"
author: "D G Rossiter"
date: "`r Sys.Date()`"
output:
   html_document:
    toc: TRUE
    toc_float: TRUE
    theme: "lumen"
    code_folding: show
    number_sections: TRUE
    fig_keep: TRUE
    fig_height: 4
    fig_width: 6
    fig_align: 'center'
---

This shows how `rpart` decides on splits.

# Dataset

Load the example dataset:

```{r}
library(sp)
data(meuse)
names(meuse)
```

We want to predict Zn from distance to river and elevation.

# Top-level split

## Split on distance to river

We first try to split on distance. The idea is to find the cutpoint at which the residual sum of squares (within-group) is lowest, i.e., the between-group is highest. So, we sort the values and try them all.

We first find the total sum of squares (TSS), i.e., with no model:

```{r}
(tss <- sum((meuse$zinc - mean(meuse$zinc))^2))
```

Now try all thresholds; keep the results in a dataframe

```{r, fig.width=12, fig.height=6}
(distances <- sort(unique(meuse$dist.m)))
(nd <- length(distances))
results.df <- data.frame(distance=distances, rss.less=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <- meuse$zinc[meuse$dist.m < distances[i]]
  branch.more <- meuse$zinc[meuse$dist.m >= distances[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
print(results.df)
(ix.r.squared.max <- which.max(results.df$r.squared))
print(results.df[ix.r.squared.max,])
(distance.r.squared.max <- results.df[ix.r.squared.max,"r.squared"])
(d.threshold.1 <- results.df[ix.r.squared.max,"distance"])
plot(r.squared ~ distance, data=results.df, type="h",
     col=ifelse(distance==d.threshold.1,"red","gray"))
```

The cutoff for distance should be `r d.threshold.1` meters. This would explain `r round(distance.r.squared.max*100, 1)`% of the variation in Zn concentration.

## Split on elevation

But we also need to try the other possible splitting variable:

```{r, fig.width=12, fig.height=6}
(elevations <- sort(unique(meuse$elev)))
(nd <- length(elevations))
results.df <- data.frame(elevation=elevations, rss.less=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <- meuse$zinc[meuse$elev < elevations[i]]
  branch.more <- meuse$zinc[meuse$elev >= elevations[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
# print(results.df)
(ix.r.squared.max <- which.max(results.df$r.squared))
print(results.df[ix.r.squared.max,])
(elevation.r.squared.max <- results.df[ix.r.squared.max,"r.squared"])
(e.threshold.1 <- results.df[ix.r.squared.max,"elevation"])
plot(r.squared ~ elevation, data=results.df, type="h",
     col=ifelse(elevation==e.threshold.1,"red","gray"))
```

The cutoff for elevation should be `r e.threshold.1` m.a.s.l. This would explain `r round(elevation.r.squared.max*100, 1)`% of the variation in Zn concentration.

Clearly the distance gives a better split ($R^2$ = `r round(distance.r.squared.max,3)`)  than elevation ($R^2$ = `r round(elevation.r.squared.max,3)`).

## Make the split

We now split the dataset into those two groups, but only if the increase in $R^2$ is more than a user-specified threshold. Here there is a very large increase in $R^2$ (the unsplit $R^2$ is by definition 0), so we split.

```{r}
meuse.split1.less <- meuse[meuse$dist.m < d.threshold.1,]
meuse.split1.more <- meuse[meuse$dist.m >= d.threshold.1,]
dim(meuse.split1.less)[1]; dim(meuse.split1.more)[1]
```

We see that the first group (closer to river) has fewer points.


# Second-level split

Now `rpart` does the same procedure for each group separately; the improvement in $R^2$ of the finer split is compared to that from the first split.

## Split of left group

We try both distance and elevation on the left-hand group.

```{r}
(distances <- sort(unique(meuse.split1.less$dist.m)))
nd <- length(distances)
```

All of these distances are less then the threshold use for the first split.

We now see how much further splitting on distance improves the fit.

```{r, fig.width=12, fig.height=6}
results.df <- data.frame(distance=distances, rss.less=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <- meuse.split1.less$zinc[meuse.split1.less$dist.m < distances[i]]
  branch.more <- meuse.split1.less$zinc[meuse.split1.less$dist.m >= distances[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
print(results.df)
ix.r.squared.max <- which.max(results.df$r.squared)
print(results.df[ix.r.squared.max,])
distance.r.squared.max <- results.df[ix.r.squared.max,"r.squared"]
d.threshold.1.1 <- results.df[ix.r.squared.max,"distance"]
plot(r.squared ~ distance, data=results.df, type="h",
     col=ifelse(distance==d.threshold.1.1,"red","gray"))
```

The cutoff for distance should be `r d.threshold.1.1` meters; however we see it is very close to the next-further distance. This explains `r round(100*distance.r.squared.max, 2)`\% of the variation _in this group_.

But we also need to try the other possible splitting variable, elevation.

```{r}
(elevations <- sort(unique(meuse.split1.less$elev)))
nd <- length(elevations)
```

Although we did not split on elevation, the set of elevations at this closer distance is smaller than the set for all the points.


```{r, fig.width=12, fig.height=6}
results.df <- data.frame(elevation=elevations, rss.less=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <- meuse.split1.less$zinc[meuse.split1.less$elev < elevations[i]]
  branch.more <- meuse.split1.less$zinc[meuse.split1.less$elev >= elevations[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
# print(results.df)
ix.r.squared.max.e <- which.max(results.df$r.squared)
print(results.df[ix.r.squared.max.e,])
elevation.r.squared.max <- results.df[ix.r.squared.max.e,"r.squared"]
e.threshold.1.1 <- results.df[ix.r.squared.max.e,"elevation"]
plot(r.squared ~ elevation, data=results.df, type="h",
     col=ifelse(elevation==e.threshold.1.1,"red","gray"))
```

The cutoff for elevation should be `r e.threshold.1.1` meters; however we see it is very close to another similar elevation. This explains `r round(100*elevation.r.squared.max, 2)`\% of the variation _in this group_.

The two improvements are close, but now splitting this branch on elevation ($R^2$ = `r round(elevation.r.squared.max,3)`) gives a better $R^2$ _within this branch_ than splitting on distance ($R^2$ = `r round(distance.r.squared.max,3)`).

The $R^2$ reported in the previous code is only for this group; to decide if the split is enough improvement in the overall model we need to compute the $R^2$ of the proposed split over all the (new) groups, and compare with the $R^2$ from the first split we've already made.

```{r}
rss.right <- sum((meuse.split1.more$zinc - mean(meuse.split1.more$zinc))^2)
rss.left <- sum((meuse.split1.less$zinc - mean(meuse.split1.less$zinc))^2)
(r2.1 <- 1 - ((rss.right + rss.left)/tss))
rss.left.split <- sum(results.df[ix.r.squared.max.e,c("rss.less","rss.more")])
(r2.2 <- 1 - ((rss.right + rss.left.split)/tss))
```

The  $R^2$ after this split is `r round(r2.2,3)`; the $R^2$ after the first split is `r round(r2.1,3)`; the improvement is `r round((r2.2 - r2.1), 3)`. Depending on the setting of the complexity parameter (CP), we would decide to split or not. This improvment is not too much, only about 4\%, i.e., corresponding to a CP of 0.04.

Make the split:

```{r}
meuse.split1.less.split2.less <- meuse.split1.less[meuse.split1.less$elev < e.threshold.1.1,]
meuse.split1.less.split2.more <- meuse.split1.less[meuse.split1.less$elev >= e.threshold.1.1,]
dim(meuse.split1.less.split2.less)[1]; dim(meuse.split1.less.split2.more)[1]
row.names(meuse.split1.less.split2.less)
row.names(meuse.split1.less.split2.more)
```


## Split of right group

Same procedure for the other group.

```{r, fig.width=12, fig.height=6}
(distances <- sort(unique(meuse.split1.more$dist.m)))
(nd <- length(distances))
results.df <- data.frame(distance=distances, rss.more=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <- meuse.split1.more$zinc[meuse.split1.more$dist.m < distances[i]]
  branch.more <- meuse.split1.more$zinc[meuse.split1.more$dist.m >= distances[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
print(results.df)
ix.r.squared.max <- which.max(results.df$r.squared)
print(results.df[ix.r.squared.max,])
(distance.r.squared.max <- results.df[ix.r.squared.max,"r.squared"])
d.threshold.2.1 <- results.df[ix.r.squared.max,"distance"]
plot(r.squared ~ distance, data=results.df, type="h",
     col=ifelse(distance==d.threshold.2.1,"red","gray"))
```

The cutoff for distance should be `r d.threshold.2.1` meters, explaining `r round(100*distance.r.squared.max,2)`% of the variation _in this group_; however we see it is very close to the next-further distance.

But we also need to try the other possible splitting variable, elevation.

```{r, fig.width=12, fig.height=6}
(elevations <- sort(unique(meuse.split1.more$elev)))
nd <- length(elevations)
results.df <- data.frame(elevation=elevations, rss.less=rep(0,nd), rss.more=rep(0,nd), rss=rep(0,nd), r.squared=rep(0,nd))
for (i in 1:nd) {
  branch.less <-
    meuse.split1.more$zinc[meuse.split1.more$elev < elevations[i]]
  branch.more <-
    meuse.split1.more$zinc[meuse.split1.more$elev >= elevations[i]]
  rss.less <- sum((branch.less-mean(branch.less))^2)
  rss.more <- sum((branch.more-mean(branch.more))^2)
  rss <- sum(rss.less + rss.more)
  results.df[i,2:5] <- c(rss.less, rss.more, rss, 1-rss/tss)
  }
# print(results.df)
ix.r.squared.max.e <- which.max(results.df$r.squared)
print(results.df[ix.r.squared.max.e,])
elevation.r.squared.max <- results.df[ix.r.squared.max.e,"r.squared"]
e.threshold.1.2 <- results.df[ix.r.squared.max.e,"elevation"]
plot(r.squared ~ elevation, data=results.df, type="h",
     col=ifelse(elevation==e.threshold.1.2,"red","gray"))
```

The cutoff for elevation should be `r e.threshold.1.2` m.a.s.l., explaining `r round(100*elevation.r.squared.max,2)`% of the variation _in this group_; however we see it is very close to the next-further distance.

The two improvements are close, but now splitting this branch on elevation ($R^2$ = `r round(elevation.r.squared.max,3)`) gives a better $R^2$ than splitting on distance ($R^2$ = `r round(distance.r.squared.max,3)`).

Compute the $R^2$ of the proposed split over all the (new) groups, and compare with the $R^2$ from the first split we've already made. We do not include the split of the left branch.

```{r}
rss.left.split <- sum(results.df[ix.r.squared.max.e,c("rss.less","rss.more")])
(r2.2 <- 1 - ((rss.right + rss.left.split)/tss))
```

The  $R^2$ after this second split is `r round(r2.2,3)`; the $R^2$ after the first split is `r round(r2.1,3)`; the improvement is `r round((r2.2 - r2.1), 3)`. This is a very substantial improvement, so the split would be made.

Make the split:

```{r}
meuse.split1.more.split2.less <- meuse.split1.more[meuse.split1.more$elev < e.threshold.1.2,]
meuse.split1.more.split2.more <- meuse.split1.more[meuse.split1.more$elev >= e.threshold.1.2,]
dim(meuse.split1.more.split2.less)[1]; dim(meuse.split1.more.split2.more)[1]
row.names(meuse.split1.more.split2.less)
row.names(meuse.split1.more.split2.more)
```

# Predictions for the groups

We now have four groups at two levels. The predictions for the groups are the averages of the target variable value in the group.

Top level (no split):

```{r}
mean(meuse$zinc)
```

First split:

```{r}
mean(meuse.split1.less$zinc)
mean(meuse.split1.more$zinc)
```

This is a big difference: closer to the river than `r d.threshold.1` m has a much higher mean Zn concentration than further away, about 3x.

Second split, left branch:

```{r}
mean(meuse.split1.less$zinc)
(pred.1.1 <- mean(meuse.split1.less.split2.less$zinc))
(pred.1.2 <- mean(meuse.split1.less.split2.more$zinc))
```

At the  elevations lower than `r e.threshold.1.2` m.a.s.l., close to the river, are the highest Zn concentrations.

Second split, right branch:

```{r}
mean(meuse.split1.more$zinc)
(pred.2.1 <- mean(meuse.split1.more.split2.less$zinc))
(pred.2.2 <- mean(meuse.split1.more.split2.more$zinc))
```

Again the elevation is important: at the  elevations lower than `r e.threshold.1.2` m.a.s.l., but far from the river, are the second-highest Zn concentrations.

# Map

Add the predictions for each point and display:

```{r}
meuse$rf.pred <- 0
meuse[(meuse$dist.m < d.threshold.1) & (meuse$elev < e.threshold.1.1), "rf.pred"] <- pred.1.1
meuse[(meuse$dist.m < d.threshold.1) & (meuse$elev >= e.threshold.1.1), "rf.pred"] <- pred.1.2
meuse[(meuse$dist.m >= d.threshold.1) & (meuse$elev < e.threshold.1.2), "rf.pred"] <- pred.2.1
meuse[(meuse$dist.m >= d.threshold.1) & (meuse$elev >= e.threshold.1.2), "rf.pred"] <- pred.2.2
table(round(meuse$rf.pred,2))
```

```{r}
plot(meuse$y ~ meuse$x, cex=1.5*meuse$rf.pred/max(meuse$rf.pred), asp=1)
grid()
data(meuse.riv)
lines(meuse.riv)
```