跳转至

项目3: ML模型训练

难度: ⭐⭐⭐ 中级 时间: 4-5小时 涉及知识: scikit-learn, Pandas, 数据预处理, 模型评估


🎯 项目目标

完成一个端到端的机器学习项目: 1. 数据加载和探索 2. 数据预处理和特征工程 3. 训练多个模型 4. 模型评估和比较 5. 保存和加载模型


📋 需求

功能需求

Bash
# 命令行使用
python train.py data.csv --target price --model random_forest --output model.pkl

# 功能
- 自动识别数据类型(数值/分类)
- 自动处理缺失值
- 训练多个模型并比较
- 输出评估报告
- 保存最佳模型

技术要求

  • 使用pandas处理数据
  • 使用scikit-learn训练模型
  • 使用argparse处理命令行参数
  • 适当的日志记录
  • 模型持久化

🚀 实现步骤

步骤1: 环境准备

Bash
pip install pandas numpy scikit-learn matplotlib

步骤2: 数据加载和探索

Python
# ml_pipeline.py
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import mean_squared_error, r2_score
import pickle
import argparse
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MLPipeline:
    def __init__(self):
        self.data = None
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.scaler = StandardScaler()
        self.imputer = SimpleImputer(strategy='mean')
        self.label_encoders = {}
        self.model = None
        self.task_type = None  # 'classification' or 'regression'

    def load_data(self, filepath, target_column):
        """加载数据"""
        logger.info(f"Loading data from {filepath}")

        if filepath.endswith('.csv'):
            self.data = pd.read_csv(filepath)
        elif filepath.endswith('.xlsx'):
            self.data = pd.read_excel(filepath)
        else:
            raise ValueError("Unsupported file format")

        logger.info(f"Data shape: {self.data.shape}")
        logger.info(f"Columns: {self.data.columns.tolist()}")

        # 分离特征和目标
        X = self.data.drop(columns=[target_column])
        y = self.data[target_column]

        # 判断任务类型
        if y.dtype == 'object' or y.nunique() < 10:
            self.task_type = 'classification'
        else:
            self.task_type = 'regression'

        logger.info(f"Task type: {self.task_type}")

        return X, y

    def preprocess(self, X, y):
        """数据预处理"""
        logger.info("Preprocessing data...")

        # 处理分类特征
        categorical_cols = X.select_dtypes(include=['object']).columns
        for col in categorical_cols:
            le = LabelEncoder()
            X[col] = le.fit_transform(X[col].astype(str))
            self.label_encoders[col] = le

        # 处理缺失值
        X = pd.DataFrame(
            self.imputer.fit_transform(X),
            columns=X.columns
        )

        # 分割数据
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )

        # 标准化
        self.X_train = self.scaler.fit_transform(self.X_train)
        self.X_test = self.scaler.transform(self.X_test)

        logger.info(f"Training set size: {len(self.X_train)}")
        logger.info(f"Test set size: {len(self.X_test)}")

    def train(self, model_name='random_forest'):
        """训练模型"""
        logger.info(f"Training {model_name} model...")

        if self.task_type == 'classification':
            if model_name == 'random_forest':
                self.model = RandomForestClassifier(n_estimators=100, random_state=42)
            elif model_name == 'logistic_regression':
                self.model = LogisticRegression(random_state=42)
            else:
                raise ValueError(f"Unknown model: {model_name}")
        else:  # regression
            if model_name == 'random_forest':
                self.model = RandomForestRegressor(n_estimators=100, random_state=42)
            elif model_name == 'linear_regression':
                self.model = LinearRegression()
            else:
                raise ValueError(f"Unknown model: {model_name}")

        self.model.fit(self.X_train, self.y_train)
        logger.info("Training completed")

    def evaluate(self):
        """评估模型"""
        predictions = self.model.predict(self.X_test)

        if self.task_type == 'classification':
            accuracy = accuracy_score(self.y_test, predictions)
            logger.info(f"Accuracy: {accuracy:.4f}")
            logger.info("\nClassification Report:")
            logger.info(classification_report(self.y_test, predictions))
            return {'accuracy': accuracy}
        else:
            mse = mean_squared_error(self.y_test, predictions)
            r2 = r2_score(self.y_test, predictions)
            logger.info(f"MSE: {mse:.4f}")
            logger.info(f"R2 Score: {r2:.4f}")
            return {'mse': mse, 'r2': r2}

    def save_model(self, filepath):
        """保存模型"""
        model_data = {
            'model': self.model,
            'scaler': self.scaler,
            'imputer': self.imputer,
            'label_encoders': self.label_encoders,
            'task_type': self.task_type
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
        logger.info(f"Model saved to {filepath}")

    def load_model(self, filepath):
        """加载模型"""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)

        self.model = model_data['model']
        self.scaler = model_data['scaler']
        self.imputer = model_data['imputer']
        self.label_encoders = model_data['label_encoders']
        self.task_type = model_data['task_type']

        logger.info(f"Model loaded from {filepath}")

    def predict(self, X):
        """预测"""
        # 预处理新数据
        X = X.copy()
        for col, le in self.label_encoders.items():
            if col in X.columns:
                X[col] = le.transform(X[col].astype(str))

        X = pd.DataFrame(self.imputer.transform(X), columns=X.columns)
        X = self.scaler.transform(X)

        return self.model.predict(X)

def main():
    parser = argparse.ArgumentParser(description='ML Training Pipeline')
    parser.add_argument('data', help='Path to data file')
    parser.add_argument('--target', required=True, help='Target column name')
    parser.add_argument('--model', default='random_forest',
                       choices=['random_forest', 'logistic_regression', 'linear_regression'],
                       help='Model type')
    parser.add_argument('--output', default='model.pkl', help='Output model file')

    args = parser.parse_args()

    # 创建pipeline
    pipeline = MLPipeline()

    # 加载数据
    X, y = pipeline.load_data(args.data, args.target)

    # 预处理
    pipeline.preprocess(X, y)

    # 训练
    pipeline.train(args.model)

    # 评估
    metrics = pipeline.evaluate()

    # 保存模型
    pipeline.save_model(args.output)

    print("\nTraining completed successfully!")

if __name__ == '__main__':
    main()

步骤3: 模型比较

Python
def compare_models(pipeline):
    """比较多个模型"""
    models = {
        'classification': ['random_forest', 'logistic_regression'],
        'regression': ['random_forest', 'linear_regression']
    }

    results = {}
    task_models = models[pipeline.task_type]

    for model_name in task_models:
        logger.info(f"\n{'='*50}")
        logger.info(f"Training {model_name}")
        logger.info(f"{'='*50}")

        pipeline.train(model_name)
        metrics = pipeline.evaluate()
        results[model_name] = metrics

    # 打印比较结果
    logger.info("\n" + "="*50)
    logger.info("Model Comparison")
    logger.info("="*50)
    for model_name, metrics in results.items():
        logger.info(f"\n{model_name}:")
        for metric, value in metrics.items():
            logger.info(f"  {metric}: {value:.4f}")

    return results

📝 扩展挑战

  1. 超参数调优 - 使用GridSearchCV或RandomizedSearchCV
  2. 交叉验证 - 添加K-fold交叉验证
  3. 特征重要性 - 输出并可视化特征重要性
  4. 模型解释 - 使用SHAP或LIME解释模型
  5. Web界面 - 使用Streamlit创建交互界面

🎯 完成标准

  • 能加载和处理CSV/Excel数据
  • 自动识别分类/回归任务
  • 正确处理缺失值和分类特征
  • 能训练至少2种模型
  • 输出评估指标
  • 能保存和加载模型
  • 代码结构清晰

💡 提示

  • 先用简单的数据集测试(如鸢尾花、波士顿房价)
  • 注意数据泄露问题(先分割再标准化)
  • 保存预处理参数,确保预测时使用相同处理
  • 记录实验参数和结果

📚 参考资源


🚀 下一步

完成后,尝试: - 项目4: API开发 - 参加Kaggle竞赛 - 学习深度学习项目

记住: 特征工程往往比调参更重要!