사이킷런의 보스턴 주택 가격 데이터셋을 사용하여 다항 회귀 분석을 한다. 참고로 해당 데이터셋은 윤리적인 문제로 scikit-learn 1.2 버전 이후로 삭제되었다.


1. 라이브러리 및 데이터 로드

# 필요한 라이브러리
from sklearn.datasets import load_boston
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import pandas as pd  

# 보스턴 데이터셋을 boston_dataset 변수에 저장
boston = load_boston()


2. 변수 생성 및 학습, 테스트 데이터 나누기

가설함수를 2차 함수라고 가정하고, 그에 맞는 변수들을 생성해야한다. 과정은 다음과 같다.

# 다항 회귀, 가설함수를 2차 함수로 설정
polynomial_transformer = PolynomialFeatures(2)
# 2차 함수에 맞게 변수 생성 
polynomial_data = polynomial_transformer.fit_transform(boston.data)
# 생성된 변수에도 변수명을 부여
polynomial_feature_names = polynomial_transformer.get_feature_names(boston.feature_names)

# x, y 데이터프레임 형태로 변환
x = pd.DataFrame(polynomial_data, columns=polynomial_feature_names)
y = pd.DataFrame(boston.target, columns=[['MEDV']])
x(입력변수) 구조

x shape : (506, 105), y shape : (506, 1)


# 학습, 테스트용 데이터셋 나누기
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

학습 데이터셋은 80%, 테스트 데이터셋은 20%로 데이터를 나누고 random_state를 42로 설정하여 데이터 분배 결과를 고정할 수 있도록 한다.

# 잘 나누어졌는지 shape로 확인
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
출력 결과
(404, 105) (102, 105) (404, 1) (102, 1)


3. 모델링, 평가

# 선형 회귀 함수 사용
model = LinearRegression()
# 학습데이터로 모델 학습
model.fit(x_train, y_train)
# 학습이 완료된 모델에 테스트 데이터를 입력하여 y값을 예측
y_predict = model.predict(x_test)

# RMSE 값 구하기
mean_squared_error(y_test, y_test_prediction) ** 0.5
출력 결과
3.7661065104015696

보스턴 주택 가격 데이터셋을 학습, 테스트 8:2로 데이터를 나누고 다중 회귀(가설함수 2차 함수로 가정)로 분석한 결과 약 3.8 정도의 오차를 가지는 모델을 생성하였다. 이전에 범죄율만 가지고 주택 가격을 예측했을 때(RMSE : 약 7.8)와 입력 변수를 전체 다 사용했으나 선형으로 예측한 다중선형회귀(RMSE : 약 4.9) 보다 성능이 좋아졌음을 알 수 있다.


4. 최종 모델

# theta 0의 값
model.intercept_
# theta 1, 2, 3 ... 의 값
model.coef_

theta 0의 값 :

array([4.92670721e+08])

theta 1, 2, 3, … , 105의 값 :

array([[-4.92670936e+08, -7.22931535e+00,  6.64177060e-01,
        -4.72179590e+00,  3.67378882e+01,  2.85261009e+02,
         1.78393390e+01,  6.64572881e-01, -4.11341040e+00,
         2.66263484e+00, -8.43232394e-03,  7.64945359e+00,
         1.32201393e-01, -2.02052871e-01,  3.52764521e-04,
         7.38855141e-02,  5.84418414e-01,  2.58757422e+00,
        -2.02837847e+00,  1.75603439e-01, -5.40745692e-03,
        -1.89117490e-01,  2.71314302e-01, -3.52477416e-02,
         6.94397406e-01, -4.79466963e-04,  3.38903866e-02,
        -5.10155901e-04, -4.63119868e-03, -6.28083514e-02,
        -1.37674876e+00, -3.78123454e-03,  1.24102968e-03,
        -2.13795684e-02, -1.92521604e-02,  8.57218385e-04,
        -4.51568792e-03,  1.80318307e-04, -9.46247722e-03,
         5.45338803e-02, -1.02797894e-02, -2.15994391e-01,
         3.06860323e-01,  4.88014045e-03,  1.65323538e-01,
        -1.00727368e-03, -5.39816702e-04, -2.70282025e-02,
         4.05514997e-03, -1.83414642e-02,  3.67378833e+01,
        -3.67209648e+01, -5.26529847e+00, -3.03016356e-02,
        -9.23229861e-01,  1.28146102e-01, -9.10391216e-03,
        -8.38122909e-01,  1.11381162e-02, -2.41896558e-01,
        -9.84147266e+01, -8.70543217e+00, -4.60198145e-02,
         1.48030219e+01, -2.07999870e+00,  2.84684335e-01,
        -1.44996628e+01,  1.70896337e-03,  1.01429845e+00,
         1.13033641e+00, -6.77317315e-02, -3.26822751e-01,
        -1.83631659e-01, -1.24709743e-02, -5.64945789e-01,
        -6.57185143e-03, -3.68862299e-02, -1.49994397e-05,
         3.74969342e-04,  1.42119169e-02, -4.84839890e-04,
         4.45775932e-03, -6.50274375e-04, -8.34292370e-03,
         5.03210046e-01,  7.43943729e-02, -6.13610785e-03,
        -2.56856051e-01, -4.21354965e-03,  3.94347786e-02,
        -8.29077422e-02,  4.97195607e-03, -1.69892466e-01,
         4.15677692e-03, -2.95959896e-02, -4.70479767e-05,
         9.39935703e-03, -4.29702479e-04, -1.13419291e-03,
        -2.59300486e-02,  5.94455365e-03,  2.60203089e-02,
        -4.28552296e-05, -3.08363345e-04,  1.67517301e-02]])

이번 결과는 이전에 분석한 결과(입력값이 1개)와는 다르게 입력변수가 13개이기 때문에 시각적으로 표현하기 어렵다. 하지만 다항 회귀 모델의 세타값들은 위와 같이 구할 수 있었다.


5. Reference


👩🏻‍💻개인 공부 기록용 블로그입니다
오류나 틀린 부분이 있을 경우 댓글 혹은 메일로 따끔하게 지적해주시면 감사하겠습니다.

댓글남기기