머신러닝 모델별 퍼포먼스 그래프

Machine learning
ggplot2
Published

March 3, 2023

아래와 같이 모델별로 머신러닝 성능을 정리한 표가 있다고 해보자.

이 데이터를 통해 머신러닝 모델별, 그리고 질병별 성능을 아래와 같이 시각화해보자. X축에는 모델이, Y축에는 성능이 와야 하고, 각 성능 지표별로 다른 선으로 구성할 것이다. 그리고 각 질병별로 이런 그래프가 나열되어야 한다.

우선 데이터 시각화를 위한 데이터프레임을 만들어야 한다. 저 표 형태 그대로는 데이터프레임으로 저장할 수 없다. 따라서 데이터 프레임 형태를 잘 구성하는 게 중요하다.

특정 범주별로 색상을 구분해주는 그래프를 만들기 위해선 X축과 범례에 들어갈 범주 데이터, 그리고 수치에 해당하는 숫자 데이터를 각각 다른 열로 구성해야 한다.

우선 각 모델별 성능별 점수가 질병별로 필요하므로, 질병 열을 구성해준다. 위의 표에서 질병 * 모델 * 성능을 봤을 때, 하나의 행이 질병 하나의 모델별 성능별 점수인 것을 알 수 있다. 따라서 질병에 따라 모델, 성능, 점수가 길게 늘어지는 long data가 되도록 구성한다.

require(data.table)
require(ggplot2)
require(ggthemes)
df <- data.table(
  disease = rep(c("Death","Heart disease","Stroke","Cancer","Hypertension","Diabetes"), each=12),
  model = rep(rep(c("Mod1","Mod2","Mod3","Mod4"),each=3),6),
  category = rep(c("AUC","Accuracy","F1"), 4*6),
  score = c(0.863, 0.725, 0.106, 0.873, 0.753, 0.116, 0.877, 0.802, 0.135, 0.887, 0.793, 0.135,
            0.718,  0.617,  0.213,  0.731,  0.653, 0.224,   0.731,  0.605,  0.216,  0.734,  0.614,  0.218,
            0.790,  0.684,  0.223,  0.794,  0.681, 0.223,   0.794,  0.713,  0.234,  0.795,  0.705,  0.233,
            0.774,  0.654,  0.123,  0.775,  0.669, 0.126,   0.776,  0.665,  0.125,  0.783,  0.690,  0.131,
            0.698,  0.615,  0.434,  0.720,  0.595, 0.449,   0.769,  0.662,  0.492,  0.770,  0.665,  0.493,
            0.706,  0.592,  0.324,  0.725,  0.615, 0.336,   0.806,  0.699,  0.400,  0.809,  0.726,  0.413
            )
)
df |> head(12)

하나의 질병 당 모델 네 종류, 그리고 모델 성능 세 종류가 필요하다. 그렇기 떄문에 하나의 질병을 12번씩 반복하여 만들어주면 된다.

다음으로 원하는 순서대로 출력하기 위해, 질병의 순서와 모델 성능의 순서를 정해준다.

df$disease <- factor(df$disease, levels=c("Death","Heart disease","Stroke","Cancer","Hypertension","Diabetes"))
df$category <- factor(df$category, levels=c("AUC","Accuracy","F1"))

그리고 나서 질병 * 모델 * 성능 지표 별 점수를 시각화해준다.

 df |> 
  ggplot(aes(x=model, y=score, color=category, group=category))+
  geom_point(aes(shape=category)) + 
  geom_line(aes(lty=category)) + 
  facet_grid(~disease) +
  scale_x_discrete(labels=c("Model 1", "Model 2", "Model 3", "Model 4"))+
  scale_y_continuous(limits = c(0,1),
                     breaks = seq(0,1,0.2),
                     expand = c(0,0)) + 
  scale_shape_discrete(labels=c("AUC","Accuracy","F1-score"),
                       name = NULL)+
  scale_color_tableau(labels=c("AUC","Accuracy","F1-score"),
                       name = NULL)+
  scale_linetype_discrete(labels=c("AUC","Accuracy","F1-score"),
                          name = NULL)+
  geom_text(aes(label=format(score,3)), vjust=-1, size=1.5) +
  theme_few() +
  theme(legend.position = "top",
        legend.direction = "horizontal",
        axis.title = element_blank(),
        legend.text = element_text(size=7),
        axis.text = element_text(size=4),
        strip.text = element_text(size=6, face = "bold")) +
  ggtitle("(A) Male (n=64,389)")

질병은 facet_grid()에 포함시켜 옆으로 나열하게끔 시각화 하였다.

Back to top