Skip to main content

General multi-feature classification library based on pytorch

Project description

nymph

基于Pytorch的多特征分类框架

概述

基于Pytorch的多特征序列标注和普通分类框架,包装的还算可以。可以直接照搬demo,拿csv文件去训练预测。

功能

  • 多特征分类(特征包括字符型、数值型,其中字符型最好是单个词而非词组或句子)
  • 输出详细分类详情

原理

  • 预处理:对各列非数值类数据分别构建词表并使用Embedding获得低维稠密向量,对数值类数据进行标准化,然后拼接获得各行对应向量
  • 模型:
    • 普通分类:全连接神经网络,NormClassifier(具体效果看特征)
    • 序列标注:Bi-LSTM-CRF,SeqClassifier(效果较好)
  • 预测:使用sklearn获取f1分数,并且获得各类别分类详情

安装

使用如下命令进行安装

pip install -U nymph

使用示例

训练数据

数据可见test.csv

如图:

test_data

普通分类

训练模型

源码如下,具体可参见train_demo_by_norm.py

# -*- coding: utf-8 -*-
import os

import pandas as pd
from nymph.data import NormDataset, split_dataset
from nymph.modules import NormClassifier

project_path = os.path.abspath(os.path.join(__file__, '../../'))
data_path = os.path.join(project_path, r'data\test.csv')
save_path = 'demo_saves'

if __name__ == '__main__':
    # 读取数据
    data = pd.read_csv(data_path)
    # 构建分类器
    classifier = NormClassifier()
    classifier.init_data_processor(data, target_name='label')

    # 构建数据集
    norm_ds = NormDataset(data)

    train_ratio = 0.7
    dev_ratio = 0.2
    test_ratio = 0.1

    train_ds, dev_ds, test_ds = split_dataset(norm_ds, (train_ratio, dev_ratio, test_ratio))

    # 训练模型
    # classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)
    classifier.train(train_set=norm_ds, dev_set=norm_ds, save_path=save_path)

    # 测试模型
    test_score = classifier.score(norm_ds)
    print('test_score', test_score)

    # 预测模型
    pred = classifier.predict(norm_ds)
    print(pred)
训练结果

终端输出

train_demo_by_norm_result

预测模型

源码如下,具体可参见predict_demo_by_norm.py

# -*- coding: utf-8 -*-
import os

import pandas as pd
from nymph.data import NormDataset, split_dataset
from nymph.modules import NormClassifier

project_path = os.path.abspath(os.path.join(__file__, '../../'))
data_path = os.path.join(project_path, r'data\test.csv')
save_path = 'demo_saves'

if __name__ == '__main__':
    # 读取数据
    data = pd.read_csv(data_path)
    # 构建分类器
    classifier = NormClassifier()

    # 加载分类器
    classifier.load(save_path)

    # 构建数据集
    norm_ds = NormDataset(data)

    # 预测模型
    pred = classifier.predict(norm_ds)
    print(pred)

    # 获取各类别分类结果,并保存信息至文件中
    classifier.report(norm_ds, 'report.csv')

    # 对数据进行预测,并将数据和预测结果写入到新的文件中
    classifier.summary(norm_ds, 'summary.csv')
预测结果

如图:predict_demo_by_norm_result

report.csv内容

report

summary.csv内容

summary

序列标注

训练模型

源码如下,具体可参见train_demo_by_seq.py

# -*- coding: utf-8 -*-
import os

import pandas as pd
from nymph.data import SeqDataset, split_dataset
from nymph.modules import SeqClassifier

project_path = os.path.abspath(os.path.join(__file__, '../../'))
data_path = os.path.join(project_path, r'data\test.csv')
save_path = 'demo_saves_seq'


def split_fn(dataset: list):
    return list(range(len(dataset)+1))


if __name__ == '__main__':
    # 读取数据
    data = pd.read_csv(data_path)
    # 构建分类器
    classifier = SeqClassifier()
    classifier.init_data_processor(data, target_name='label')

    # 构建数据集
    norm_ds = SeqDataset(data, split_fn=split_fn, min_len=4)

    train_ratio = 0.7
    dev_ratio = 0.2
    test_ratio = 0.1

    train_ds, dev_ds, test_ds = split_dataset(norm_ds, (train_ratio, dev_ratio, test_ratio))

    # 训练模型
    # classifier.train(train_set=train_ds, dev_set=dev_ds, save_path=save_path)
    classifier.train(train_set=norm_ds, dev_set=norm_ds, save_path=save_path)

    # 测试模型
    test_score = classifier.score(norm_ds)
    print('test_score', test_score)

    # 预测模型
    pred = classifier.predict(norm_ds)
    print(pred)
训练结果

终端输出

train_demo_by_seq_result

预测模型

源码如下,具体可参见predict_demo_by_seq.py

# -*- coding: utf-8 -*-
import os

import pandas as pd
from nymph.data import SeqDataset, split_dataset
from nymph.modules import SeqClassifier

project_path = os.path.abspath(os.path.join(__file__, '../../'))
data_path = os.path.join(project_path, r'data\test.csv')
save_path = 'demo_saves_seq'


def split_fn(dataset: list):
    return list(range(len(dataset)+1))


if __name__ == '__main__':
    # 读取数据
    data = pd.read_csv(data_path)
    # 构建分类器
    classifier = SeqClassifier()

    # 加载分类器
    classifier.load(save_path)

    # 构建数据集
    seq_ds = SeqDataset(data, split_fn=split_fn, min_len=4)

    # 预测模型
    pred = classifier.predict(seq_ds)
    print(pred)

    # 获取各类别分类结果,并保存信息至文件中
    classifier.report(seq_ds, 'seq_demo_report.csv')

    # 对数据进行预测,并将数据和预测结果写入到新的文件中
    classifier.summary(seq_ds, 'seq_demo_summary.csv')

如图:predict_demo_by_seq_result

seq_demo_report.csv内容

seq_demo_report

seq_demo_summary.csv内容

seq_demo_summary

更新历史

  • 0.1.0: 初始化项目,增加全连接模型
  • 0.2.0: 增加序列标注模型,大幅重构项目结构与内部实现代码
  • 0.2.1: 更新代码,使GPU和CPU下同时可用
  • 0.2.2: 增加将训练过程的loss和f1值写入到TensorBoard中
  • 0.2.3: 增加Norm Classifier的TensorBoard功能

参考

  1. python - Sorting list based on values from another list? - Stack Overflow

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nymph-0.2.4.tar.gz (15.3 kB view details)

Uploaded Source

Built Distribution

nymph-0.2.4-py3-none-any.whl (25.1 kB view details)

Uploaded Python 3

File details

Details for the file nymph-0.2.4.tar.gz.

File metadata

  • Download URL: nymph-0.2.4.tar.gz
  • Upload date:
  • Size: 15.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/39.0.1 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.0

File hashes

Hashes for nymph-0.2.4.tar.gz
Algorithm Hash digest
SHA256 9ffe19769260b39902b09cb045207e221ddc9f649c817cde0363db55991b5125
MD5 d26a1b408cfec056ace2029a860ad4e1
BLAKE2b-256 2164fafca4971180142de839baa247df20ef6e9bb42a1fc35fe74996013f2fb7

See more details on using hashes here.

File details

Details for the file nymph-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: nymph-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 25.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/39.0.1 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.0

File hashes

Hashes for nymph-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 1c3e5678562667ebb558b4da691dbb8417dc2bda19db3a3f5424aaa8d7c3f1f0
MD5 1580a11b66c7766af13a5a32d08f4147
BLAKE2b-256 799854d91d5c51ed56980b0c1b6e397ae3ee6201db01258c57ca38d0f6ccedd5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page