跳转至

01 - 数据工程与预处理(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:掌握大模型训练的数据收集、清洗、去重、tokenization和高效数据加载技术。


目录

  1. 数据 pipeline 概述
  2. 数据收集与来源
  3. 数据清洗与质量控制
  4. 数据去重技术
  5. Tokenization 工程
  6. 高效数据加载
  7. 数据混合策略

数据 pipeline 概述

1.1 完整数据流程

Text Only
原始数据
┌─────────────────┐
│   数据收集       │  ← Common Crawl, GitHub, 书籍, 论文
│  (Data Collection)│
└────────┬────────┘
┌─────────────────┐
│   文本提取       │  ← HTML解析, PDF提取, 格式转换
│ (Text Extraction)│
└────────┬────────┘
┌─────────────────┐
│   质量过滤       │  ← 语言检测, 质量评分, 垃圾过滤
│ (Quality Filter) │
└────────┬────────┘
┌─────────────────┐
│   数据去重       │  ← MinHash, SimHash, 精确去重
│  (Deduplication) │
└────────┬────────┘
┌─────────────────┐
│   敏感信息过滤   │  ← PII检测, 毒性内容过滤
│   (PII Removal)  │
└────────┬────────┘
┌─────────────────┐
│   Tokenization  │  ← BPE训练, 文本编码
│  (Tokenization)  │
└────────┬────────┘
┌─────────────────┐
│   数据混合与打包 │  ← 多数据源混合, 序列打包
│  (Data Mixing)   │
└────────┬────────┘
    训练数据

1.2 数据规模对比

模型 训练数据量 数据来源
GPT-3 300B tokens Common Crawl, WebText, Books, Wikipedia
LLaMA-2 2T tokens 公开数据,去重过滤
GPT-4 未公开 估计 >10T tokens
PaLM 780B tokens 网页, 书籍, 代码, 对话
Chinchilla 1.4T tokens 证明数据量与模型大小同等重要

数据收集与来源

2.1 常见数据源

Python
class DataSources:
    """
    大模型训练常见数据源
    """

    SOURCES = {
        # 网页数据
        "common_crawl": {
            "description": "网页爬取数据",
            "volume": "数十PB原始数据",
            "quality": "低-中,需要大量过滤",
            "processing": "WET提取, 语言检测, 质量过滤"
        },

        # 代码数据
        "github": {
            "description": "开源代码仓库",
            "volume": "数百TB",
            "quality": "高,结构化",
            "processing": "License过滤, 去重, 语言分类"
        },

        # 书籍数据
        "books": {
            "description": "书籍和文学作品",
            "volume": "数十TB",
            "quality": "高,长文本",
            "processing": "版权检查, OCR纠错"
        },

        # 学术数据
        "academic": {
            "description": "论文和学术文献",
            "volume": "数TB",
            "quality": "高,专业",
            "processing": "PDF提取, 公式处理"
        },

        # 对话数据
        "conversation": {
            "description": "对话和问答数据",
            "volume": "数TB",
            "quality": "中-高",
            "processing": "隐私过滤, 质量筛选"
        },

        # 百科数据
        "wikipedia": {
            "description": "维基百科",
            "volume": "数百GB",
            "quality": "高,知识密集",
            "processing": "结构化提取, 引用去除"
        }
    }

# 开源数据集
OPEN_DATASETS = {
    "The Pile": "825GB多样化英文文本",
    "C4": "清洗后的Common Crawl",
    "OSCAR": "多语言Common Crawl",
    "ROOTS": "BigScience多语言语料",
    "RedPajama": "LLaMA训练数据开源复刻",
    "RefinedWeb": "高质量网页数据"
}

2.2 Common Crawl 处理流程

Python
class CommonCrawlProcessor:
    """
    Common Crawl 数据处理
    """

    def __init__(self):
        self.language_detector = None  # fastText语言检测
        self.quality_classifier = None  # 质量分类器

    def extract_text_from_html(self, html_content):
        """
        从HTML中提取正文文本
        """
        from bs4 import BeautifulSoup

        soup = BeautifulSoup(html_content, 'html.parser')

        # 移除脚本和样式
        for script in soup(["script", "style"]):
            script.decompose()

        # 提取文本
        text = soup.get_text()

        # 清理空白
        lines = (line.strip() for line in text.splitlines())  # 链式调用:strip去除空白
        chunks = (phrase.strip() for line in lines for phrase in line.split("  "))
        text = '\n'.join(chunk for chunk in chunks if chunk)

        return text

    def detect_language(self, text):
        """
        检测文本语言
        """
        import fasttext

        # 加载预训练语言检测模型
        model = fasttext.load_model('lid.176.bin')

        predictions = model.predict(text.replace('\n', ' '), k=1)
        lang = predictions[0][0].replace('__label__', '')
        score = predictions[1][0]

        return lang, score

    def quality_filter(self, text):
        """
        质量过滤

        过滤条件:
        1. 文本长度(太短或太长)
        2. 符号比例(太多符号可能是垃圾)
        3. 重复行比例
        4. 停用词比例(检测无意义文本)
        """
        import re

        # 长度检查
        if len(text) < 100 or len(text) > 100000:
            return False, "Length filter"

        # 符号比例
        symbol_ratio = len(re.findall(r'[^\w\s]', text)) / len(text)  # re.findall正则查找所有匹配项
        if symbol_ratio > 0.5:
            return False, "Symbol ratio filter"

        # 重复行检测
        lines = text.split('\n')
        unique_lines = set(lines)
        if len(unique_lines) / len(lines) < 0.3:
            return False, "Repetition filter"

        # 单词比例(过滤无意义字符组合)
        words = re.findall(r'\b\w+\b', text)
        if len(words) / len(text.split()) < 0.5:
            return False, "Word ratio filter"

        return True, "Passed"

数据清洗与质量控制

3.1 质量评估指标

Python
class QualityMetrics:
    """
    文本质量评估指标
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def perplexity_score(text, language_model):
        """
        使用语言模型计算困惑度

        低困惑度 = 更像自然语言
        高困惑度 = 可能是垃圾文本
        """
        import torch

        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            inputs = language_model.tokenizer(text, return_tensors='pt')
            outputs = language_model(**inputs, labels=inputs['input_ids'])
            perplexity = torch.exp(outputs.loss)

        return perplexity.item()

    @staticmethod
    def readability_score(text):
        """
        可读性评分

        使用Flesch Reading Ease等标准
        """
        import textstat

        flesch = textstat.flesch_reading_ease(text)
        # 0-100分,越高越容易读

        return flesch

    @staticmethod
    def coherence_score(text):
        """
        连贯性评分

        检测主题一致性
        """
        # 简单实现:检查句子间词汇重叠
        sentences = text.split('.')
        if len(sentences) < 2:
            return 0

        overlap_scores = []
        for i in range(len(sentences) - 1):
            words1 = set(sentences[i].lower().split())
            words2 = set(sentences[i + 1].lower().split())

            if words1 and words2:
                overlap = len(words1 & words2) / len(words1 | words2)
                overlap_scores.append(overlap)

        return sum(overlap_scores) / len(overlap_scores) if overlap_scores else 0

class QualityClassifier:
    """
    基于机器学习的质量分类器
    """

    def __init__(self):
        self.features = [
            'char_count',
            'word_count',
            'sentence_count',
            'avg_word_length',
            'symbol_ratio',
            'uppercase_ratio',
            'stopword_ratio',
            'unique_word_ratio'
        ]

    def extract_features(self, text):
        """
        提取文本特征
        """
        import re
        from collections import Counter

        features = {}

        # 基本统计
        features['char_count'] = len(text)
        features['word_count'] = len(text.split())
        features['sentence_count'] = len(re.findall(r'[.!?]+', text))

        # 平均词长
        words = text.split()
        features['avg_word_length'] = sum(len(w) for w in words) / len(words) if words else 0

        # 符号比例
        features['symbol_ratio'] = len(re.findall(r'[^\w\s]', text)) / len(text)

        # 大写比例
        features['uppercase_ratio'] = sum(1 for c in text if c.isupper()) / len(text)

        # 停用词比例
        stopwords = set(['the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of'])
        word_set = set(w.lower() for w in words)
        features['stopword_ratio'] = len(word_set & stopwords) / len(word_set) if word_set else 0

        # 词汇多样性
        features['unique_word_ratio'] = len(word_set) / len(words) if words else 0

        return features

3.2 PII(个人身份信息)检测与移除

Python
class PIIDetector:
    """
    PII检测与脱敏
    """

    PII_PATTERNS = {
        'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
        'phone': r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',
        'ssn': r'\b\d{3}-\d{2}-\d{4}\b',
        'credit_card': r'\b(?:\d[ -]*?){13,16}\b',
        'ip_address': r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b'
    }

    def __init__(self):
        import spacy
        self.nlp = spacy.load('en_core_web_sm')

    def detect_pii(self, text):
        """
        检测文本中的PII
        """
        import re

        pii_found = []

        # 正则检测
        for pii_type, pattern in self.PII_PATTERNS.items():
            matches = re.finditer(pattern, text)
            for match in matches:
                pii_found.append({
                    'type': pii_type,
                    'value': match.group(),
                    'start': match.start(),
                    'end': match.end()
                })

        # NER检测(姓名、组织、地点)
        doc = self.nlp(text)
        for ent in doc.ents:
            if ent.label_ in ['PERSON', 'ORG', 'GPE']:
                pii_found.append({
                    'type': ent.label_.lower(),
                    'value': ent.text,
                    'start': ent.start_char,
                    'end': ent.end_char
                })

        return pii_found

    def anonymize(self, text, pii_list=None):
        """
        对PII进行脱敏处理
        """
        if pii_list is None:
            pii_list = self.detect_pii(text)

        # 按位置排序(从后往前替换,避免位置偏移)
        pii_list.sort(key=lambda x: x['start'], reverse=True)  # lambda匿名函数

        anonymized = text
        for pii in pii_list:
            replacement = f"[{pii['type'].upper()}]"
            anonymized = (
                anonymized[:pii['start']] +
                replacement +
                anonymized[pii['end']:]
            )

        return anonymized

数据去重技术

4.1 精确去重

Python
class ExactDeduplication:
    """
    精确去重
    """

    def __init__(self):
        self.seen_hashes = set()

    def compute_hash(self, text):
        """
        计算文本哈希
        """
        import hashlib
        return hashlib.md5(text.encode()).hexdigest()

    def is_duplicate(self, text):
        """
        检查是否重复
        """
        text_hash = self.compute_hash(text)
        if text_hash in self.seen_hashes:
            return True
        self.seen_hashes.add(text_hash)
        return False

class NearDeduplication:
    """
    近似去重(MinHash + LSH)
    """

    def __init__(self, num_hashes=128, num_bands=16):
        self.num_hashes = num_hashes
        self.num_bands = num_bands
        self.rows_per_band = num_hashes // num_bands
        self.buckets = [defaultdict(list) for _ in range(num_bands)]  # defaultdict自动为缺失键提供默认值

    def get_shingles(self, text, k=5):
        """
        获取k-gram shingles
        """
        words = text.split()
        shingles = set()
        for i in range(len(words) - k + 1):
            shingle = ' '.join(words[i:i+k])
            shingles.add(shingle)
        return shingles

    def compute_minhash(self, shingles):
        """
        计算MinHash签名
        """
        import hashlib

        signature = []
        for i in range(self.num_hashes):
            min_hash = float('inf')
            for shingle in shingles:
                # 使用不同的哈希函数(通过加盐实现)
                hash_val = int(hashlib.md5(f"{shingle}:{i}".encode()).hexdigest(), 16)
                min_hash = min(min_hash, hash_val)
            signature.append(min_hash)

        return signature

    def is_near_duplicate(self, text, threshold=0.8):
        """
        检查是否近似重复

        使用LSH(局部敏感哈希)加速
        """
        shingles = self.get_shingles(text)
        signature = self.compute_minhash(shingles)

        # LSH: 将签名分桶
        candidate_pairs = set()
        for band_idx in range(self.num_bands):
            start = band_idx * self.rows_per_band
            end = start + self.rows_per_band
            band_signature = tuple(signature[start:end])

            bucket = self.buckets[band_idx]
            if band_signature in bucket:
                candidate_pairs.update(bucket[band_signature])
            bucket[band_signature].append(text)

        # 计算Jaccard相似度验证
        for candidate in candidate_pairs:
            candidate_shingles = self.get_shingles(candidate)
            similarity = len(shingles & candidate_shingles) / len(shingles | candidate_shingles)
            if similarity > threshold:
                return True

        return False

4.2 文档级去重

Python
class DocumentDeduplication:
    """
    文档级去重(SimHash)
    """

    def __init__(self, hashbits=64):
        self.hashbits = hashbits
        self.seen_hashes = []

    def compute_simhash(self, text):
        """
        计算SimHash
        """
        import hashlib
        from collections import defaultdict

        # 分词并统计词频
        words = text.split()
        word_freq = defaultdict(int)
        for word in words:
            word_freq[word] += 1

        # 初始化向量
        vector = [0] * self.hashbits

        # 加权累加
        for word, freq in word_freq.items():
            # 计算词哈希
            hash_val = int(hashlib.md5(word.encode()).hexdigest(), 16)

            for i in range(self.hashbits):
                bit = (hash_val >> i) & 1
                if bit:
                    vector[i] += freq
                else:
                    vector[i] -= freq

        # 生成指纹
        fingerprint = 0
        for i in range(self.hashbits):
            if vector[i] > 0:
                fingerprint |= (1 << i)

        return fingerprint

    def hamming_distance(self, hash1, hash2):
        """
        计算汉明距离
        """
        xor = hash1 ^ hash2
        return bin(xor).count('1')

    def is_duplicate(self, text, threshold=3):
        """
        检查是否重复

        threshold: 汉明距离阈值,越小越严格
        """
        fingerprint = self.compute_simhash(text)

        for seen_fp in self.seen_hashes:
            if self.hamming_distance(fingerprint, seen_fp) <= threshold:
                return True

        self.seen_hashes.append(fingerprint)
        return False

Tokenization 工程

5.1 BPE训练

Python
class BPETokenizerTrainer:
    """
    BPE (Byte-Pair Encoding) 训练
    """

    def __init__(self, vocab_size=32000):
        self.vocab_size = vocab_size
        self.vocab = {}
        self.merges = []

    def train(self, texts):
        """
        训练BPE tokenizer

        Args:
            texts: 训练文本列表
        """
        from collections import defaultdict

        # 初始化词汇表(字符级别)
        word_freqs = defaultdict(int)
        for text in texts:
            words = text.split()
            for word in words:
                # 将单词拆分为字符序列
                chars = ' '.join(list(word)) + ' </w>'
                word_freqs[tuple(chars.split())] += 1

        # 初始化词汇表
        self.vocab = set()
        for word in word_freqs:
            self.vocab.update(word)

        # BPE合并
        num_merges = self.vocab_size - len(self.vocab)

        for i in range(num_merges):
            # 统计相邻pair频率
            pairs = defaultdict(int)
            for word, freq in word_freqs.items():
                for j in range(len(word) - 1):
                    pairs[(word[j], word[j+1])] += freq

            if not pairs:
                break

            # 找到最高频pair
            best_pair = max(pairs, key=pairs.get)
            self.merges.append(best_pair)

            # 合并pair
            new_vocab = set()
            new_word_freqs = defaultdict(int)

            for word, freq in word_freqs.items():
                new_word = []
                idx = 0
                while idx < len(word):
                    if idx < len(word) - 1 and (word[idx], word[idx+1]) == best_pair:
                        new_word.append(word[idx] + word[idx+1])
                        idx += 2
                    else:
                        new_word.append(word[idx])
                        idx += 1

                new_word_freqs[tuple(new_word)] += freq
                new_vocab.update(new_word)

            word_freqs = new_word_freqs
            self.vocab = new_vocab

            if (i + 1) % 1000 == 0:
                print(f"Merge {i+1}/{num_merges}: {best_pair} -> {best_pair[0] + best_pair[1]}")

        # 添加特殊token
        special_tokens = ['<pad>', '<unk>', '<s>', '</s>', '<mask>']
        self.vocab.update(special_tokens)

        return self

    def encode(self, text):
        """
        编码文本
        """
        words = text.split()
        tokens = []

        for word in words:
            word_tokens = list(word) + ['</w>']

            # 应用merge规则
            for merge in self.merges:
                new_tokens = []
                i = 0
                while i < len(word_tokens):
                    if i < len(word_tokens) - 1 and \
                       (word_tokens[i], word_tokens[i+1]) == merge:
                        new_tokens.append(merge[0] + merge[1])
                        i += 2
                    else:
                        new_tokens.append(word_tokens[i])
                        i += 1
                word_tokens = new_tokens

            tokens.extend(word_tokens)

        return tokens

    def decode(self, tokens):
        """
        解码token
        """
        text = ''.join(tokens)
        text = text.replace('</w>', ' ')
        return text.strip()

5.2 使用 Hugging Face Tokenizers

Python
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers

class ModernTokenizerTraining:
    """
    使用现代工具训练tokenizer
    """

    @staticmethod
    def train_bpe(corpus_files, vocab_size=32000, output_path="tokenizer.json"):
        """
        训练BPE tokenizer
        """
        # 初始化
        tokenizer = Tokenizer(models.BPE())

        # 预分词器
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

        # 训练器
        trainer = trainers.BpeTrainer(
            vocab_size=vocab_size,
            special_tokens=["<pad>", "<unk>", "<s>", "</s>", "<mask>"],
            min_frequency=2
        )

        # 训练
        tokenizer.train(files=corpus_files, trainer=trainer)

        # 解码器
        tokenizer.decoder = decoders.ByteLevel()

        # 保存
        tokenizer.save(output_path)

        return tokenizer

    @staticmethod
    def train_sentencepiece(corpus_files, vocab_size=32000, model_prefix="spm"):
        """
        训练SentencePiece tokenizer(Unigram)

        优势:
        - 直接处理raw text,不需要预分词
        - 对多语言支持更好
        - LLaMA, T5等使用
        """
        import sentencepiece as spm

        # 合并所有文件
        import tempfile
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as tmp:
            for file in corpus_files:
                with open(file, 'r', encoding='utf-8') as f:  # with自动管理文件关闭
                    tmp.write(f.read())
            combined_file = tmp.name

        # 训练
        spm.SentencePieceTrainer.train(
            input=combined_file,
            model_prefix=model_prefix,
            vocab_size=vocab_size,
            model_type='bpe',  # 或 'unigram'
            character_coverage=0.9995,
            num_threads=8,
            split_digits=True,
            allow_whitespace_only_pieces=True,
            byte_fallback=True,
            unk_piece='<unk>',
            bos_piece='<s>',
            eos_piece='</s>',
            pad_piece='<pad>'
        )

        return f"{model_prefix}.model"

高效数据加载

6.1 内存映射与流式加载

Python
class MemoryMappedDataset:
    """
    内存映射数据集

    适合:数据集大于内存的情况
    """

    def __init__(self, data_path, index_path):
        import mmap
        import json

        # 打开数据文件
        self.file = open(data_path, 'rb')
        self.mm = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_READ)

        # 加载索引
        with open(index_path, 'r') as f:
            self.index = json.load(f)

    def __getitem__(self, idx):  # 魔术方法:使对象支持dataset[idx]下标访问语法
        # 从索引获取位置
        start, end = self.index[idx]

        # 从内存映射读取
        self.mm.seek(start)
        data = self.mm.read(end - start)

        return json.loads(data.decode('utf-8'))  # json.loads将JSON字符串→Python对象

    def __len__(self):
        return len(self.index)

class StreamingDataset:
    """
    流式数据集

    适合:超大规模数据集,按需加载
    """

    def __init__(self, data_paths):
        self.data_paths = data_paths

    def __iter__(self):
        for path in self.data_paths:
            with open(path, 'r', encoding='utf-8') as f:
                for line in f:
                    yield json.loads(line)  # yield使__iter__成为生成器,逐行懒加载数据,不必全部载入内存

6.2 DataLoader 优化

Python
import torch
from torch.utils.data import IterableDataset, DataLoader

class EfficientDataLoader:
    """
    高效数据加载器配置
    """

    @staticmethod
    def create_dataloader(dataset, batch_size, num_workers=4):
        """
        创建优化的DataLoader
        """
        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,  # 加速GPU数据传输
            prefetch_factor=2,  # 预取批次
            persistent_workers=True,  # 保持worker进程
            drop_last=True  # 丢弃不完整批次
        )

    @staticmethod
    def create_streaming_dataloader(dataset, batch_size):
        """
        流式DataLoader(用于IterableDataset)
        """
        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=0,  # IterableDataset通常不用多worker
            pin_memory=True
        )

class PackedDataset(IterableDataset):
    """
    序列打包数据集

    将多个短序列打包成一个长序列,提高训练效率
    """

    def __init__(self, dataset, max_length=2048, tokenizer=None):
        self.dataset = dataset
        self.max_length = max_length
        self.tokenizer = tokenizer

    def __iter__(self):
        buffer = []
        buffer_length = 0

        for sample in self.dataset:
            tokens = self.tokenizer.encode(sample['text'])

            # 如果加入这个样本会超出限制,先yield当前buffer
            if buffer_length + len(tokens) > self.max_length:
                if buffer:
                    # 添加EOS token
                    packed = buffer + [self.tokenizer.eos_token_id]
                    # Padding
                    packed = packed + [self.tokenizer.pad_token_id] * (self.max_length - len(packed))
                    yield {'input_ids': torch.tensor(packed[:self.max_length])}  # yield产出值,函数变为生成器

                buffer = tokens
                buffer_length = len(tokens)
            else:
                buffer.extend(tokens)
                buffer_length += len(tokens)

数据混合策略

7.1 多数据源混合

Python
class DataMixer:
    """
    多数据源混合策略
    """

    def __init__(self, datasets, weights, seed=42):
        """
        Args:
            datasets: 数据源字典 {name: dataset}
            weights: 采样权重 {name: weight}
        """
        self.datasets = datasets
        self.weights = weights
        self.seed = seed
        self.rng = random.Random(seed)

    def sample(self):
        """
        按权重采样数据源
        """
        names = list(self.weights.keys())
        weights = [self.weights[name] for name in names]

        # 按权重选择数据源
        chosen = self.rng.choices(names, weights=weights, k=1)[0]

        # 从选中的数据源采样
        dataset = self.datasets[chosen]
        sample = next(iter(dataset))

        return sample, chosen

    def create_mixed_iterator(self, total_samples):
        """
        创建混合迭代器
        """
        for _ in range(total_samples):
            sample, source = self.sample()
            sample['source'] = source
            yield sample

# 数据混合示例配置
DATA_MIXING_CONFIGS = {
    "general_purpose": {
        "web": 0.60,
        "code": 0.15,
        "books": 0.10,
        "academic": 0.10,
        "conversation": 0.05
    },
    "code_focused": {
        "code": 0.50,
        "web": 0.30,
        "books": 0.10,
        "academic": 0.05,
        "conversation": 0.05
    },
    "academic_focused": {
        "academic": 0.40,
        "books": 0.30,
        "web": 0.20,
        "code": 0.05,
        "conversation": 0.05
    }
}

7.2 动态数据课程

Python
class CurriculumLearning:
    """
    课程学习:按难度逐步增加数据复杂度
    """

    def __init__(self, datasets_by_difficulty):
        """
        Args:
            datasets_by_difficulty: 按难度排序的数据源列表
                [(difficulty_level, dataset), ...]
        """
        self.datasets = datasets_by_difficulty
        self.current_stage = 0

    def get_current_data(self, training_progress):
        """
        根据训练进度获取数据

        Args:
            training_progress: 0.0 ~ 1.0
        """
        # 根据进度确定阶段
        num_stages = len(self.datasets)
        target_stage = int(training_progress * num_stages)
        target_stage = min(target_stage, num_stages - 1)

        # 混合当前阶段和下一阶段的数据
        if target_stage < num_stages - 1:
            # 渐进式过渡
            transition_ratio = training_progress * num_stages - target_stage

            current_data = self.datasets[target_stage][1]
            next_data = self.datasets[target_stage + 1][1]

            # 按transition_ratio混合
            return MixedDataset([current_data, next_data], [1-transition_ratio, transition_ratio])
        else:
            return self.datasets[target_stage][1]

class UpsamplingStrategy:
    """
    上采样策略:对高质量数据增加采样率
    """

    @staticmethod
    def quality_based_upsampling(dataset, quality_scores, target_ratio=2.0):
        """
        基于质量评分的上采样

        Args:
            dataset: 原始数据集
            quality_scores: 每个样本的质量分数
            target_ratio: 高质量数据的采样倍数
        """
        # 计算采样权重
        weights = [1.0 + (target_ratio - 1.0) * score for score in quality_scores]

        # 创建加权采样器
        from torch.utils.data import WeightedRandomSampler
        sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)

        return sampler

总结

数据工程关键要点

Text Only
1. 数据质量 > 数据数量
   - 高质量数据可以弥补数量不足
   - 垃圾数据会损害模型性能

2. 去重至关重要
   - 重复数据会导致过拟合
   - 近似去重比精确去重更重要

3. Tokenization影响模型能力
   - 词汇表大小影响模型效率
   - 多语言需要特殊处理

4. 数据混合需要策略
   - 不同任务需要不同数据配比
   - 课程学习可以提升训练效率

5. 高效加载是训练瓶颈
   - 数据加载不能成为训练瓶颈
   - 内存映射和流式加载解决大数据问题

推荐工具

工具 用途
Hugging Face Datasets 数据集加载和处理
Tokenizers 快速tokenizer训练
SentencePiece 多语言tokenization
Datasketch MinHash去重
Spark/Dask 大规模数据处理

下一步02-训练基础设施 - 学习分布式训练、混合精度和训练优化技术。


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026