Building a Decision Tree From Scratch with R

Decision trees are a foundational type of machine learning model which serve as the basis for more advanced tree types such as Random Forests (bagging) and XGBoost (boosting). In general, decision trees are a recursive learning methodology which takes training data and splits it wherever the most information can be gained.

There are many ways to measure information gain, but for the purposes of this introduction, we can build a simple regression tree which splits wherever the standard deviation is most greatly reduced.


First, lets define information gain as a function in R. This function accepts a column vector x which represents the values of a given predictor. Additionally, it accepts y, a vector of responses the same length as x. Finally, it accepts split which is a numeric split point under evaluation.

info_gain <- function(x, y, split) {
  
  # Extract values above and below a point
  lo <- y[x <  split]
  hi <- y[x >= split]
  
  # Early return if standard deviation cannot be used
  if (length(lo) < 2 || length(hi) < 2) {
    return(-Inf)
  }
  
  # Return standard deviation reduction
  sd(y) - (sd(lo) + sd(hi))
}

Here we define a function that determines the best split to maximize the amount of information gained. It accepts x and y just like the function above. It also accepts an argument trials which corresponds with the number of splits we should try.

best_split <- function(x, y, trials) {
  
  # Determine the minimum and maximim x
  x_rng <- range(x)
  x_dif <- diff(x_rng)
  
  # Create a vector of splits
  splits <-
    seq(
      x_rng[1] + 0.1 * x_dif,
      x_rng[2] + 0.9 * x_dif,
      length.out = trials
    )
  
  # Determine information gained by splits
  info <-
    splits |>
    purrr::map_dbl(
      info_gain,
      x = x,
      y = y
    )
  
  # Return the best split
  splits[which.max(info)]
}

Finally, we will create functions which orchestrate the training and prediction of a model using recursion. They accept x and y from above. In addition, the train function accepts trials and min_split arguments which correspond to the number of split trials and minimum number of observations, respectively.

train_recursive <- function(x, y, trials = 2, min_split = 4) {
  
  if (length(x) < min_split) {
    return(list(response = mean(y)))
  }
  
  best <- best_split(x, y, trials)
  
  list(
    split = best,
    left = train_recursive(x[x < best], y[x < best]),
    right = train_recursive(x[x >= best], y[x >= best])
  )
}

predict_recursive <- function(x, model) {
  
  if (is.null(model$split)) {
    return(model$response)
  }
  
  if (x < model$split) {
    predict_recursive(x, model$left)
  } else {
    predict_recursive(x, model$right)
  }
  
}

Now we can use them to train a simple model! Suppose we have the following data and we want to fit a line to these points. There is a very obvious repeating pattern as well as an overarching logarithmic response. This would be very difficult to model with linear regression, but our decision tree has no issue fitting this data!

# Training data
x <- seq(1, 50, 0.5)
y <- x * 0.1 + sin(x) + runif(length(x), -0.5, 0.5)

# Fit a model
mod <- train_recursive(x, y)

# Get predictions for range of x
y_hat <- purrr::map_dbl(x, predict_recursive, model = mod)