Skip to main content

Some base methods for developing

Project description

tketool

这个Python模块主要用于文件和目录的操作,包括创建、读写、遍历等功能,同时还提供了对SFTP服务器上文件的路径获取。另外,它还包含了一些关于配置管理和日志记录的功能,如获取配置文件中的设置、日志的记录和查询等。此外,还提供了对数据进行哈希运算的方法和打印格式化表格的功能。

类成员:

name info
Service_Shelve Service_Shelve是一个配置管理类,主要用于获取配置文件中的设置,并实例化为各种资源对象。
Custom_Handler 这是一个自定义日志处理器,可以控制日志的输出方式,并维护最近的20条日志记录。
log_color_enum 这是一个定义日志颜色ANSI代码常量的枚举类,用于设置和重置日志文本颜色。
ConfigManager ConfigManager是一个处理配置文件读写的类,支持查询、获取配置项,并在不存在时添加默认值。

函数成员:

name info
create_folder_if_not_exists 这是一个创建指定路径文件夹的函数,如果路径已存在则不会报错,适用于多种操作系统。
write_file_line 这是一个将行列表写入指定文件的函数,可以选择是否移除每行的换行符。
read_file_lines 这是一个读取文本文件内容的函数,可以选择是否删除换行符,若读取失败返回空列表。
write_file 这是一个Python函数,用于将指定内容写入指定路径的文件,如果文件已存在,原内容会被覆盖。
read_file 这是一个Python函数,用于读取指定路径的文件内容,如果遇到错误则返回空列表。
enum_directories 这个函数用于遍历指定路径下的所有目录,并可选是否递归遍历子目录。
enum_files 这是一个可以递归枚举指定目录及其子目录下所有文件的生成器函数。
delete_files_in_directory 这是一个Python函数,用于递归删除指定目录及其子目录下的所有文件和链接。
get_file_path_in_sftpserver 这是一个Python函数,用于获取本地或SFTP服务器上文件的路径。
get_help_info 该函数用于从函数的文档字符串中提取参数的帮助信息。
hash_str 这是一个将任意对象转化为字符串并进行MD5哈希运算的函数。
hash_obj_strbase 这是一个Python函数,用于对各种类型的输入对象进行哈希处理并返回处理后的字符串形式。
set_logger 这个函数主要用于设置和替换目标日志记录器的处理程序。
log 这是一个使用当前日志处理器记录日志信息的函数,不支持控制台输出彩色日志。
log_multi_row 这是一个使用全局变量current_logger来记录多行日志的函数。
get_log_history 这是一个无参数的函数,用于获取当前操作句柄的日志历史记录,并返回一个列表。
convert_print_color 这是一个将输入参数转化为带颜色的字符串的函数,输入可以是字符串或包含字符串和颜色枚举的元组。
print_table 这个函数用于打印格式良好的表格,通过指定列名和数据行进行表格的构建。
get_config_instance 这个函数用于获取或创建指定配置文件的配置管理器实例,实现了单例模式。

create_folder_if_not_exists

创建一个文件夹,如果该文件夹不存在。这个函数的目标是为了确保当程序需要在某个特定的路径下写入文件或者创建新的文件夹时,这个路径确实存在。这也可以避免在程序试图访问不存在的路径时发生错误。

参数:

*args:一个或多个字符串,表示要创建的文件夹的路径。例如,create_folder_if_not_exists('folder1', 'folder2') 将在当前目录下创建一个名为 'folder1/folder2' 的文件夹。

返回值:

str:返回创建的文件夹的完整路径。

使用例子:

path = create_folder_if_not_exists('folder1', 'folder2')

注意:

该函数使用了os模块的os.path.join函数来连接路径。这是一个与平台无关的方式,可以在Linux、Windows等不同的操作系统上都能正常工作。同时,该函数也使用了os模块的os.makedirs函数来创建文件夹。其中的exist_ok参数设为True,表示如果路径已经存在,不会抛出错误,而是直接返回路径。

此外,该函数假定程序有相应的文件操作权限。如果没有,os.makedirs可能会抛出PermissionError异常。

write_file_line

这个函数的主要目的是将给定的行列表写入到指定的文件中。如果指定了ignor_n_char参数为True,那么在写入文件之前,会将行列表中的每一行中的"\n"字符进行移除。

参数:

path (str): 待写入的文件的路径。

lines (list): 需要写入的行列表,每个元素代表一行。

ignor_n_char (bool): 是否在写入文件之前移除每一行中的"\n"字符,默认为True。

返回类型:

无返回值

使用示例:

write_file_line("/tmp/test.txt", ["Hello\n", "World"], True)

注意事项:

在进行文件写入操作时,请确保拥有对应文件的写权限,否则可能会引发权限错误。

错误或异常:

如果给定的文件路径不存在,或者没有写权限,函数会抛出异常。

read_file_lines

这是一个简单的函数,用于读取文本文件的内容。此函数根据用户的需求,可以选择是否删除行尾的换行符。

参数:

path (str): 指定要打开的文件的路径。

replace_n_char (bool): 如果为True,则在返回的行尾会删除换行符。默认为False,即不删除换行符。

返回:

Liststr: 返回一个列表,其中包含文件中的所有行。如果文件打开失败,或者在读取文件过程中发生了其他错误,那么返回一个空列表。

示例:

lines = read_file_lines("example.txt", replace_n_char=True)

# 从"example.txt"文件中读取所有行,并删除每一行的行尾换行符。

错误和异常:

如果在尝试打开或读取文件的过程中发生了错误,此函数会打印出错误原因,并返回一个空列表。

write_file

这个函数是用来将特定的内容写入到指定的文件中。

Args:

path(str): 需要写入内容的文件的路径。

content(str): 需要写入的内容。

Returns:

None. 这个函数没有返回值,只是执行写入文件的操作。

Raises:

Exception: 如果写入文件过程中出现错误,会抛出异常。

Example:

write_file('example.txt', '这是一段示例内容。')

这个例子中,函数会将'这是一段示例内容。'这段话写入到'example.txt'这个文件中。

Note:

这个函数使用 'w' 模式打开文件,所以如果文件已经存在,它的原有内容会被新的内容覆盖。

read_file

此函数是用于读取指定路径的文件。

参数:

path (str): 文件的路径。

返回:

str: 返回读取的文件内容。如果读取过程中有任何错误,会返回一个空列表。

示例:

函数调用可以如下所示:

read_file('path/to/your/file.txt')

错误:

如果文件路径不存在或者文件无法打开,函数会捕获异常并打印错误信息。

enum_directories

enum_directories函数将遍历给定路径下的所有目录,并根据参数决定是否递归遍历子目录。

参数:

path (str): 需要遍历的目录路径。

recursive (bool, 可选): 是否递归遍历子目录。默认为False,表示只遍历顶层目录。

返回:

generator: 一个生成器,每次yield一个元组,包含两个元素,第一个元素是完整的目录路径,第二个元素是目录的名字。

使用方法:

for dir_path, dir_name in enum_directories('/path/to/directory', recursive=True):

print(f'Found directory: {dir_path} with name: {dir_name}')

错误或者bug:

如果输入的路径不存在或者不是一个目录,函数会抛出异常。

enum_files

这是一个枚举目录中文件的函数。

参数:

path (str): 需要枚举的目录路径。

recursive (bool): 是否需要递归枚举子目录中的文件,如果设为False,则只枚举顶层目录中的文件。

返回:

generator: 这是一个生成器函数,每次调用都会返回下一个文件的完整路径和文件名。

例子:

# 枚举当前目录下的所有文件

for full_path, file_name in enum_files('.'):

print(full_path, file_name)

# 枚举当前目录及其子目录下的所有文件

for full_path, file_name in enum_files('.', True):

print(full_path, file_name)

注意:

该函数依赖于os模块,使用前请确保已经正确导入os模块:import os。

delete_files_in_directory

删除给定目录下的所有文件和子目录。

这个函数会遍历给定目录,对于目录下的每一个文件或目录,如果是文件或者链接,直接删除;如果是目录,递归删除该目录及其包含的所有文件和子目录。

参数:

directory (str): 需要删除文件的目录路径

返回:

None

错误:

如果删除文件或目录过程中出现错误,会打印错误信息,但不会中断程序。

使用示例:

delete_files_in_directory('/path/to/directory')

注意:

调用这个函数需要谨慎,因为它会删除指定目录下的所有文件和子目录,且不可恢复。

get_file_path_in_sftpserver

这是一个Python函数,用于在SFTP服务器中获取文件的路径。如果在本地路径中找不到文件,那么函数将会尝试连接SFTP服务器并从中下载文件。

参数:

folder_path (str): 本地文件夹路径,用于和文件名拼接成文件的完整本地路径。

filename (str): 需要获取路径的文件的文件名。

endpoint (str): SFTP服务器的IP地址或者主机名。

access_user (str): 登陆SFTP服务器的用户名。

secret_pwd (str): 登陆SFTP服务器的密码。

remote_folder_path (str): 远程SFTP服务器文件夹路径,用于和文件名拼接成文件的完整远程路径。

port (int, optional): SFTP服务器的端口,默认为22。

返回:

str: 文件的本地路径。

示例:

local_path = get_file_path_in_sftpserver('/local/folder', 'test.txt', 'sftp.server.com', 'user', 'password', '/remote/folder')

print(local_path) \# 输出: '/local/folder/test.txt'

注意:

  • 如果本地文件不存在,且尝试从SFTP服务器下载时发生错误,那么函数会抛出异常。

  • 函数使用了paramiko库建立SSH连接和SFTP客户端,确保在使用前已安装此库。

  • 函数使用了自定义的log函数记录日志,请在同一个作用域内定义此函数,否则将会抛出NameError异常。

get_help_info

Extract parameters help info from function's docstring.

Service_Shelve

这个类Service_Shelve是一个配置管理类,它主要用于获取指定配置文件中的各种设置,如数据源路径、模型URL、API令牌等。

它通过读取配置文件,将配置内容实例化为config对象,并提供获取各种资源(如数据源、模型等)的方法。

示例:


service_shelve = Service_Shelve("config_file_path")

datasource = service_shelve.get_datasource()

model = service_shelve.get_llm_GLM6B()

方法列表:

  • \_\_init\_\_(self, config_file_path): 构造函数,接收一个config_file_path,这个路径是配置文件的路径。该函数将配置文件实例化为config对象。

  • get_config(self): 返回config对象。

  • get_datasource(self): 从配置文件中获取sample_source_path,然后根据路径实例化一个LocalDisk_NLSampleSource对象。

  • get_llm_GLM6B(self, buffer_version=-1): 从配置文件中获取glm_url配置,并将其实例化为ChatGLM对象。

  • get_llm_GPT4(self, buffer_version=-1): 从配置文件中获取openai_proxy、openai_token、openai_temperature配置,并将相关配置实例化为ChatGPT4对象。

  • get_llm_GPT35(self, buffer_version=-1): 从配置文件中获取openai_proxy、openai_token、openai_temperature配置,并将相关配置实例化为ChatGPT3对象。

  • get_ft_model(self, model_id, buffer_version=-1, **kwargs): 从配置文件中获取openai_proxy、openai_token、openai_temperature配置,并将相关配置与提供的model_id、**kwargs实例化为FineTuned_Completion_Model对象。

  • get_emb(self, buffer_version=-1, **kwargs): 从配置文件中获取openai_token、openai_proxy配置,并将相关配置实例化为Openai_embedding对象。

__init__

初始化Service_Shelve类。

这个类是一个服务存储器,负责管理和提供各种服务。这些服务可能包括数据源、多种语言模型、嵌入模型等。Service_Shelve通过读取一个配置文件来获取这些服务的配置信息,根据这些配置信息初始化相应的服务。

参数:

config_file_path (str): 配置文件的路径。这个配置文件中包含了各种服务的配置信息。

属性:

config_obj: 根据给定的配置文件路径生成的配置对象。这个配置对象之后会被用来获取各种服务的配置信息。

示例:

service_shelve = Service_Shelve("path/to/config/file")

glm6b_service = service_shelve.get_llm_GLM6B()

gpt4_service = service_shelve.get_llm_GPT4()

emb_service = service_shelve.get_emb()

get_config

获取配置对象。

此方法是类Service_Shelve的一个方法,用于获取一个保存了某配置文件信息的对象。该对象可根据需要获取配置文件中的指定信息。

返回:

obj: 一个配置文件对象。可以根据需求获取配置文件中的指定信息。

示例:

service_shelve = Service_Shelve(config_file_path="config/path")

config_obj = service_shelve.get_config()

get_datasource

该函数用于获取数据源。

获取数据源的具体方式是通过配置文件(由self.config_obj所代表)来获取sample_source_path的值,该值代表数据源在本地磁盘的存储路径。然后,使用LocalDisk_NLSampleSource类来实例化一个数据源对象,该对象的特点是从本地磁盘读取样本来源。

函数没有参数。

返回类型是LocalDisk_NLSampleSource类的一个实例,代表一个可以从本地磁盘读取样本来源的数据源对象。

使用示例:

service_shelve = Service_Shelve(config_file_path)

datasource = service_shelve.get_datasource()

注意:如果配置文件中没有sample_source_path的配置项,或者该配置项的值不是一个有效的本地磁盘路径,可能会在运行时抛出异常。

get_llm_GLM6B

此函数的目的是从配置文件获取GLM模型的URL,并返回GLM模型实例。

参数:

buffer_version(int, 可选): 版本号,默认为-1。

返回:

返回ChatGLM类的实例,这是一个机器学习模型。

例子:

service_shelve = Service_Shelve('config_file_path')

glm_model = service_shelve.get_llm_GLM6B()

注意:

需要先确保配置文件中存在"glm_url"这个配置项,并且其值是GLM模型的URL。

get_llm_GPT4

这个方法用于获取ChatGPT4模型实例。

参数:

buffer_version (int, optional): 版本号,默认是-1。

返回:

ChatGPT4: 返回带有指定参数的ChatGPT4实例。

解释:

该方法从配置对象中获取所需的配置信息(如代理设置、API令牌和温度参数)。然后,它使用这些配置信息创建一个新的ChatGPT4实例,并返回。

在函数内部,首先通过调用self.config_obj.get_config方法获取必要的配置,比如"openai_proxy"、"openai_token"和"openai_temperature"。

然后,将获取的temperature转换为float类型,并设置到配置字典config_dict中。

最后,使用获取到的api_token、proxys、config_dict和buffer_version作为参数,创建一个新的ChatGPT4对象,并返回。

例如:

service_shelve = Service_Shelve(config_file_path)

gpt4 = service_shelve.get_llm_GPT4(buffer_version=0)

注意:

如果配置文件中没有设置相关的配置信息,"openai_proxy"和"openai_token"将默认为"not_set","openai_temperature"将默认为"0.8"。

get_llm_GPT35

此函数的目的是获取GPT-3.5语言模型的实例。

参数:

buffer_version (int, 默认值为-1): 指定的缓存版本,如果没有指定,则默认为-1。

返回:

返回的是一个ChatGPT3实例。该实例是使用从配置对象中获取的代理和API令牌创建的,温度参数也从配置对象中获取。

使用示例:

service_shelve = Service_Shelve(config_file_path)

gpt3_5_llm = service_shelve.get_llm_GPT35(buffer_version=1)

注意事项:

  1. 这个函数从配置对象中获取了openai代理、openai令牌和openai温度这些配置。所以,应确保在调用此函数之前已经正确设置了这些配置。

  2. 如果在创建ChatGPT3实例时发生错误,此函数不会捕获和处理这些错误,所以你需要自行处理可能出现的异常。

get_ft_model

此函数的主要目的是根据指定的模型ID和其他参数,获取一个经过微调的模型。

参数:

model_id (str) : 模型ID。

buffer_version (int, 可选) : 缓冲区版本,默认为 -1。

**kwargs : 其他一些需要传递给模型的参数。

返回:

FineTuned_Completion_Model 对象:一个经过微调的模型对象。

示例:

service = Service_Shelve(config_file_path)

ft_model = service.get_ft_model('gpt-3.5-turbo')

result = ft_model.predict('Hello, how are you?')

注意:

  • 在使用这个函数时,需要保证已经配置了openai代理和openai的api token,以及openai的温度参数,否则可能无法正确获取模型。

  • 如果传递给此函数的buffer_version参数为负数,则会自动使用默认的buffer_version。

  • 如果函数无法获取指定的模型ID,则可能会抛出异常。

错误:

  • 如果获取的模型不存在或无法访问,将抛出异常。

get_emb

get_emb是Service_Shelve类的一个方法,用于获取嵌入模型。

参数:

buffer_version(int, 可选): 缓冲区版本,默认为-1,表示使用最新版本。

**kwargs(dict, 可选): 允许用户提供额外的配置参数。

返回:

返回值是一个Openai_embedding实例。此实例可用于获取特定文本的嵌入表示。

举例:

service_shelve = Service_Shelve(config_file_path="config.json")

emb_model = service_shelve.get_emb()

text_embedding = emb_model.get_embedding("This is a test sentence")

注意:

为了使用此函数,需要在配置文件中提供有效的OpenAI API令牌和代理服务器详细信息(如果使用)。如果这些信息未正确配置,函数可能会引发异常。

错误与异常:

如果给定的buffer_version无效,或者OpenAI API令牌或代理服务器详细信息未正确配置,此函数可能会引发异常。

hash_str

这是一个将字符串哈希化的函数。

Args:

target_str (str): 对这个字符串进行哈希化操作。

Returns:

str: 返回一个经过MD5哈希运算后的字符串。

注意:如果传入的不是字符串类型,会先将其转化为字符串类型然后进行哈希运算。

示例:

print(hash_str('hello world')) \# 输出: 5eb63bbbe01eeed093cb22bb8f5acdc3

函数并不会检查传入参数的有效性, 如果传入对象不能被转化为字符串, 则会抛出一个ValueError异常。

hash_obj_strbase

该函数的主要功能是对输入的对象进行哈希处理,转换成字符串的形式。它能够对不同类型的对象进行处理,包括字符串、整数、浮点数、字典以及可迭代的对象。

参数:

obj: 待处理的对象,可以是任何类型

返回:

对输入对象进行哈希处理后的字符串

处理过程:

  1. 如果输入的对象是字符串、整数、浮点数,则直接转换为字符串并进行哈希处理

  2. 如果输入的对象是字典,则将字典的键和值分别转换为列表,然后进行哈希处理,最后将处理后的键和值的哈希结果进行连接并进行哈希处理

  3. 如果输入的对象是可迭代的对象,则将每个元素转换为字符串并进行哈希处理,然后将所有元素的哈希结果连接成一个字符串并进行哈希处理

  4. 如果输入的对象是其他类型,则将对象的类名转换为字符串并进行哈希处理

注意事项:

这个函数没有处理递归引用的情况,例如,列表或字典等数据结构中包含自身的引用,这会导致无限递归。如果输入的对象中存在这种情况,函数可能会出现堆栈溢出的错误。

示例:

print(hash_obj_strbase("hello")) # 输出一个字符串的哈希值

print(hash_obj_strbase({"name": "Tom", "age": 20})) # 输出一个字典的哈希值

print(hash_obj_strbase([1, 2, 3])) # 输出一个列表的哈希值

print(hash_obj_strbase((1, 2, 3))) # 输出一个元组的哈希值

Custom_Handler

这个类是一个自定义的日志处理器,它继承自logging.Handler

这个处理器的主要功能是:将日志信息输出到标准输出,并且提供了一些特殊的处理,比如处理进度条显示和历史记录。

这个处理器的工作原理如下:

  • 当日志等级为62时,表示这是一个进度条结束的信号,它会将进度条的标志位设置为False,并输出‘process finish.’信息。

  • 当日志等级为60时,表示这是一个进度条开始的信号,它会将进度条的标志位设置为True,并在标准输出上显示进度条。

  • 当日志等级为61时,表示这是一个普通的日志信息,但是需要在进度条下方打印,所以它会先输出一个回车符\\r,然后输出日志信息,最后再输出进度条。

  • 对于其他日志等级,也会根据进度条的标志位来决定是在进度条下方打印,还是直接打印。

这个处理器还维护了一个最近输出的20条日志的历史记录队列。

这个类的使用方法如下:

import logging

handler = Custom_Handler()

logger = logging.getLogger(\_\_name\_\_)

logger.addHandler(handler)

\# 输出普通日志

logger.log(61, "This is a normal log.")

\# 开始进度条

logger.log(60, "Start progress bar...")

\# 结束进度条

logger.log(62, "End progress bar.")

__init__

Custom_Handler类的初始化函数。

该类是用于特殊的日志处理,包括处理进度条显示以及历史记录的保存等功能。在初始化的时候,会设定一些默认的参数。

Attributes:

in_processbar (bool): 一个标志位,用于判断当前是否在处理进度条显示。

history (deque): 一个双端队列,用于保存最近的20条历史记录。

processbar_str_temp (str): 一个临时字符串,用于保存当前的进度条显示。

示例:

handler = Custom_Handler()

logger = logging.getLogger('your_logger')

logger.addHandler(handler)

注意:

此类特定于处理带有进度条的日志。对于普通的日志,可能不适用。

emit

这是 Custom_Handler 类中的一个方法,用于将记录的日志信息输出至标准输出。此方法根据记录的日志级别进行不同的处理。

参数:

record: logging.LogRecord 对象,包含了所有要输出的日志信息。

返回类型: 无返回值。

函数流程:

  1. 如果日志级别为62,表明进度条已完成,它将向标准输出写入'\r'和'process finish.',然后返回。

  2. 如果日志级别为60,表明进度条正在进行中,它将替换消息中的换行符,然后将消息写入到标准输出,如果此时不在进度条中,它还会在消息前增加一个换行符。

  3. 如果日志级别为其他值,它会根据是否在进度条中进行不同处理。如果在进度条中,它将在消息前后各增加一个换行符并将消息和进度条一起写入到标准输出。如果不在进度条中,它将在消息前增加一个换行符,然后将消息写入到标准输出。

特殊处理:

如果日志级别为61,它将不会替换消息中的换行符。

注意:

无论何时,都会刷新标准输出,并将消息添加到历史队列中。

错误或者bug: 无特殊说明。

set_logger

这个函数用于设置日志记录器的处理程序。

参数列表:

target_logger (logging.Logger): 需要设置处理程序的目标日志记录器

返回类型:

此函数将遍历目标日志记录器中的所有处理程序,并将它们从记录器中移除。

然后,它将当前的处理程序添加到目标日志记录器中。

注意,此函数假定在调用此函数之前,已经创建并配置了名为current_handle的处理程序。

示例:

import logging

logger = logging.getLogger(__name__)

handler = logging.StreamHandler()

set_logger(logger, handler)

这段代码创建了一个名为__name__的日志记录器,并将默认流处理程序设置为处理程序。

然后,它使用set_logger函数将处理程序设置为当前处理程序。

注意: 此函数存在一个潜在的问题,那就是它会移除目标日志记录器中的所有处理程序,而不仅仅是要替换的那一个。

如果记录器在其他地方也在使用,这可能会导致问题。在使用此函数之前,最好确认目标记录器的处理程序是否真的需要被完全替换。

log

这是一个简单的日志记录函数。此函数首先通过全局变量current_logger获取当前的日志处理器,然后调用其info方法记录日志。此函数没有返回值。

参数:

str: 需要被记录的日志信息,类型为字符串。

示例:

log("This is a test log.") # 输出: This is a test log.

注意:

尽管在函数体中存在一个被注释掉的print语句,但请不要取消注释并使用,因为它可能会在不支持ANSI escape code的环境中造成问题。该语句设计为在控制台输出彩色日志,\033[0m为ANSI escape code,用于重置颜色。

错误和bug:

暂时没有发现错误和bug。

log_multi_row

这是一个记录多行日志的函数。

函数的工作原理是使用全局变量 current_logger,并调用其 log 方法,实现多行日志的记录。

参数:

str (str): 需要记录的多行日志的字符串。

返回类型:

无。

使用示例:

log_multi_row("这是一个\n多行日志")

注意事项:

  1. 本函数没有返回值,其功能仅仅是记录日志。

  2. 在使用本函数前,需要确保全局变量 current_logger 已经被正确初始化并可以使用。

  3. 由于 log 方法的第一个参数是 61,因此本函数可能仅适用于某些特定配置的日志系统。

get_log_history

这是一个函数,获取当前句柄的日志历史记录。

函数没有参数。

返回类型是列表,其中包含当前句柄的历史记录。

示例:

log_history = get_log_history()

注意:此函数使用了全局变量current_handle,确保在调用此函数前已正确初始化此全局变量。

没有已知错误或bug。

log_color_enum

这是一个枚举类log_color_enum,其目的是定义一系列关于日志颜色的常量。这些常量与ANSI颜色代码相对应。

枚举值包括:

  • DEFAULT: 默认颜色,无特殊颜色代码。

  • RED: 红色代码。

  • YELLOW: 黄色代码。

  • GREEN: 绿色代码。

  • BLUE: 蓝色代码。

  • MAGENTA: 洋红色代码。

  • CYAN: 青色代码。

每个枚举值都是由ANSI颜色代码字符串表示。例如,对于红色,其ANSI颜色代码是"\033[91m"。

使用示例:

print(f"{log_color_enum.RED.value}This is red text{log_color_enum.DEFAULT.value}")

在上述示例中,我们使用log_color_enum.RED.value获取红色的ANSI代码,然后将其添加到需要显示为红色的文本前面。然后,我们添加DEFAULT的ANSI代码,以重置颜色到默认状态。这样,任何在这之后打印的文本都将以默认颜色显示,而不是红色。

注意,在某些环境(如Windows的某些版本)中,ANSI颜色代码可能无法正常工作。可能需要使用第三方库(如colorama)来启用ANSI颜色支持。

没有已知的错误或bug。

convert_print_color

此函数接收多个参数(可能是字符串或者元组),并根据参数类型进行处理。如果参数是元组,且元组长度为2,元组的第二个元素是颜色枚举类型,那么该元组被认为是包含字符串和颜色枚举的元组,函数会将其转化为带有颜色的字符串;如果参数不符合前述条件,那么直接认为该参数是字符串,进行处理。

Args:

*args: 可变参数。每个参数可能是一个字符串,或者一个包含字符串和颜色枚举的元组。

Returns:

str: 返回由输入参数转化而来的字符串,如果参数是包含字符串和颜色枚举的元组,转化后的字符串将带有颜色。

Example:

convert_print_color('hello', ('world', log_color_enum.RED))

# 输出:'hello\033[31mworld\033[0m'

Note:

需要注意,对于元组参数,其第二个元素必须是颜色枚举类型,否则函数会抛出异常。

print_table

这个函数的主要作用是打印一个格式良好的表格,表格的列名由参数 table_col 提供,表格的数据由参数 rows 提供。

参数:

table_col (liststr): 表格列的名字,这是一个字符串列表。

rows (list[liststr]): 表格的数据,这是一个二维字符串列表,每个子列表代表一行数据。

truncate_string (int, optional): 如果某个单元格的字符串长度超过这个值,则会被截断。默认值为30。

返回值:

无。这个函数没有返回值,它的主要目标是提供一个优雅的表格输出。

例子:

cols = ["姓名", "年龄", "职业"]

data = [

["张三", "27", "工程师"],

["李四", "31", "医生"],

["王五", "25", "教师"]

]

print_table(cols, data)

注意事项:

  1. 当表格数据的行数和 table_col 的长度不一致时,将会引发 Exception

  2. truncate_string 参数不能为负值,否则会引发 Exception

ConfigManager

这是一个ConfigManager类,其目的是用来处理配置文件的读取和写入。它将读取指定的配置文件,并将其内容存储为字典格式以供查询。如果尝试获取不存在的配置项,它会使用默认值,并在配置文件中添加这个键值对。

以下是一个使用示例:

cm = ConfigManager("/path/to/config")

value = cm.get_config("key", "default")

类方法介绍:

  • \_\_init\_\_: 初始化方法,接收一个配置文件路径参数。在创建类实例时,会立即加载配置文件。

  • _load_configs: 私有方法,负责加载配置文件。如果文件存在,就读取所有行并将其转换为字典格式。如果文件不存在,就创建一个新的空文件。

  • _sanitize_string: 静态方法,接收一个字符串,去掉其开头和结尾的双引号或单引号,然后返回。

  • get_config: 公有方法,接收一个键和一个默认值参数。如果键存在于配置文件中,则返回对应的值;如果不存在,则返回默认值,并将这个键值对添加到配置文件中。

注意:如果配置文件的写入权限被禁止,这个类可能会抛出异常。

__init__

ConfigManager是一个用于配置管理的类。

创建对象时,用户需要提供配置文件的路径作为参数。在创建对象后,该类会自动加载指定路径下的配置文件,以便后续通过键值对的方式获取配置信息。

如果在获取配置信息时,指定的键不存在,该类会自动在配置文件末尾添加该键,并设置其值为默认值。

示例:

config_manager = ConfigManager("/path/to/config/file")

db_host = config_manager.get_config("db_host", "localhost")

参数:

config_file_path: 配置文件的路径。可以是相对路径或绝对路径。如果指定的路径不存在,会自动创建一个新的空白配置文件。

注意:

  1. 配置文件中的键和值都是字符串形式,键值对之间通过等号("=")分隔,例如:db_host=localhost

  2. 此类没有提供修改配置信息的方法,如果需要修改配置信息,需要直接修改配置文件,然后重新创建ConfigManager对象以加载新的配置信息。

_load_configs

_load_configs是一个私有方法,用于加载配置文件中的配置。如果配置文件存在,它会读取文件中的每一行,并通过等号('=')将每一行分割为key和value,并将它们加入到_config_map字典中。同时,为了确保配置项的值是干净的,我们使用_sanitize_string方法去除两边可能存在的引号。

如果配置文件不存在,此方法会创建一个新的空的配置文件。

注意,此方法没有返回值。

执行这个方法不需要任何参数。

这个方法在初始化ConfigManager类的时候会被自动调用,用于加载配置文件的内容,无需手动调用。

_sanitize_string

_sanitize_string是一个静态方法,用于处理输入的字符串,去除字符串两边的双引号和单引号。

参数:

s (str): 待处理的字符串。

返回:

str: 返回处理后的字符串。

例如:

输入字符串为' "example" ',调用此方法后,返回的字符串为'example'。

注意:

该方法仅处理字符串前后的引号,不处理字符串中间的引号。

get_config

此函数是用于获取指定键的配置值。如果键存在于配置映射中,则直接返回其对应的值;否则,会给出警告并在配置文件中添加键和默认值,并返回默认值。

参数:

key (str): 需要获取的配置的键。

default_value (str, 可选): 如果键不存在于配置映射中,将使用此默认值。默认值为""。

返回:

str: 返回的配置值,如果键不存在于配置映射中,返回的是默认值。

注意:

如果键不存在,此函数将自动在配置文件中创建该键,并设置其值为默认值,然后返回默认值。

get_config_instance

这个函数的目标是得到一个配置管理器实例。

如果文件名参数为空,函数将获取当前工作目录下名为'config.jconfig'的文件。

如果该文件名还未在_config_instance全局变量中,函数将创建一个新的配置管理器实例,

并将其在_config_instance中与该文件名关联起来。

这个函数使用了全局变量_config_instance来存储所有已创建的配置管理器实例,

并以文件名作为每个实例的键。这种设计模式被称为单例模式,通过它,可以确保对于同一个配置文件,

总是返回同一个配置管理器实例。

参数:

filename: str, optional, default=None

配置文件的名字。如果没有指定,将使用当前工作目录中名为'config.jconfig'的文件。

返回:

ConfigManager实例。根据提供的文件名,返回对应的配置管理器实例。如果该文件名还未在

_config_instance中,则创建一个新的实例并返回。

错误:

如果指定的文件不存在,ConfigManager的构造函数将抛出异常。

示例:

# 获取默认配置文件的配置管理器实例

config_instance = get_config_instance()

# 获取指定配置文件的配置管理器实例

config_instance = get_config_instance('my_config.jconfig')

注意:

此函数依赖于_config_instance全局变量和ConfigManager类,需要确保在使用此函数之前,

这些全局变量和类已经被正确地定义和初始化。

pyml

这个Python模块主要用于基于PyTorch的深度学习模型训练,提供了一套完整的训练、验证及优化工具。它包括模型训练类、优化类和状态跟踪类等组件,有利于实现模型的训练、优化和状态追踪。模块还使用了枚举类型来定义设备使用方式、数据类型和模型参数更新模式,提高代码的可读性和维护性。此外,模块还提供了插件调用和训练行为的定义工具,用于定制和扩展模型训练的行为。

类成员:

name info
pymodel_trainer 'pymodel_trainer'是一个基于PyTorch的模型训练类,负责模型的训练、验证和插件调用。
pytrainer_deepspeed 'pytrainer_deepspeed'是一个类,用于使用DeepSpeed优化库训练并优化PyTorch模型。
plugin_invoke_Enum plugin_invoke_Enum是一个表示插件调用状态的枚举类,涵盖了从插件执行开始到结束的各个阶段。
device_use_enum 这是一个名为device_use_enum的枚举类型类,主要用于定义设备使用方式,如自动或CPU。
dtype_enum dtype_enum是一个枚举类,用于标记和识别不同的数据类型,方便代码中的引用和比较。
update_mode_enum 这是一个枚举类update_mode_enum,用于定义不同的模型参数更新模式,提高代码的可读性和维护性。
global_state_board global_state_board是一个在训练过程中进行全局状态记录和管理的类。
epoch_state_board epoch_state_board是一个类,用于追踪和记录神经网络每个训练周期的状态信息。
step_state_board step_state_board是一个用于跟踪和记录训练过程中每一步的状态和相关信息的类。
trainer_plugin_base 这是一个抽象基类,提供训练插件的基础结构和工具,定义了训练插件的基本行为,并且需要子类实现特定训练行为。
pytrainer_accelerate 'pytrainer_accelerate'类用于在PyTorch和Accelerator库的帮助下,在指定的CPU或GPU上训练模型。

函数成员:

name info
convert_to_list 这是一个将整数、浮点数、列表或Tensor转换为列表的函数,能处理1D或2D的Tensor。
invoke_at 这是一个装饰器生成器,用于给函数添加识别标签,以便在特定类型的插件调用时触发。

pymodel_trainer

pymodel_trainer类是一个基于PyTorch的模型训练器,继承自trainer_plugin_base。该类主要负责模型的训练、验证以及训练进程的插件调用。

主要成员函数:

  • __init__: 初始化函数,构造模型训练器

  • train: 训练模型

  • evaluate: 对模型进行评价

  • invoke_model: 调用模型进行预测

  • calculate_loss: 计算损失

  • zero_grad: 重置梯度

  • step: 执行一步优化

  • backward: 执行反向传播

初始化函数__init__的参数列表如下

  • model: 需要训练的PyTorch模型

  • loss_obj: 用于计算模型损失的损失函数

  • update_mode: 模型参数更新的模式,每一步更新或者每一轮更新

  • output_folder: 模型输出文件夹的路径

  • plugins: 训练插件列表

  • optimizer_type: 优化器的类型,默认为"adamw"

  • learn_rate: 学习率,默认为0.01

train函数的参数列表如下

  • sample_set: 需要训练的样本集合

  • epoch: 训练的轮数

  • input_convert_func: 输入数据转换函数,默认不变

  • label_convert_func: 标签数据转换函数,默认不变

evaluate函数的参数列表如下

  • sample_set: 需要评价的样本集合

  • input_convert_func: 输入数据转换函数,默认不变

  • label_convert_func: 标签数据转换函数,默认不变

  • logit_convert_func: logits转换函数,默认不变

  • scores: 评价指标列表,默认包括准确率、精确率、召回率和F1分数

注:这个类没有明显的错误或bug。

Invoke

Invokepymodel_trainer类中的一个方法。该方法目前为空,没有执行任何操作。

参数:

  • global_state : global_state_board对象,用于存储所有全局状态,包括模型、优化器等。

  • epoch_state : epoch_state_board对象,用于存储当前epoch的状态,包括损失等。

  • step_state : step_state_board对象,用于存储当前步骤的状态,包括输入、输出、损失等。

返回类型:无

示例:

trainer = pymodel_trainer(...)

global_state = global_state_board(...)

epoch_state = epoch_state_board(...)

step_state = step_state_board(...)

trainer.Invoke(global_state, epoch_state, step_state)

注意:这个方法目前为空,没有执行任何操作。尚未发现错误或者bug。

__init__

这是一个pymodel_trainer类的初始化函数,用于初始化训练模型的各个参数和属性。

参数:

  • model: 一个torch.nn.Module对象,表示模型。

  • loss_obj: 损失函数对象,用于计算模型的损失。

  • update_mode: 更新模型的方式,有按步(Per_Step)和按批(Per_Epoch)两种,取值来自update_mode_enum,默认是按步更新。

  • output_folder: 模型输出的文件夹路径,如果为None,则按时间戳创建一个新的文件夹,否则使用给定的路径。

  • plugins: 插件列表,默认为空。插件可以用来扩展模型的功能。

  • optimizer_type: 优化器的类型,默认为"adamw"。优化器用于更新模型的参数。

  • learn_rate: 学习率,默认为0.01。学习率决定了模型参数更新的步长。

在这个初始化函数中,将会对模型、损失函数和优化器等重要属性进行初始化设置,并创建用于存储模型的文件夹。同时,也会根据插件列表来进行插件的初始化。

注意,该函数没有返回值。

使用示例:

model = torch.nn.Linear(10, 1)

loss_obj = torch.nn.MSELoss()

trainer = pymodel_trainer(model=model, loss_obj=loss_obj, output_folder='model_path')

在这个例子中,我们创建了一个线性模型和均方误差损失函数,并且使用pymodel_trainer类进行了初始化。在初始化过程中,我们传入了模型和损失函数,并指定了模型的输出文件夹路径。

错误和警告:

本函数中,optimizer_type的取值应为optimizer_dict中的键,否则会引发错误。

_invoke_plugin

_invoke_plugin 是一个私有方法,用于在训练过程中的某些特定阶段调用插件函数。例如,在每一个训练批次开始时,或者在每一个训练批次结束后,可能需要执行一些特定的操作(例如,记录日志、更新学习率等)。这些操作可以通过编写插件函数,并在合适的时机调用这些插件函数来实现。

参数:

  • plugin_enum: 一个枚举值,指定当前训练的阶段。例如,可以是 'Batch_begin'(在一个训练批次开始时),'Batch_end'(在一个训练批次结束后)等。

  • base_wall: 一个全局状态板对象,包含了全局的训练状态,例如,当前的模型、优化器、损失函数等。

  • epoch_wall: 一个周期状态板对象,包含了当前训练周期的状态,例如,当前周期的损失值、准确率等。

  • batch_wall: 一个步骤状态板对象,包含了当前训练批次的状态,例如,当前批次的输入数据、目标数据、模型输出等。

返回:

  • 无返回值

使用方法:

  • 这是一个私有方法,通常不会直接在类外部调用。而是在训练过程中的特定阶段,例如,一个训练批次开始时,调用 '_invoke_plugin(plugin_enum.Begin, base_wall, epoch_wall, batch_wall)',在一个训练批次结束后,调用 '_invoke_plugin(plugin_enum.End, base_wall, epoch_wall, batch_wall)' 等。

注意事项:

  • 如果 plugin_enum 对应的插件函数列表为空,或者不存在,那么 '_invoke_plugin' 方法会直接返回,不会执行任何操作。

可能的错误:

  • 如果 plugin_enum 不是预定义的枚举值,'_invoke_plugin' 方法可能无法正确地找到并执行对应的插件函数。

_statistics

这是一个名为_statistics的函数,它的主要功能是统计和记录训练过程中模型的参数信息。

参数:

  • global_state(global_state_board类型):全局状态板,主要用于存放全局级别的训练信息,如模型、优化器等。

  • epoch_state(epoch_state_board类型):历元状态板,主要用于存放每个历元级别的训练信息,如当前历元的损失值等。

  • step_state(step_state_board类型):步骤状态板,主要用于存放每个训练步骤级别的训练信息,如当前步骤的损失值等。

返回值:

函数首先创建了一个PrettyTable对象,用于格式化参数信息的输出。然后遍历优化器的参数组,统计每个参数的数量,并将其添加到PrettyTable中。在统计过程中,参数的总数量也被累加并存储在全局状态板的update_parameter_count属性中。最后,将总的参数数量也添加到PrettyTable中,并以日志的形式输出。

注意:此函数不会返回任何值,它的主要目的是统计和记录训练过程中的参数信息,以便于后期分析和调试。

示例代码:

\# 创建一个trainer对象

trainer = pymodel_trainer(model=model, loss_obj=loss, optimizer_type="adamw", learn_rate=0.01)

\# 创建全局、历元、步骤状态板

global_state = global_state_board(...)

epoch_state = epoch_state_board(...)

step_state = step_state_board(...)

\# 调用_statistics函数

trainer._statistics(global_state, epoch_state, step_state)

zero_grad

这个方法用来将优化器中的所有梯度清零。在训练神经网络时,我们需要在每个更新步骤之前清零梯度,因为PyTorch在.backward()方法中会累加梯度,而不是替换它们。

参数:

  • global_state (global_state_board): 全局状态板,包含了模型、优化器和其他全局状态信息。

  • epoch_state (epoch_state_board): 当前epoch的状态板,包含了当前epoch的信息,如当前epoch的损失,准确度等。

  • step_state (step_state_board): 当前步(batch)的状态板,包含了当前步的信息,如输入数据,目标标签,模型输出等。

返回类型:

无返回值。此方法主要用于更新全局状态板上的优化器的状态。

示例:


model_trainer = pymodel_trainer(model, loss_obj, optimizer_type="adam", learn_rate=0.01)




**for epoch in range(num_epochs):**

epoch_state = epoch_state_board(epoch)




**for step in range(num_steps):**

step_state = step_state_board(step)

model_trainer.zero_grad(global_state, epoch_state, step_state)

...

invoke_model

invoke_modelpymodel_trainer类中的一个方法,用于执行模型的前向传播过程。

参数:

  • global_state(global_state_board类型的实例): 保存全局状态信息的实例,包含当前正在训练的模型等信息。

  • epoch_state(epoch_state_board类型的实例): 保存某一训练轮次(epoch)的状态信息的实例。本方法中未使用。

  • step_state(step_state_board类型的实例): 保存某一训练步骤(step)的状态信息的实例,包含了本步骤的输入数据等信息。

返回值:

  • Tensor: 前向传播过程的输出结果。

示例:

\# 假设global_state, epoch_state, step_state已经初始化

output = trainer.invoke_model(global_state, epoch_state, step_state)

注意事项:

  • 本方法会根据step_state中的输入数据(converted_input)执行模型的前向传播过程,但并未处理模型的输出结果。具体的后处理过程(如损失函数的计算等)需要在调用本方法后自行进行。

  • epoch_state参数在本方法中未被使用,可以传入None

calculate_loss

calculate_losspymodel_trainer 类的一个成员方法。其功能是计算神经网络模型在给定输入和标签下的损失。

参数:

  • global_state: 一个 global_state_board 对象。包含了模型的全局信息,如模型对象,损失函数对象,优化器对象等。

  • epoch_state: 一个 epoch_state_board 对象。包含了当前epoch的信息,如当前epoch的损失,当前epoch的编号等。

  • step_state: 一个 step_state_board 对象。包含了当前step的信息,如当前step的输入,当前step的输出,当前step的标签等。

返回:

  • loss: 一个 Pytorch 的 Tensor 对象,表示模型在当前step的输入和标签下的损失。

示例使用:

trainer = pymodel_trainer(model=my_model, loss_obj=my_loss)  \# 创建训练器对象

global_state = global_state_board(...)

epoch_state = epoch_state_board(...)

step_state = step_state_board(...)

loss = trainer.calculate_loss(global_state, epoch_state, step_state)  \# 计算损失

注意:

  • 必须确保 global_state 中的 loss_obj 是一个有效的 Pytorch 的损失函数对象。

  • 必须确保 step_state 中的 logitconverted_label 是同形状的 Tensor,否则可能无法计算损失。

  • 目前没有发现该函数存在错误或bug。

backward

backward函数是pymodel_trainer类的一个方法。这个方法用于在给定的步骤状态下,将损失反向传播回模型中。反向传播是神经网络学习的关键步骤,它通过计算损失函数关于网络权重的梯度,来更新模型的参数。这个函数并不返回任何值,但它修改了step_state的状态。

参数:

  • global_state (global_state_board): 全局状态板,存储了训练过程的全局信息,如模型,优化器等。

  • epoch_state (epoch_state_board): 存储了当前训练周期的状态信息,如当前是第几个训练周期,当前训练周期的损失等。

  • step_state (step_state_board): 存储了当前步骤的状态信息,如当前是第几步,当前步骤的输入,输出,损失等。

返回值:

示例:

trainer = pymodel_trainer(model, loss_obj)




**for epoch in range(num_epochs):**




**for step in range(num_steps):**

\# forward pass

output = trainer.invoke_model(global_state, epoch_state, step_state)

\# calculate loss

loss = trainer.calculate_loss(global_state, epoch_state, step_state)

\# backward pass

trainer.backward(global_state, epoch_state, step_state)

错误和异常:

  • 如果step_state.loss_tensor不存在或者为None,会导致backward()函数调用失败。

step

step 方法是 pymodel_trainer 类的一个成员函数,用于执行一次模型的参数更新步骤。

参数列表:

  • global_state: 是一个 global_state_board 实例,储存全局状态,包括模型、优化器、损失函数等信息。

  • epoch_state: 是一个 epoch_state_board 实例,储存当前周期的状态,比如周期损失等。

  • step_state: 是一个 step_state_board 实例,储存当前批处理步骤的状态,如输入、输出、损失值等。在本函数中并未被使用。

此函数无返回值。

函数工作流程:

  • 调用 optimizerstep 方法,按优化器设定的更新策略更新模型的参数。

注意事项:

  • 本函数并未进行错误处理,如果在参数更新过程中出现错误,会引发运行时错误。

使用示例:

trainer = pymodel_trainer(

model=your_model,

loss_obj=your_loss,

optimizer_type='adamw',

)




**while training:**

...

trainer.step(global_state, epoch_state, step_state)

...

train

trainpymodel_trainer 类的一个方法,它负责训练模型。

参数:

  • sample_set (SampleSet): 训练集数据

  • epoch (int): 训练的轮次(epochs),默认为100

  • input_convert_func (function): 转换输入数据的函数,默认为恒等函数

  • label_convert_func (function): 转换标签数据的函数,默认为恒等函数

返回类型:

此函数没有返回值。

使用例子:

trainer = pymodel_trainer(model, loss_obj)

trainer.train(sample_set, epoch=50)

注意事项:

如果loss值不是一个tensor,将会抛出一个异常。在训练过程中,根据update_mode参数的设置,可能在每个步骤(step)或每个轮次(epoch)结束时更新模型参数。在每个步骤或轮次结束时,都会释放与该步骤或轮次相关的资源。

evaluate

这个函数是用于评估模型性能的方法。在评估过程中,它将对数据样本集合进行遍历,然后将每个样本送入模型进行预测。预测结果将被转化为类别标签,并与真实的类别标签进行比较,以计算各类评估指标,如准确率、精确率、召回率和F1分数等。

参数:

sample_set: SampleSet对象,用于存储需要进行评估的数据样本集合。

input_convert_func: 函数对象,将用于对输入数据进行预处理。默认为恒等函数。

label_convert_func: 函数对象,将用于对标签进行预处理。默认为恒等函数。

logit_convert_func: 函数对象,将用于对模型输出的logits进行处理,从而得到预测的类别标签。默认为恒等函数。

scores: 评估指标计算对象的列表。默认会至少包含AccuracyScores(),PrecisionScores(),RecallScores(),F1Scores()这四种。

返回:

这个函数将返回一个三元组,分别包含各类评价指标的结果、总体评价指标的结果以及混淆矩阵。

示例:

假设我们有一个已经训练好的模型model,和一个用于评估的数据集sample_set。我们可以通过以下代码来进行评估:

per_type_result, all_result, c_matrix = model.evaluate(sample_set)

这样我们就可以得到每类的评价指标,总体评价指标以及混淆矩阵。

pytrainer_deepspeed

pytrainer_deepspeed 是一个继承自 pymodel_trainer 的类,用于使用DeepSpeed深度学习优化库训练PyTorch模型。

这个类的目的是在PyTorch模型训练过程中使用DeepSpeed进行优化,包括自动混合精度训练、模型并行、激活checkpointing等特性。通过构造函数,可以选择使用的精度类型。

它使用了DeepSpeed的initialize方法,配置参数包含了对DeepSpeed的各种参数设置,如fp16训练、Zero优化、激活checkpointing等。

类中的_drive_batch_data方法用于在训练开始前,将批次数据移动到指定的设备上。init_deepspeed方法在训练开始时,初始化DeepSpeed的模型、优化器等参数。invoke_model方法用于前向传播,backward方法用于反向传播,step方法用于更新模型参数。

示例:


model = SomeTorchModel()

loss_obj = SomeLossObject()

trainer = pytrainer_deepspeed(model, loss_obj, precision_type=dtype_enum.BF16)

global_state = GlobalStateBoard(...)

epoch_state = EpochStateBoard(...)

step_state = StepStateBoard(...)

trainer.init_deepspeed(global_state, epoch_state, step_state)




**for epoch in range(num_epochs):**




**for batch in dataloader:**

step_state.update(batch)

trainer._drive_batch_data(global_state, epoch_state, step_state)

output = trainer.invoke_model(global_state, epoch_state, step_state)

loss = loss_obj(output, step_state.converted_label)

trainer.backward(global_state, epoch_state, step_state)

trainer.step(global_state, epoch_state, step_state)

注意: 这个类没有明确的错误处理机制,如果传入的参数不符合要求,可能导致错误。

__init__

初始化pytrainer_deepspeed类。

pytrainer_deepspeed是一个继承自pymodel_trainer的类,主要用于训练深度学习模型。该类通过抽象化处理,使得用户可以在不了解深度学习训练细节的情况下,便捷地进行模型训练。

参数:

model (torch.nn.Module): 需要训练的PyTorch模型。

loss_obj: 损失函数对象,用于在训练过程中计算预测值与真实值之间的误差。

precision_type (dtype_enum, 可选): 模型训练的精度类型。默认为dtype_enum.Auto,表示自动选择精度类型。dtype_enum中还包含Float32和BF16两种类型,分别表示使用32位浮点数和16位浮点数进行训练。

**kwargs: 其他参数。

该类的主要方法包括:

init_deepspeed: 初始化DeepSpeed引擎。

invoke_model: 调用模型进行前向传播。

backward: 调用模型进行反向传播。

step: 更新模型参数。

使用示例:

model = torch.nn.Linear(10, 1) \# 假设我们的模型是一个线性模型

loss_obj = torch.nn.MSELoss()  \# 我们使用均方误差作为损失函数

trainer = pytrainer_deepspeed(model, loss_obj, precision_type=dtype_enum.Float32)

_drive_batch_data

这是一个类的方法,负责驱动批次数据。

这个方法的主要功能是将输入数据从CPU转移到GPU上,用于模型的训练计算。

在分布式训练中,由于数据需要在不同的设备上进行计算,因此需要将数据转移到相应的设备上。

参数:

global_state: global_state_board对象,表示全局状态的信息,如模型、优化器等。

epoch_state: epoch_state_board对象,表示当前epoch的状态信息。

step_state: step_state_board对象,表示当前步骤的状态信息,其中包含了当前步骤的输入和标签数据。

返回值:

无返回值。

注意:

这个方法没有返回值,但它会修改step_state对象的converted_input和converted_label属性,

使其指向GPU上的内存地址。

示例:

假设我们有一个名为trainer的pytrainer_deepspeed对象,以及global_state, epoch_state和step_state这三个状态对象,我们可以像下面这样使用这个方法:

trainer._drive_batch_data(global_state, epoch_state, step_state)

在调用这个方法之后,step_state.converted_input和step_state.converted_label将被转移到GPU上,可以被用于模型的训练计算。

init_deepspeed

这个函数是pytrainer_deepspeed类的一部分,用来初始化DeepSpeed模型训练库。

DeepSpeed是微软开源的一个高性能分布式训练库,可以大幅度提升训练速度,同时减少所需的计算资源。这个函数会根据全局状态、时期状态以及步骤状态来设置DeepSpeed的配置参数。

参数:

  • global_state (global_state_board): 全局状态板,包含了训练过程中的全局信息,如模型、优化器等。

  • epoch_state (epoch_state_board): 时期状态板,包含了训练过程中某一轮时期的信息。

  • step_state (step_state_board): 每步训练的状态板,包含了训练过程中某一步的信息。

该函数不返回任何值,但会更改self.model_engineglobal_state.optimizer的值。

使用示例:

trainer = pytrainer_deepspeed(model, loss)

trainer.init_deepspeed(global_state, epoch_state, step_state)

注意:此函数无法单独使用,它依赖于pymodel_trainer类和pytrainer_deepspeed类的其他方法。

注意:在使用此函数前,确保你的环境中已安装了DeepSpeed库。

注意:此函数没有明显的错误或bug,但在使用时要确保传入的状态板(state_board)的类型和值是正确的,否则可能会引发错误。

invoke_model

这是pytrainer_deepspeed类中的一个成员函数,它的主要功能是调用深度学习模型进行前向传播。

参数:

  • global_state:global_state_board类型,包含全局状态信息,如模型、优化器等。

  • epoch_state:epoch_state_board类型,包含单个epoch的状态信息。

  • step_state:step_state_board类型,包含单个训练步骤的状态信息。

返回:

  • 返回模型在输入数据上的前向传播结果。

注意:

这个函数的作用主要是调用模型进行前向传播,并没有与模型相关的其他操作,如更新参数等。

例子:

trainer = pytrainer_deepspeed(model, loss_obj)

trainer.init_deepspeed(global_state, epoch_state, step_state)

output = trainer.invoke_model(global_state, epoch_state, step_state)

backward

这是一个负责执行模型的反向传播的函数。这个函数会使用deepSpeed引擎进行反向传播,从而优化模型的参数。

参数:

  • global_state: global_state_board对象,表示全局状态,包含了全局的设置和参数。

  • epoch_state: epoch_state_board对象,表示当前epoch的状态,包含了当前epoch的设置和参数。

  • step_state: step_state_board对象,表示当前步骤的状态,包含了当前步骤的设置和参数。

返回:

  • 这个函数没有返回值。

例子:

\# 假设已经有了global_state,epoch_state和step_state对象

trainer = pytrainer_deepspeed(model, loss_obj)

trainer.backward(global_state, epoch_state, step_state)

注意:

这个函数不会检查输入参数的合法性,如果输入的参数类型或者值不正确,可能会抛出异常。使用时需要保证输入参数的正确性。

step

这个方法是pytrainer_deepspeed类的一部分,该类是用于训练pytorch模型的。这个类扩展了pymodel_trainer类,添加了对deepspeed库的支持,这是一个用于加速深度学习训练过程的库。

step方法是每一步训练过程中的一个步骤,在每一轮训练的每个批次数据过后调用。

参数:

  • global_state: global_state_board对象,包含全局状态信息,如模型,优化器等。

  • epoch_state: epoch_state_board对象,包含当前epoch的状态信息。

  • step_state: step_state_board对象,包含当前步骤的状态信息。

无返回值。

这个方法的主要任务是执行模型的优化器的step操作,这会更新模型的权重。

注意:这个方法没有明确的错误处理或者异常抛出,如果在执行过程中有任何错误,都会直接导致程序终止运行。

代码示例:

trainer = pytrainer_deepspeed(model, loss_obj)




**for epoch in range(num_epochs):**




**for batch_data in data_loader:**

\# ... 此处省略了一些代码,包括数据预处理,模型前向计算,计算损失等步骤

trainer.step(global_state, epoch_state, step_state)

convert_to_list

将输入的变量转换为列表类型。

这个函数接收一个输入——可以是整数、浮点数、列表或者Tensor,并将其转换为列表。如果输入是整数或浮点数,它将被放入一个新的列表中并返回。如果输入已经是一个列表,那么函数会直接返回它。如果输入是一个Tensor,函数会将其转换为列表,如果Tensor是2D的,它将被flatten(降维)为1D列表。

注意:

假设传入的tensor都是1D或者2D的。

参数:

var: 输入变量,可以是int、float、list、Tensor类型。

返回:

列表,其中包含了原始输入的元素。如果输入是2D Tensor,返回的列表将是flatten后的1D列表。

可能抛出的错误:

ValueError:如果输入的不是int、float、list或Tensor类型,将抛出一个值错误。

plugin_invoke_Enum

plugin_invoke_Enum 是一个Python枚举类,用于表示插件调用的不同状态。

这个枚举类定义了8个成员:

  • Never: 表示插件从未被调用。

  • Epoch_begin: 表示在数据训练的每个epoch开始时调用插件。

  • Epoch_end: 表示在数据训练的每个epoch结束时调用插件。

  • Batch_begin: 表示在数据训练的每个batch开始时调用插件。

  • Batch_end: 表示在数据训练的每个batch结束时调用插件。

  • Begin: 表示在插件的执行开始时。

  • End: 表示在插件的执行结束时。

  • After_Backward: 表示在反向传播之后调用插件。

  • Update: 表示在更新训练模型参数时调用插件。

每个枚举成员都与一个整数值关联,这些值默认从0开始。

例如:

print(plugin_invoke_Enum.Epoch_begin)  \# 输出: Epoch_begin

print(plugin_invoke_Enum.Epoch_begin.value)  \# 输出: 1

注意:Python的枚举类是不可变的,因此不能给枚举成员赋值。

目前未发现此类存在错误或BUG。

device_use_enum

这是一个名为device_use_enum的类,此类继承自Enum枚举类。该类的主要目的是定义设备使用方式的枚举类型。设备使用方式可以是自动(Auto)或者是CPU(CPU)。

在使用这个枚举类时,可以直接通过类名调用枚举值,例如:

device = device_use_enum.Auto




**if device == device_use_enum.Auto:**

print("Device is set to Auto")

请注意,这个类没有明显的错误或bug,但是当需要定义更多设备使用方式时,需要在类内部添加。

dtype_enum

这是一个枚举类dtype_enum,它提供了一种枚举数据类型的方法,为不同的数据类型提供了枚举值。枚举值对应的是不同的数据类型,如自动类型、BF16类型、FP16类型、Float32类型等。

它主要用于标记和识别不同的数据类型,并可以在代码中方便的引用和比较。

示例:

**def process_data(data, dtype):**




**if dtype == dtype_enum.Auto:**

\# do something




**elif dtype == dtype_enum.BF16:**

\# do something




**elif dtype == dtype_enum.FP16:**

\# do something




**elif dtype == dtype_enum.Float32:**

\# do something

该类没有发现明显的错误或者bug。

update_mode_enum

这是一个枚举类update_mode_enum,它是枚举(Enum)类的子类。该类定义了两种更新模式:“Per_Step”和“Per_Epoch”,分别对应数值1和2。这种设计通常用于表示不同的更新策略或模式,使得代码更具可读性和维护性。在调用或使用这个枚举类时,可以使用update_mode_enum.Per_Step或update_mode_enum.Per_Epoch来表示不同的更新模式。

例如:

在某神经网络的训练过程中,我们可以根据需要选择不同的更新模式:





**if update_mode == update_mode_enum.Per_Step:**

\# 每一步都更新模型参数

model.update()




**elif update_mode == update_mode_enum.Per_Epoch:**

\# 每一个epoch结束后更新模型参数

model.update()

这样一来,通过使用本枚举类,我们的代码变得更易于理解和维护。

global_state_board

这是一个名为global_state_board的类,该类的主要作用是在训练过程中进行全局状态的记录和管理。

类的初始化函数需要接收的参数包括:

  • train_obj:训练对象,包含训练所需要的模型、损失函数、优化器等信息。

  • epoch_count:训练迭代的次数。

  • sample_set:训练集样本。

  • pb:进度条对象,用于展示训练进度。

  • input_convert_func:输入转换函数,用于处理输入数据。

  • label_convert_func:标签转换函数,用于处理标签数据。

此外,该类还提供了log方法,用于将日志信息添加到log_stack列表中。

示例:

\# 创建训练对象

trainer = SomeTrainer(model, loss, optimizer, out_folder)

\# 创建全局状态记录板对象

gsb = global_state_board(trainer, epoch_count, sample_set, progress_bar, input_convert_func, label_convert_func)

\# 添加日志

gsb.log('Start training...')

注意:类的使用应在训练流程的控制和管理上下文中,确保提供的训练对象和数据集等信息正确无误。

__init__

global_state_board是一个类,它用于记录和管理训练过程的全局状态信息。

类初始化方法如下:

__init__(self, train_obj, epoch_count, sample_set, pb, input_convert_func, label_convert_func):

参数:

train_obj: 训练对象,通常包含了模型、损失函数等信息。

epoch_count: 训练的轮数。

sample_set: 训练样本集。

pb: 进度条对象,用于在训练过程中显示训练进度。

input_convert_func: 输入转换函数,用于将原始输入数据转换为适合模型训练的格式。

label_convert_func: 标签转换函数,用于将原始标签数据转换为适合模型训练的格式。

该类的主要目标是在训练过程中收集和保存训练状态,包括训练对象的类型、模型、损失函数、优化器、训练的轮数、训练样本集、更新模式、进度条、输入和标签的转换函数等。同时,该类也负责记录训练日志、参数更新次数、插件数据、图表数据以及训练开始的时间。

示例:


\# 创建一个训练对象

train_obj = Trainer(model, loss, optimizer, out_folder)

\# 设置训练轮数

epoch_count = 100

\# 创建训练样本集

sample_set = SampleSet(data, labels)

\# 创建进度条对象

pb = ProgressBar()

\# 定义输入转换函数




**def input_convert_func(input_data):**

return input_data.reshape(-1, 1)

\# 定义标签转换函数




**def label_convert_func(label_data):**

return label_data.reshape(-1, 1)

\# 创建global_state_board对象

gsb = global_state_board(train_obj, epoch_count, sample_set, pb, input_convert_func, label_convert_func)

log

logglobal_state_board类的一个方法.

使用这个方法可以将日志信息添加到log_stack属性中. log_stack是一个列表,用于存储所有的日志信息. 当调用log方法时,传入的日志信息将被添加到列表的尾部.

参数:

  • self: 是指向类实例的引用. 在Python中,它是所有实例方法的第一个参数.

  • lg_s: 一个字符串类型的参数,它代表了要添加到log_stack的日志信息.

返回值:

  • 这个方法没有返回值.

用法:

下面是一个使用log方法的例子.

board = global_state_board(...)

board.log("Training started")

在这个例子中,我们创建了一个global_state_board的实例,然后使用log方法添加了一条日志信息"Training started".

注意:

  • 这个方法没有做任何的错误检查,所以当传入的lg_s不是字符串时,程序可能会崩溃.

epoch_state_board

这是一个名为epoch_state_board的类,用于在训练神经网络时追踪和记录每个训练周期(epoch)的状态。

每个epoch_state_board对象代表一个训练周期,它包含以下属性:

  • epoch_idx: 训练周期的索引值。

  • epoch_loss: 该训练周期的损失值。

  • start_time: 该训练周期的开始时间,表示为Unix时间戳。

  • end_time: 该训练周期的结束时间,表示为Unix时间戳。

  • plugin_datas: 一个字典,用于存储插件数据,键是插件名称,值是插件返回的任何数据。

在每个训练周期开始时,会创建一个新的epoch_state_board对象,然后在训练过程中更新其属性。

该类的主要目的是提供一种方便的方式来记录和访问训练信息,这样可以在训练过程中方便地进行调试和分析。

示例用法如下:

\# 假设我现在是在第5个训练周期

epoch_board = epoch_state_board(5)

\# 在训练过程中,我可以不断更新损失值

epoch_board.epoch_loss += loss_value

\# 当训练周期结束时,我可以记录结束时间

epoch_board.end_time = time.time()

\# 我也可以通过`plugin_datas`属性记录其他信息,如准确率

epoch_board.plugin_datas['accuracy'] = calculate_accuracy()

\# 在后续的代码中,我可以通过`epoch_board`对象来获取训练信息

print(f"Epoch {epoch_board.epoch_idx} loss: {epoch_board.epoch_loss}")

__init__

epoch_state_board类用于管理和跟踪每个训练周期(称为"epoch")的状态。它记录了训练周期的索引、该周期的损失函数值、训练开始和结束的时间以及相关的插件数据。

类参数:

  • epoch_idx:该训练周期的索引(也就是编号)。

类属性:

  • epoch_idx:存储传入的epoch_idx参数,表示当前训练周期的索引。

  • epoch_loss:在每个训练周期开始时,初始化为0,用于累积计算训练周期的总损失。

  • start_time:训练周期开始的时间,使用time.time()获取当前时间。

  • end_time:训练周期结束的时间,初始化为0。

  • plugin_datas:一个用于存储插件数据的字典,初始化为空。

使用示例:

\# 创建一个`epoch_state_board`对象,传入训练周期的索引作为参数

epoch_status = epoch_state_board(epoch_idx=1)

\# 在训练周期中,可以通过`epoch_status.epoch_loss`来累积损失值




**for data in train_data:**

loss = train_step(data)

epoch_status.epoch_loss += loss

\# 在训练周期结束后,可以通过`epoch_status.end_time`来设置结束时间

epoch_status.end_time = time.time()

注意:目前没有发现代码中存在的错误或bug。

step_state_board

step_state_board类是一个用于记录并跟踪训练过程中每一步的状态的类。这个类可以用来为每个训练步骤保存相关的信息,例如开始和结束时间、损失值等。这些信息有助于进一步的分析和调试。

该类的属性包括:

  • batch_idx:当前批次的索引

  • ori_item:原始的批次数据

  • start_time:当前步骤开始的时间

  • end_time:当前步骤结束的时间

  • converted_input:转换后的输入数据,由global_state.input_convert_func(batch_item)得到

  • converted_label:转换后的标签数据,由global_state.label_convert_func(batch_item)得到

  • logit:预测的输出

  • loss_value:当前步骤的损失值

  • loss_tensor:当前步骤的损失张量

  • plugin_datas:插件数据,用于保存额外的信息

使用方式如下:

global_state = some_global_state_board()  \# 初始化一个全局状态




**for i, data in enumerate(dataloader):**

step_state = step_state_board(i, data, global_state)  \# 创建步骤状态板

...

\# 在训练过程中更新步骤状态板的信息

step_state.end_time = time.time()

step_state.loss_value = some_loss

...

注意,converted_inputconverted_label的值应由global_state.input_convert_func(batch_item)global_state.label_convert_func(batch_item)得到,但在\_\_init\_\_方法中并未被赋值,需要在后续的训练过程中手动赋值。

__init__

step_state_board 是一个用于描述批处理状态的类。它用于跟踪批处理中的各种信息,如批处理索引、原始项目、开始和结束时间、转换后的输入和标签、日志、损失值、损失张量和插件数据等。

类的使用示例:

\# 初始化一个全局状态

global_state = global_state_board()

\# 初始化一个批处理状态

batch_state = step_state_board(batch_idx=0, batch_item=data, global_state=global_state)

构造函数 \_\_init\_\_ 的参数:

  • batch_idx : int 类型,批处理的索引。

  • batch_item : 数据类型不限,是批处理的原始数据项。

  • global_state : global_state_board 类型,描述全局的状态。

构造函数 __init__ 不返回任何值,它的目的是初始化 step_state_board 类的实例。

注意:目前函数里的 global_state.input_convert_func(batch_item)global_state.label_convert_func(batch_item) 两行代码被注释掉了,可能会影响 self.converted_inputself.converted_label 的赋值,需要根据实际情况决定是否启用这两行代码。

invoke_at

此函数是一个装饰器生成器,用于为函数添加一个特殊属性 _invoke_at,以便后续在特定的插件调用时识别和处理。

参数:

types (list): 包含plugin_invoke_Enum枚举类的列表,用于指示函数在哪些类型的插件调用时被触发。

返回:

返回一个装饰器,这个装饰器可以被用于装饰其他函数,给他们附加_invoke_at属性。

使用示例:

@invoke_at([plugin_invoke_Enum.Type1, plugin_invoke_Enum.Type2])

def some_function():

pass

上述代码会给 some_function 函数附加一个属性_invoke_at,其值为[plugin_invoke_Enum.Type1, plugin_invoke_Enum.Type2]。

注意事项:

  • 请确保 types 列表中的元素都是 plugin_invoke_Enum 枚举类的实例。

  • 被此函数装饰的函数在执行时,其实际行为不会被改变,即它仍然会按照原代码执行。

trainer_plugin_base

这是一个抽象基类,提供训练插件的基础结构和工具。这个类主要有两个方法:get_plugin_mapInvoke

# 类介绍

此类作为所有训练插件的基类,定义了训练插件的基础结构和行为。提供了一个接口,供子类实现特定的训练行为。

# 使用例子

class MyTrainerPlugin(trainer_plugin_base):

def Invoke(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):

# 实现特定的训练行为

# 方法介绍

  • get_plugin_map:返回一个字典,字典的键是枚举类型,值是方法列表,这些方法在该枚举类型下被调用。此方法用于整理和提供训练插件的调用行为。

  • Invoke:这是一个抽象方法,需要在子类中实现。此方法在训练过程中被调用,根据当前的全局状态、纪元状态和步骤状态进行特定的训练行为。

# 参数列表

  • global_state:全局状态板,提供了全局的状态信息,如训练总的纪元数、当前纪元数等。

  • epoch_state:纪元状态板,提供了当前纪元的状态信息,如纪元的开始时间、结束时间等。

  • step_state:步骤状态板,提供了当前步骤的状态信息,如步骤的开始时间、结束时间等。

# 返回类型介绍

  • get_plugin_map:返回一个字典,键是枚举类型,值是方法列表。

  • Invoke:没有返回值。

# 注意事项

  • 该类是一个抽象基类,不能直接实例化,只能作为基类被继承。

  • Invoke方法必须在子类中实现。

get_plugin_map

这个方法用于获取插件映射表。插件映射表是一个字典,其键是枚举值,对应的值是一个包含所有具有此枚举值的方法的列表。

插件映射表的获取过程如下:首先,我们遍历这个类的所有属性,然后检查每个属性是否具有'_invoke_at'属性。如果具有'_invoke_at'属性,我们就把这个属性看作是一个方法,并且将其'_invoke_at'属性的每个元素都看作是一个枚举值。然后,我们将这些枚举值和对应的方法添加到插件映射表中。

参数:

返回类型:

一个字典,其键是枚举值,对应的值是一个包含所有具有此枚举值的方法的列表。

注意:

如果同一个枚举值对应多个方法,那么这些方法将被添加到同一个列表中,并将此列表作为这个枚举值在插件映射表中的值。

示例:

如果我们有一个类,其内部定义了两个方法method1和method2,它们的'_invoke_at'属性都包含了同一个枚举值enum1,那么在调用get_plugin_map方法后,将返回一个插件映射表,其内容将是{enum1: [method1, method2]}。

Invoke

这是一个在基类trainer_plugin_base中的方法,名为Invoke。这个方法是一个抽象方法,用于在子类中被重写。

参数:

  • global_state (global_state_board): 全局状态对象,用于存储和传递全局的状态信息

  • epoch_state (epoch_state_board): epoch状态对象,用于存储和传递当前epoch的状态信息

  • step_state (step_state_board): step状态对象,用于存储和传递当前step的状态信息

返回:

这个方法没有返回值。

注意,这个方法过于抽象,具体的行为需要在子类中实现,因此可能会有不同的行为和返回值。如果你在使用这个方法的过程中遇到问题,可能是由于在子类中没有正确地重写这个方法。

例子:

**class MyTrainerPlugin(trainer_plugin_base):**




**def Invoke(self, global_state, epoch_state, step_state):**

\# 在这里实现你的逻辑

print(global_state, epoch_state, step_state)

在这个例子中,我们创建了一个新的训练插件MyTrainerPlugin,并重写了Invoke方法。在我们的实现中,我们只是简单地打印出了传入的状态信息。

pytrainer_accelerate

这是一个名为pytrainer_accelerate的类,继承自pymodel_trainer。该类的目标是提供一种在使用PyTorch框架和Accelerator库在GPU或CPU上训练模型的方式。用户可以通过指定参数,灵活地控制训练过程,如是否使用CPU进行训练,指定使用哪些GPU,以及每个GPU的最大使用率等。

该类在初始化时,将传入的模型和损失函数,以及其他参数绑定到自身。同时根据指定的设备类型创建Accelerator实例,并将模型分配到相应的设备上。

类中定义的_drive_batch_data方法,在每个训练批次开始时被调用,用于将输入和标签数据转移到用于训练的设备上。

backward方法则在计算反向传播时被调用,用于计算损失函数的梯度。

使用示例:

model = torch.nn.Sequential(...)

loss_obj = torch.nn.CrossEntropyLoss()

trainer = pytrainer_accelerate(model, loss_obj, use_cpu=False, use_gpu_ids=[0, 1])




**for epoch in range(num_epochs):**




**for batch in dataloader:**

trainer._drive_batch_data(...)

trainer.backward(...)

参数:

  • model (torch.nn.Module):待训练的模型。

  • loss_obj:用于模型训练的损失函数。

  • use_cpu (bool):是否使用CPU进行训练。默认为False。

  • use_gpu_ids (list):要用于训练的GPU的ID列表。如果是None,则使用所有可用的GPU。默认为None。

  • no_split_module_classes (list):不进行模型分割的模块类别列表。默认为空列表。

  • max_use_per_gpu (float):每个GPU的最大使用率,值在0.0到1.0之间,默认为1.0。

  • kwargs:其他参数。

注意点:

  • 当use_cpu为True时,无论use_gpu_ids的值为何,都只会使用CPU进行训练。

  • 当use_gpu_ids为None时,将使用所有可用的GPU进行训练,但是GPU的使用率仍然受max_use_per_gpu参数的限制。

  • 当指定no_split_module_classes时,对应的模块在分配到设备时,将保持整体,不会被分割。

__init__

pytrainer_accelerate是一个pytorch模型训练加速器类,继承自pymodel_trainer。它利用了Accelerator库来配合GPU或CPU进行分布式训练,并在模型训练的过程中对数据进行优化处理,提高模型训练速度和效率。

参数:

  • model: 需要被训练的torch.nn.Module模型对象。

  • loss_obj: 损失函数对象。

  • use_cpu: 布尔值,是否使用CPU进行训练。默认值为False。

  • use_gpu_ids: 列表,需要使用的GPU的ID。如果为None,则自动选择GPU。默认值为None。

  • no_split_module_classes: 列表,不希望分割的模块类。默认值为空列表。

  • max_use_per_gpu: 每个GPU最大使用的内存百分比,范围是0-1。默认值为1.0。

  • kwargs: 其他参数。

方法:

  • _drive_batch_data: 在每个批次训练开始时调用,将数据移动到加速器设备(CPU或GPU)上。

  • backward: 对损失函数进行反向传播。

示例:

model = torch.nn.Linear(10, 1)

loss_obj = torch.nn.MSELoss()

trainer = pytrainer_accelerate(model, loss_obj, use_cpu=False, use_gpu_ids=[0, 1])

在上述示例中,我们创建了一个线性模型和一个均方差损失函数,然后使用pytrainer_accelerate类来加速模型训练,我们选择了使用ID为0和1的两个GPU进行训练。

_drive_batch_data

这个类为 pytrainer_accelerate,它继承了 pymodel_trainer 类,主要针对模型的训练过程进行加速处理。主要通过使用 Accelerator 对象和 device_map 实现模型处理设备的自动调度,并在每个批次数据处理前,将数据转移到合适的处理设备。

函数 _drive_batch_datapytrainer_accelerate 的一个成员函数,它在每个批次数据处理前被调用。这个函数的主要作用是将数据(输入和标签)转移到合适的处理设备(CPU或CUDA设备)。数据的转移过程都在 move_to_drive 这个内部函数中完成,支持数据为 dict 类型或者其他类型。转移后的数据会保存在 step_state 对象的 converted_inputconverted_label 属性中。

_drive_batch_data 函数的参数列表如下:

  • global_state: global_state_board:全局状态板,用于存放全局状态信息。

  • epoch_state: epoch_state_board:当前周期状态板,用于存放当前周期的状态信息。

  • step_state: step_state_board:当前步骤状态板,用于存放当前步骤的状态信息。

这个函数没有返回值。

backward

backwardpytrainer_accelerate类的一个方法。pytrainer_accelerate类是一个模型训练器,它继承自pymodel_trainer类,并添加了accelerate库的支持,可以在CPU和GPU上进行高效训练。

它通过model参数接受一个pytorch模型,并通过loss_obj参数接受一个损失对象。此外,还可以通过参数use_cpuuse_gpu_idsno_split_module_classesmax_use_per_gpu等来指定训练的硬件环境和模型的部署方式。

该类的实例化方法通过Accelerator类创建一个加速器,并根据给定的硬件环境和模型部署方式将模型部署到对应的设备上。

backward方法是用于执行反向传播的。它接受三个参数,分别是global_stateepoch_statestep_state,这些状态对象包含了训练过程中的全局状态、当前epoch的状态和当前步骤的状态。它通过acceleratorbackward方法来执行反向传播,并更新状态对象中的损失张量。

参数:

  • global_state (global_state_board): 全局状态板,包含了训练过程中的全局状态。

  • epoch_state (epoch_state_board): epoch状态板,包含了当前epoch的状态。

  • step_state (step_state_board): 步骤状态板,包含了当前步骤的状态。

返回:

示例:

\# 创建一个模型和损失对象

model = torch.nn.Linear(10, 1)

loss_obj = torch.nn.MSELoss()

\# 创建一个训练器

trainer = pytrainer_accelerate(model, loss_obj, use_cpu=True)

\# 创建状态板

global_state = global_state_board(...)

epoch_state = epoch_state_board(...)

step_state = step_state_board(...)

\# 执行反向传播

trainer.backward(global_state, epoch_state, step_state)

trainerplugins

该Python模块主要用于机器学习模型的训练过程中的监控、调试、保存和评估。具体包括:记录和检查模型参数梯度,便于调试和监控训练过程;在训练过程中保存模型,可以根据步骤或损失值决定保存时机;加载并校验模型参数,将其应用到指定的PyTorch模型;实时监控和展示训练过程中的参数;评估训练集分类模型的性能,并绘制结果图表;记录和管理训练过程中的日志信息。

类成员:

name info
grad_log grad_log类是用于在训练过程中检查和记录模型参数梯度的工具,以便于调试和监控训练过程。
model_save 这是一个在训练过程中进行模型保存的插件,能根据步骤或损失值来决定保存模型。
dashboard_plugin "dashboard_plugin"是一个用于实时监控和展示训练过程参数的可视化仪表板类,基于flask框架提供web服务。
classification_evalution_trainset 这是一个用于评估训练集分类模型性能(包括准确率、精度、召回率和F1分数)并绘制结果图表的类。
log_plugin 'log_plugin'是一个用于在训练过程中记录和管理日志信息的类。

函数成员:

name info
load_saved_parameters 这是一个加载并校验模型参数,将其应用到指定PyTorch模型的函数。

grad_log

这个类是grad_log,它是trainer_plugin_base的子类,这个类用于在训练过程中检查模型参数的梯度,以便于调试和监控模型的训练过程。

主要的功能和用法如下:

  1. 在训练开始(Invoke2)和每个batch结束后(Invoke),检查模型的参数,如果参数的绝对值超过设定的阈值,就用_format_table函数将参数的信息格式化后记录下来。

  2. _format_table函数会将参数的信息(包括形状、最小值、最大值等)格式化为字符串,方便后续查看。

使用示例:

logger = grad_log(batch_check_freq=10)

trainer.add_plugin(logger)

trainer.train()

初始化的参数:

  • init_model_check:是否在训练开始时检查模型参数,默认为True。

  • batch_check_freq:每几个batch检查一次模型参数,默认为3。

  • check_fp16:是否检查16位浮点数,默认为True。

  • warning_level:设置梯度的警告级别,默认为1。

函数:

  • \_\_init\_\_:初始化函数,设置一些参数和阈值。

  • _in_range:检查值是否在设定的范围内。

  • _format_table:格式化参数的信息。

  • Invoke2:在训练开始时调用,检查模型参数。

  • Invoke:在batch结束后调用,检查模型参数。

注意:

这个类没有返回值,主要用于在训练过程中打印和记录参数信息。

__init__

初始化 grad_log 类实例。

grad_log 是一个用于检查 model 的参数和梯度值是否在安全范围内的 trainer 插件。如果参数或梯度值超出范围,将以指定的警告级别打印警告。该类可帮助我们了解模型训练过程中参数和梯度的变化情况,以便调整学习率等超参数,优化模型训练过程。

参数:

init_model_check (bool): 是否在训练开始时检查模型的初始参数值,默认为 True。

batch_check_freq (int): 每隔几个 batch 检查一次参数和梯度值,默认为 3。

check_fp16 (bool): 是否检查半精度浮点数 (float16) 的参数和梯度值,默认为 True。如果为 False,将检查单精度浮点数 (float32)。

warning_level (int): 警告级别,默认为 1。级别越高,安全范围越小,警告越频繁。

使用示例:

from trainer_plugin_base import trainer_plugin_base




**class my_trainer(trainer_plugin_base):**




**def \_\_init\_\_(self):**

super().\_\_init\_\_()

self.plugin = grad_log(init_model_check=True, batch_check_freq=1, check_fp16=False, warning_level=2)

在这个示例中,我们初始化了一个 my_trainer 类,它继承了 trainer_plugin_base 并使用 grad_log 插件。在每个 batch 训练结束后,都将检查模型的参数和梯度值,并以较高的警告级别打印警告。我们不检查 float16 参数,只检查 float32 参数。

_in_range

这是一个工具函数,它接受一个或多个值作为输入,并检查这些值是否在预先定义的范围内。

这个函数的主要目的是在梯度下降训练过程中,对模型参数和梯度进行检查,以确保它们没有超出浮点数可以表示的范围,防止因溢出等问题导致的训练失败。

参数:

*values (float): 一个或多个需要检查的浮点数值。

返回:

bool: 如果所有输入值都在预定义的范围内,返回True;否则返回False。

使用示例:

check_result = self._in_range(max_value, min_value)




**if not check_result:**

\# 如果检查失败,打印警告信息或进行其他处理

print("Warning: some values are out of range.")

注意:这个函数不会对输入值进行任何修改或处理,只进行范围检查。

_format_table

这是一个私有方法,用于格式化表格中的每一行数据。方法首先遍历每个单元格数据,判断该数据的类型,如果数据是浮点型,它会首先检查该数据是否在预设范围内,如果在范围内,精确到小数点后八位,否则,将其标记为异常值(以 "**" 包围)。如果数据是列表类型,则将每个元素转换为字符串,并用逗号连接。如果数据类型既不是浮点型也不是列表类型,则直接将该数据转换为字符串。

参数:

row_data: 数据列表,包含要格式化的行数据。

返回:

formatted_row: 格式化后的行数据列表。

例子:

假设_row_data = [123.456789, [1,2,3], "测试"]

经过_format_table方法处理后,得到formatted_row = ['123.45678900', '1,2,3', '测试']

注意:

这个方法主要被类方法Invoke和Invoke2使用,而不应该直接被调用。

Invoke2

Invoke2grad_log 类的一个方法,它在训练过程中的某个指定时点被调用。其主要目的是检查模型中各个参数的值是否在合理范围内,并将那些超出范围的参数进行记录和显示。

该方法接收三个参数,分别为 global_stateepoch_state,和 step_state

参数:

  • global_state (类型:global_state_board):一个存储全局状态的对象,包括了整个训练过程的相关信息,例如模型参数、优化器状态等。

  • epoch_state (类型:epoch_state_board):一个存储当前训练轮次(epoch)状态的对象,包括了当前轮次的相关信息,例如当前轮次的训练损失、准确率等。

  • step_state (类型:step_state_board):一个存储当前训练步骤(step)状态的对象,包括了当前步骤的相关信息,例如当前步骤的训练损失、梯度值等。

返回值:

  • 此方法无返回值。

在执行过程中,该方法首先检查每个需要进行梯度更新的模型参数,计算其均值、方差、最大值和最小值。然后判断这些值是否在设定的阈值范围内。如果参数值超出范围,则将该参数的名称、形状、最小值、最大值、均值、方差等信息添加到rows列表中。最后,如果rows列表不为空,即存在参数值超出范围的情况,则调用print_table函数,将这些信息以表格的形式打印出来。

注意:这个方法并不会改变模型参数的值,只是进行检查并输出超出范围的参数信息。如果希望在参数值超出范围时进行某种处理,需要在这个方法中添加相应的代码。

Invoke

这是一个触发方法,用于在每个训练批次结束时进行调用。该方法主要用于检查模型参数的梯度,包括梯度的最大值、最小值等。如果检测到的梯度值超出预定义的范围,该方法将对梯度值进行格式化并打印出来。

参数:

  • global_state(global_state_board): 全局状态板,包含模型的全局状态信息,如模型的参数、优化器等。

  • epoch_state(epoch_state_board): epoch状态板,包含当前epoch的状态信息,如当前epoch的损失值、准确率等。

  • step_state(step_state_board): step状态板,包含当前step的状态信息,如当前step的损失值、准确率等。

无返回值。

其中一段代码示例如下:

**for group in global_state.optimizer.param_groups:**




**for param_tensor in group['params']:**

max_value = param_tensor.data.max().item()

mini_value = param_tensor.data.min().item()

shape = list(param_tensor.data.size())

grad_mini = param_tensor.grad.min().item() if param_tensor.grad is not None else 0

grad_max = param_tensor.grad.max().item() if param_tensor.grad is not None else 0




**if not self._in_range(max_value, mini_value, grad_mini, grad_max):**

rows.append(self._format_table([shape, mini_value, max_value, grad_mini, grad_max]))

以上代码遍历优化器的参数组,对每一个参数张量,计算其数据的最大值、最小值以及梯度的最大值、最小值,然后判断这些值是否在允许的范围内,如果不在,则将这些值进行格式化并添加到rows列表中。

此方法没有已知的错误或bug。

model_save

这个类是一个模型的保存插件,在训练的过程中,不同的步骤下进行模型的保存。它继承了trainer_plugin_base基类。

主要功能:

  • 在训练过程的每个步骤结束时,每隔一定步数(save_per_step)进行模型保存

  • 在每一个训练周期结束的时候,如果该周期的损失值小于之前保存过的模型的损失值,那么就保存该模型

  • 具备模型保存路径不存在时,创建路径的能力

主要方法:

  • \_\_init\_\_(self, save_per_step=1, mini_loss_save=True, save_folder='checkpoint'):初始化方法,设置保存步长、是否保存最小损失模型以及模型保存路径

  • BeginInvoke(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):在训练开始时,创建模型保存的文件夹

  • save_model(self, global_state: global_state_board, path):保存模型的方法,会根据训练器的类型,选择不同的保存方法

  • Batch_end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):在每个批次训练结束时,每隔一定步数进行模型的保存

  • epoch_end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):在每个训练周期结束时,如果该周期的损失值小于之前保存过的模型的损失值,那么保存该模型

  • Invoke(self, base_wall, epoch_wall, batch_wall):在插件调用时被执行,但在该类中并未被实现

例子:

trainer = Trainer(...)

plugin = model_save(save_per_step=100, save_folder='./checkpoint')

trainer.add_plugin(plugin)

trainer.train(...)

注意事项:

  • 由于Invoke方法在该类中并未被实现,所以如果有需要在插件调用时执行的操作,请在子类中重写此方法。

__init__

初始化model_save类的实例对象。

此类是训练模型时的一个插件,用于保存训练过程中的模型。其中,每一步训练后都会保存一次模型,并且在每个周期结束后,如果损失降低,也会保存模型。

参数:

save_per_step(int, 默认值为1): 每训练多少步,保存一次模型。

mini_loss_save(bool, 默认值为True): 是否在每个训练周期结束后,如果损失有所降低,保存模型。

save_folder(str, 默认值为'checkpoint'): 保存模型的文件夹名称。

返回:

使用示例:

save_plugin = model_save(save_per_step=50, mini_loss_save=True, save_folder='model_checkpoints')

在这个示例中,model_save的实例在每50步训练后保存一次模型,如果训练周期结束后,损失有所降低,也会保存模型,所有的模型都保存在'model_checkpoints'文件夹下。

注意:

此类没有返回值。

BeginInvoke

这是一个model_save类中的BeginInvoke函数,其作用是在训练开始时创建模型保存的文件夹。

参数:

global_state (global_state_board): 存储全局状态相关信息的实例,包括当前模型、优化器等信息。

epoch_state (epoch_state_board): 存储当前周期(epoch)相关状态的实例,如当前周期的损失等。

step_state (step_state_board): 存储当前步骤相关状态的实例,如当前步骤的损失、正确率等。

返回:

无返回值

使用示例:

在训练开始时,首先创建一个model_save类的实例,然后调用这个函数来创建模型保存的文件夹。


save_plugin = model_save(save_per_step=100, save_folder='checkpoint')

save_plugin.BeginInvoke(global_state, epoch_state, step_state)

注意事项:

该函数无返回值,主要作用是创建保存模型的文件夹,如果文件夹已经存在则不会重复创建。

该函数没有错误处理,如果在创建文件夹时出现错误(如权限问题、磁盘空间不足等),可能会抛出异常。

save_model

本函数用于模型的保存。根据训练器的类型,将模型的需要优化的参数保存在指定的路径下。

参数:

  • global_state: global_state_board类型,包含全局状态信息,如训练器类型、模型等。

  • path: 字符串类型,用于指定模型保存的路径。

返回:

  • 无返回值

注意:本函数不会检查路径是否合法、是否有权限等,这些需要在调用前确认。

Batch_end

Batch_end是一个函数,其属于model_save类的方法。此函数在每个batch训练结束后被调用,用于定期保存当前全局状态下的模型参数。如果全局状态下的步数可以被预设的保存步数整除,则会保存模型。

参数:

  • global_state (global_state_board): 表示全局状态的对象,包含了模型、优化器等相关信息。

  • epoch_state (epoch_state_board): 表示当前epoch(训练周期)状态的对象,包含了当前epoch的损失、准确率等信息。

  • step_state (step_state_board): 表示当前步骤状态的对象,包含了当前步骤的损失、准确率等信息。

该函数没有返回值。

例如,如果设定save_per_step=100,那么每进行100步训练,函数就会执行以下操作:

  1. 构建保存路径,路径包括模型文件夹、保存文件夹和步数信息。

  2. 调用save_model方法保存模型

  3. 将保存信息添加到日志堆栈

注意:此函数不保证每次都能够成功保存模型,只有当全局状态下的步数能够被预设的保存步数整除时,才会进行保存。

epoch_end

此方法是model_save类的一个方法,旨在在每个训练周期结束时执行特定的操作。主要功能是在训练周期结束时,如果周期损失低于当前最小损失,则保存模型。

参数:

  • global_state (global_state_board类型): 全局状态对象,包含训练过程中全局的状态信息,如模型、优化器等。

  • epoch_state (epoch_state_board类型): 周期状态对象,包含训练过程中每个周期的状态信息,如周期损失等。

  • step_state (step_state_board类型): 步骤状态对象,包含训练过程中每个步骤的状态信息。

此方法没有返回值。

注意:

  • 使用此方法需要确保global_stateepoch_statestep_state对象已经被正确初始化,并在运行过程中被正确更新。

  • 此方法会改变self.mini_loss的值,用于记录当前训练过程中的最小损失值。

  • 如果保存的模型文件夹已经存在,则会删除文件夹中的所有文件,然后保存新的模型。

  • 此方法不会处理任何异常,如果在运行过程中出现异常,如文件读写错误等,需要在调用此方法的地方进行捕获和处理。

Invoke

Invoke方法是model_save类的一个空方法,它在这个类中没有具体的实现,并且在类的其他地方也没有被调用。这个方法可能是一个占位符,留待未来实现某些功能时使用。

参数:

  • base_wall: 未在类中使用,可能是未来实现某些功能时使用。

  • epoch_wall: 未在类中使用,可能是未来实现某些功能时使用。

  • batch_wall: 未在类中使用,可能是未来实现某些功能时使用。

返回:

  • 无返回类型

注意:

  • 此函数当前未实现任何功能,也未在类的其他地方被调用,可能存在一些占位的目的。

load_saved_parameters

该函数用于加载保存的模型参数。

参数:

model: PyTorch模型,该模型是需要加载参数的模型。

path: str,保存模型参数的路径。

该函数从给定的路径加载模型参数,并将它们加载到给定的模型中。函数首先确定模型参数的设备位置,

然后根据这个设备位置加载模型参数。如果加载的参数中包含'module'关键字,我们将直接使用该关键字对应的参数。

在加载过程中,该函数会检查每一个参数的形状是否与保存的参数形状相匹配。如果不匹配,将会打印一条警告日志信息。

最后,函数会打印加载了多少个参数。

注意:

  1. 此函数没有返回值,它直接修改传入的模型参数。

  2. 如果参数加载过程中出现任何错误,函数会抛出异常。

用法示例:

model = SomePyTorchModel()

path = "/path/to/saved/parameters"

load_saved_parameters(model, path)

dashboard_plugin

这是一个名为dashboard_plugin的类,它继承自trainer_plugin_base。这个类的主要目的是为训练过程提供一个可视化的仪表板,用于实时监视训练过程中的变化,如损失函数值、参数更新次数、模型的训练进度等。它使用了flask框架来启动一个web服务,使用者可以在本地浏览器中查看训练过程。这个类还提供了一些辅助函数,比如时间戳的转换、数据的缩放等。

使用这个类时,只需要在训练脚本中创建一个dashboard_plugin的实例,然后在训练循环中调用其相关方法即可。例如:

dashboard = dashboard_plugin(port=8080)




**for epoch in range(epochs):**




**for batch in dataloader:**

...

dashboard.step_cost_cal(global_state, epoch_state, step_state)

dashboard.epoch_end(global_state, epoch_state, step_state)

dashboard.start(global_state, epoch_state, step_state)

类内方法描述:

  1. \_\_init\_\_(self, port=None): 这是初始化方法,创建一个dashboard_plugin的实例。参数port是可选的,表示flask服务监听的端口号,默认为None。

  2. convert_time(self, timestamp): 这个方法用于将时间戳转换为本地时间。

  3. start(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 这个方法在训练开始时调用,用于初始化一些状态和启动flask服务。

  4. _scale_data(self, data: list): 这是一个私有方法,用于缩放数据,使得它们能在仪表板上更好地显示。

  5. refresh_page(self, global_state: global_state_board): 这个方法在每次训练轮次结束时调用,用于刷新仪表板的显示。

  6. epoch_end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 这个方法在每个训练轮次结束时调用,用于记录本轮次的损失函数值。

  7. start_flask_server(self, global_state: global_state_board): 这个方法用于启动flask服务。

  8. step_cost_cal(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 这个方法在每个训练步骤结束时调用,用于计算本步骤的耗时。

  9. epoch_cost_cal(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 这个方法在每个训练轮次结束时调用,用于计算本轮次的耗时。

注意:这个类目前没有发现明显的bug或错误,但需要注意的是,如果训练数据量过大,可能会引起flask服务的内存溢出问题。

__init__

初始化dashboard_plugin类。这个类用于监测训练过程,可以实时查看训练过程中的损失、精度等信息,支持通过web页面查看。

参数:

port: int或None, 可选参数, 默认为None.

用于开启flask服务器的端口号。如果为None,将不会开启flask服务器。

属性:

epoch_data: list, 用于存储每个epoch的数据。

port: int或None, flask服务器的端口号。

step_count: int, 统计训练过程中的步数。

step_cost: float, 统计每步的耗时。

epoch_count: int, 统计训练过程中的epoch数。

epoch_cost: float, 统计每个epoch的耗时。

process_batch_count: int, 每次处理的batch数。

process_batch_index: int, 当前正在处理的batch的索引。

buffer_global_state: global_state_board对象, 用于存储全局状态。

template: jinja2模板对象, 用于渲染web页面。

示例:

db_plugin = dashboard_plugin(port=5000) # 创建一个dashboard_plugin对象,并指定flask服务器的端口为5000

convert_time

这是一个将时间戳转换为本地时间字符串的函数。

参数:

timestamp (int): 时间戳,代表某一时刻的秒数。

返回:

str: 本地时间字符串,格式为YYYY-MM-DD HH-MM-SS。

使用示例:

convert_time(1609459200)

'2021-01-01 08-00-00'

注意:

该函数默认的时区是东八区,即中国北京时间。

start

这是一个名为"start"的成员函数,它是dashboard_plugin类的一部分。这个函数在插件开始时被调用,用于初始化一些参数并可能启动一个Flask服务器。

参数:

global_state (global_state_board): 一个global_state_board对象,包含全局状态信息,如batch_count等。

epoch_state (epoch_state_board): 一个epoch_state_board对象,包含了当前时间步的训练状态信息。

step_state (step_state_board): 一个step_state_board对象,包含了每个训练步骤的状态信息。

返回:

无返回值。

异常:

无异常。

示例:

假设我们有一个dashboard_plugin对象dp,全局状态对象gs,时间步状态对象es和步骤状态对象ss,可以如下调用此函数:

dp.start(gs, es, ss)

注意:

这个函数可能会启动一个Flask服务器,这取决于端口是否被设置。如果端口被设置,Flask服务器将被启动用于处理web请求。

_scale_data

这个方法是用来对数据进行缩放的。其作用是对一个比较大的数据集进行缩放,使得它的长度可以适应图表的绘制。如果数据集的大小超过200,那么它会取第一个数据,然后对剩余的数据每两个进行一次平均,最后取最后一个数据。如果数据集的大小没有超过200,那么直接返回原数据集。

参数:

data: list,需要进行缩放处理的数据集。

返回:

list,长度被缩放后的数据集。

refresh_page

这个方法用于刷新训练进度的可视化仪表板页面。

参数:

global_state (global_state_board): 一个全局状态对象,包含所有全局级别的信息,如模型、优化器等的状态。

返回:

我们首先将历史的epoch数据进行缩放以适应可视化图表。然后,我们从全局状态对象中获取训练的相关信息,并且将它们组织成一个字典。接着,我们用这些信息来生成一个新的HTML页面,该页面可以显示在训练过程中的各种信息,包括模型的训练进度、损失和精度等。最后,我们将生成的HTML保存到指定的文件中。

这个方法没有返回值,它的目的是生成一个新的HTML页面来展示训练的状态。这个页面可以在训练过程中不断刷新,以达到实时查看训练进度的效果。

epoch_end

此函数的主要功能是在每个周期结束时更新相关状态信息。

函数参数:

  • global_state (global_state_board): 保存全局状态信息的类实例。

  • epoch_state (epoch_state_board): 保存周期状态信息的类实例。

  • step_state (step_state_board): 保存步骤状态信息的类实例。

返回值: 无

注:此函数没有返回值,它主要是在每个训练周期结束时,更新保存在实例中的状态信息,包括周期损失等。并且,该函数也会把当前的全局状态信息保存到buffer_global_state变量中,以备后续使用。

start_flask_server

此方法用于启动一个Flask服务器,以便于查看模型训练的实时报告。

参数:

global_state (global_state_board):

global_state对象,包含了模型训练的全局状态信息。

此方法不返回任何值。

方法中主要步骤如下:

  1. 创建一个Flask应用实例,设置静态文件夹为模型保存的文件夹。

  2. 设定Flask应用的日志级别,设置为ERROR,只有错误信息会被记录。

  3. 定义Flask应用的路由。当访问服务器的根目录('/')时,会调用serve_dashboard函数。该函数会刷新报告页面,并返回生成的HTML文件。

  4. 启动一个新线程来运行Flask应用,以便于主程序继续执行模型训练,而不会被阻塞。

注意:

这个方法会在主程序中开启一个新的线程来运行Flask服务器,以便于主程序继续执行模型训练。但是因为Flask服务器和主程序共享了全局状态对象,所以在多线程环境下可能会出现数据竞争的问题。目前的代码中并没有看到对全局状态对象的写操作,所以应该不会出现数据竞争的问题。但是如果后续有对全局状态对象的写操作,需要注意线程安全问题。

此外,如果同时启动了多个训练任务,由于所有任务都在同一个端口启动服务器,可能会出现端口冲突问题。可以考虑让用户在启动任务时指定端口,或者动态分配端口以避免冲突。

step_cost_cal

这个函数是用来计算每个步骤的耗时并更新处理批次的索引的。

参数:

global_state (global_state_board): 全局状态板,保存了全局状态的信息。

epoch_state (epoch_state_board): epoch状态板,保存了epoch状态的信息。

step_state (step_state_board): 步骤状态板,保存了步骤状态的信息。

返回值:

无返回值

epoch_cost_cal

该函数主要用于计算每个epoch的运算时间成本。该函数会在每个epoch结束时被调用。

参数:

global_state: global_state_board类的一个实例,表示全局状态信息。其中包含了训练过程中的各种全局性的信息。

epoch_state: epoch_state_board类的一个实例,表示当前epoch状态信息。其中包含了当前epoch的训练信息,如epoch的损失等。

step_state: step_state_board类的一个实例,表示当前step状态信息。其中包含了当前训练步骤的信息,如每步的开始和结束时间等。

返回:

无返回值。但是会修改类的属性,计算并更新self.epoch_count(完成的epoch数量)和self.epoch_cost(完成所有epoch花费的时间)。

注意事项:

  1. 该函数计算的是每个epoch的运算时间,即完成一个epoch所花费的时间。是通过结束时间减去开始时间得到的。

  2. 该函数会在每个epoch结束时被调用一次,所以self.epoch_count和self.epoch_cost会在每个epoch结束时更新。

classification_evalution_trainset

这是一个用于训练集分类评估的类classification_evalution_trainset,继承自trainer_plugin_base类。

这个类主要是用来计算和记录训练过程中的各类评估指标,包括准确率、精度、召回率和F1分数,并将各类评估指标的结果绘制成图表。它有三个主要的方法:begin、batch_end和epoch_end,分别在训练开始、每个批次结束和每个周期结束时被调用。

  • __init__方法是类的初始化方法,设置了一些基本的属性,包括logit_convert_func函数、epoch_result_true_label列表、epoch_result_predict_label列表、scores列表和evaluator评估器。

  • begin方法在训练开始时被调用,它会为每种评估指标创建一个空列表,用于存储指标的计算结果。

  • batch_end方法在每个批次结束时被调用,它会将当前批次的预测结果和真实结果添加到对应的列表中。

  • epoch_end方法在每个周期结束时被调用,它会计算当前周期的所有评估指标,然后将计算结果添加到图表数据中,并清空存储预测结果和真实结果的列表,为下一个周期的计算做准备。

使用示例:

classification_evaluator = classification_evalution_trainset()

classification_evaluator.begin(global_state, epoch_state, step_state)




**for batch in batches:**

classification_evaluator.batch_end(global_state, epoch_state, step_state)

classification_evaluator.epoch_end(global_state, epoch_state, step_state)

注意:这个类需要配合trainer_plugin_base类和Evaluator类使用,而且在使用前需要确保全局状态、周期状态和步骤状态的设置是正确的。

__init__

这是一个构造函数,用于初始化classification_evalution_trainset类的实例。

分类评估训练集类用于对分类模型的训练结果进行评估,包括准确度、精确度、召回率和F1分数等指标的计算。

该类的实例会在训练过程中进行调用,以记录每个批次和每个周期的训练结果,并在全局状态板上展示这些结果。

参数:

logit_convert_func (function, 默认值是lambda x: x): 一个函数,用于将预测结果(logit)转换为预测标签。默认的转换函数是lambda x: x,即不做任何转换。

属性:

logit_convert_func (function): logit转换函数。

epoch_result_true_label (list): 真实标签的列表,用于记录每个周期的真实标签。

epoch_result_predict_label (list): 预测标签的列表,用于记录每个周期的预测标签。

scores (list): 评估指标列表,包括准确度、精确度、召回率和F1分数。

evaluator (Evaluator): 评估器,用于计算每个周期的评估指标。

使用示例:

\# 创建一个分类评估训练集实例,使用自定义的logit转换函数

evaluator = classification_evalution_trainset(logit_convert_func=my_convert_func)

注意:

无已知错误或bug。

begin

这是一个类方法,用于在训练过程的开始阶段进行一些初始化工作。具体来说,该方法是在训练的每个epoch开始时被调用的,它会为每一个评估指标创建一个空的列表,用于存储每个epoch的评估结果。这些评估结果随后会被用于生成训练过程的图表数据。

参数:

  • global_state (global_state_board): 全局状态板实例,用于存储全局的训练状态,如学习率、训练集和验证集的损失等。

  • epoch_state (epoch_state_board): epoch状态板实例,用于存储当前epoch的训练状态,如当前epoch的损失、准确率等。

  • step_state (step_state_board): 步骤状态板实例,用于存储当前步骤的训练状态,如当前步的输入、输出、损失等。

返回:

  • 返回类型是None。

例子:

  • 如果你的评估指标有accuracy、precision、recall和f1 score,那么在开始阶段,global_state.chart_datas将会被初始化为{'accuracy': [], 'precision': [], 'recall': [], 'f1 score': []}。

batch_end

batch_end 是一个方法,用在每次批量处理数据后. 它从step_state中获取预测的输出(logit)和实际标签(true label),并以列表形式存储它们,以便在后续的评估阶段使用。

参数:

  • global_state (global_state_board):全局状态板,它存储了全局的数据,如整个训练过程的图表数据。

  • epoch_state (epoch_state_board):时代状态板,它存储了当前时代的数据。在这个方法中并未使用。

  • step_state (step_state_board):步骤状态板,它存储了当前步骤(或称之为批次)的数据,如logit(即模型的输出)和converted_label(即实际标签)。

返回:

在这个方法中,它首先将logit和实际标签转换为列表形式,然后将其添加到self.epoch_result_true_labelself.epoch_result_predict_label中,这两个列表存储了一个epoch的所有步骤(或批次)的结果,用于在epoch结束时进行模型评估。

注意:这个方法不返回任何值,它只是处理和存储数据。

示例:

以下是如何使用此方法的一个示例:

\# 假设我们有一个classification_evaluation_trainset的实例

trainer = classification_evaluation_trainset()

\# 假设我们有一些全局,时代和步骤状态

global_state = ...

epoch_state = ...

step_state = ...

\# 我们可以在每个批次结束后调用这个方法

trainer.batch_end(global_state, epoch_state, step_state)

epoch_end

这个函数定义了在每个训练周期结束时的操作。

函数参数:

global_state (global_state_board): 全局状态,用于存储整个训练过程中的信息,例如图表数据等。

epoch_state (epoch_state_board): 当前训练周期的状态,包含了当前周期的训练信息。

step_state (step_state_board): 当前训练步骤的状态,包含了当前训练步骤的信息。

该函数首先会根据当前训练周期的真实标签和预测标签,通过Evaluator计算出所有评价指标的结果,并存储在result中。然后,遍历result, 将每个评价指标的结果添加到全局状态的图表数据中。最后,清空当前训练周期的真实标签和预测标签,为下一周期的训练做准备。

无返回值。

使用示例:

在训练过程中,每当一个训练周期结束时,都会调用该函数来处理周期结束时的操作。

注意:

该函数假设step_state中的logit已经被转换成了标签格式,如果没有转换,可能会影响结果的正确性。

log_plugin

log_plugin 是一个继承自trainer_plugin_base的类。这个类的主要目的是在训练过程中进行日志记录,包括在每一个批次和每一个训练周期开始和结束的时候。

这个类的主要功能如下:

  • 在训练开始时,创建日志文件并设置文件路径。

  • 在每一个训练周期和批次开始时,记录堆栈中的日志信息。

  • 在每一个训练周期和批次结束时,记录训练的损失并清空堆栈中的日志信息。

  • 在训练结束时,将堆栈中的日志信息写入到日志文件中。

使用这个类的例子如下:

logger = log_plugin(show_in_batch=True, show_in_epoch=True)

logger.start(global_state, epoch_state, step_state)

\# 在训练过程中,可以使用下面的方法记录日志信息

logger._log("Training started.")

\# 在训练结束时,使用下面的方法将日志信息写入到文件中

logger.end(global_state, epoch_state, step_state)

类方法:

  • \_\_init\_\_(self, show_in_batch=False, show_in_epoch=True): 构造函数,初始化类的实例。

  • start(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在训练开始时调用,设置日志文件的路径。

  • end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在训练结束时调用,将堆栈中的日志信息写入到日志文件中。

  • epoch_begin(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在每一个训练周期开始时调用,记录堆栈中的日志信息。

  • epoch_end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在每一个训练周期结束时调用,记录训练的损失并清空堆栈中的日志信息。

  • batch_begin(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在每一个批次开始时调用,记录堆栈中的日志信息。

  • batch_end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board): 在每一个批次结束时调用,记录训练的损失并清空堆栈中的日志信息。

注意:

  • 需要保证global_state_board, epoch_state_boardstep_state_board已经被正确初始化。

Invoke

log_plugin类是一个用于记录训练过程的插件类。 它可以在训练开始,结束,每个epoch开始和结束,每个batch开始和结束时记录和输出训练信息。该类继承自trainer_plugin_base。用户可以通过设置show_in_batch和show_in_epoch来决定是否在每个batch或者epoch中输出训练信息。

使用例子如下:

log = log_plugin(show_in_batch=True, show_in_epoch=True)

log.start(global_state, epoch_state, step_state)

log.batch_begin(global_state, epoch_state, step_state)

log.batch_end(global_state, epoch_state, step_state)

log.epoch_begin(global_state, epoch_state, step_state)

log.epoch_end(global_state, epoch_state, step_state)

log.end(global_state, epoch_state, step_state)

def Invoke(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):

此函数是log_plugin类中的一个函数,它并没有具体的实现。这是一个抽象方法,需要在子类中重写。这个方法的目的是为了实现在训练过程中的某些特定时刻进行一些操作,比如在每个epoch开始时记录当前的loss值等。

参数:

global_state: global_state_board类型,表示全局状态,包含整个训练过程的信息,如模型保存的位置,训练的进度条等。

epoch_state: epoch_state_board类型,表示在当前epoch的状态,包含当前epoch的索引,当前epoch的loss值等。

step_state: step_state_board类型,表示在当前step的状态,包含当前step的索引,当前step的loss值等。

返回值: 无

注意: 由于这个函数在父类中没有具体的实现,所以如果在子类中没有重写这个函数,而又调用了这个函数的话,程序会报错。

_log

_log是一个私有方法,用于记录内容到日志文件中。

方法将接收的参数content追加到指定的日志文件中。如果日志文件不存在,会创建一个新的。

参数:

content (str):需要记录到日志文件中的内容。

返回:

例子:

下方是一个使用示例。创建一个log_plugin对象,然后调用_log方法记录内容到日志文件。


log_plugin_obj = log_plugin()

log_plugin_obj._log("This is a log message")

注意:

  • 这个方法不应该直接调用,它是一个内部方法,用于log_plugin类内部

  • 这个方法没有返回值,它的作用是副作用,即写入文件。

_log_stackcontent

这个方法的主要目的是将全局状态(global_state)中的日志堆栈内的所有消息进行日志记录,然后清空日志堆栈。

参数:

global_state (global_state_board): 训练过程的全局状态,包含训练的各种信息和状态,如日志堆栈。

返回:

注意这个方法没有返回值,但是会改变全局状态(global_state)中的日志堆栈,完成日志记录后会清空日志堆栈。

示例:

# 创建log_plugin对象

log_plugin_obj = log_plugin()

# 假设我们已经有了一个global_state对象

global_state = global_state_board()

# 在某一步骤添加日志信息到global_state的日志堆栈中

global_state.log_stack.append("Step 1 completed.")

# 调用_log_stackcontent方法进行日志记录

log_plugin_obj._log_stackcontent(global_state)

# 此时,"Step 1 completed."已经被记录到日志文件中,global_state中的日志堆栈已经被清空。

start

在训练开始时调用的方法。

该方法在训练开始时被调用。主要作用是设定日志文件的保存路径。日志文件将保存在模型文件夹下,文件名为"log.txt"。

参数:

global_state (global_state_board): 全局状态板,记录全局的训练状态,如模型文件夹路径等。

epoch_state (epoch_state_board): 该训练周期的状态板,记录了该训练周期的相关信息。

step_state (step_state_board): 当前训练步骤的状态板,记录了当前训练步骤的相关信息。

返回:

无返回值。

注意:

该方法不应该被直接调用,而应该作为训练插件在训练过程中的一个生命周期进行调用。

end

这是一个end方法,它是log_plugin类的一部分,该类是trainer_plugin_base的子类。该方法在训练的最后阶段被调用,用于输出并清空global_state的日志堆栈。

参数:

  • global_state (global_state_board): 全局状态板,包含全局的状态信息,如模型、优化器等。

  • epoch_state (epoch_state_board): 当前训练周期状态板,包含当前周期的状态信息,如当前周期的损失值、准确率等。

  • step_state (step_state_board): 当前步骤状态板,包含当前步骤的状态信息,如当前步骤的损失值、准确率等。

返回值:

该函数没有返回值。

示例:

**def end(self, global_state: global_state_board, epoch_state: epoch_state_board, step_state: step_state_board):**

self._log_stackcontent(global_state)

注意:

  • _log_stackcontentlog_plugin类的一个私有方法,它会输出并清空global_state的日志堆栈。

epoch_begin

epoch_beginlog_plugin类的一个方法,该方法在每个训练周期开始时被调用。它执行的主要操作是记录全局状态中的日志信息。log_plugin类主要用于在训练过程中实现日志记录功能,通过记录每一个epoch和batch的开始和结束状态,以及在这些状态下模型的损失值,帮助用户了解模型训练过程中的情况。

参数:

  • global_state (global_state_board类型): 训练过程中的全局状态,包含所有插件需要访问的全局信息。

  • epoch_state (epoch_state_board类型): 当前训练周期的状态信息,包括当前的训练周期数、训练周期损失等。

  • step_state (step_state_board类型): 当前步骤的状态信息,包括当前步骤数、步骤损失等。

无返回值。

示例:

假设我们正在训练一个模型,该模型的训练过程由一个log_plugin实例来监控和记录,那么在每个训练周期开始时,就会调用这个epoch_begin方法:

log = log_plugin()

log.epoch_begin(global_state, epoch_state, step_state)

在这个方法中,首先会清空global_state的日志堆栈,并将堆栈中的每一条消息记录到日志文件中。这样就可以在每个训练周期开始时,清空上一个训练周期的日志信息,并记录新的训练周期的日志信息。

epoch_end

epoch_endlog_plugin 类的一个方法,用于处理每个训练周期结束后的逻辑。

此方法主要用于记录和保存训练周期结束时的状态信息,如损失值,并将这些信息写入日志文件中。此外,此方法还会清空全局状态的日志栈,以准备下一次的训练周期。

参数:

  • global_state (global_state_board 类型): 全局状态板,用于存储全局的状态信息。

  • epoch_state (epoch_state_board 类型): 训练周期状态板,用于存储当前训练周期的状态信息。

  • step_state (step_state_board 类型): 训练步骤状态板,用于存储当前训练步骤的状态信息。

返回值:

  • 此方法没有返回值。

注意:

  • 此方法被 invoke_at 装饰器修饰,会在每个训练周期结束时被调用。

示例:

\# 创建 log_plugin 对象

log_plugin = log_plugin()

\# 创建全局状态板、训练周期状态板和训练步骤状态板

global_state = global_state_board()

epoch_state = epoch_state_board()

step_state = step_state_board()

\# 训练周期结束,调用 epoch_end 方法

log_plugin.epoch_end(global_state, epoch_state, step_state)

batch_begin

这是一个在批处理开始前执行的函数,功能是记录全局状态的日志。

它的参数包括全局状态、当前的epoch状态和step状态。

参数:

global_state(global_state_board): 存储全局状态信息的对象

epoch_state(epoch_state_board): 存储当前epoch状态信息的对象

step_state(step_state_board): 存储步骤状态信息的对象

返回:

无返回值

batch_end

在一次训练批次结束时会调用这个方法。用于记录该批次的训练情况,并将记录写入日志文件。

参数:

global_state: global_state_board类型,代表全局状态,包含了全局信息如模型、数据等。

epoch_state: epoch_state_board类型,代表当前的训练周期状态,包含了当前训练周期的信息。

step_state: step_state_board类型,代表当前训练步骤的状态,包含了当前步骤的信息。

返回值:

使用示例:

log_plugin.batch_end(global_state, epoch_state, step_state)

注意事项:

本方法没有返回值,主要功能是记录日志。

__init__

初始化log_plugin类。 这是一个用于训练过程中记录日志的插件,可以设定是否在每个batch和epoch开始和结束时记录日志。

参数:

show_in_batch(可选): 布尔值,如果为True,则在每个batch开始和结束时记录日志。 默认为False。

show_in_epoch(可选): 布尔值,如果为True,则在每个epoch开始和结束时记录日志。 默认为True。

属性:

save_file: 用于保存日志的文件的路径,具体路径在start函数中由global_state.model_folder指定。

show_in_batch: 是否在每个batch开始和结束时记录日志。

show_in_epoch: 是否在每个epoch开始和结束时记录日志。

例子:

log_plugin = log_plugin(show_in_batch=True, show_in_epoch=False)

在开始和结束每个batch时将日志记录在文件中,但是不在epoch开始和结束时记录。

mlsample

该Python模块主要负责数据集的获取、处理和管理。它提供了从多种数据源(如本地磁盘、SSH、Minio服务等)下载、上传和删除数据集的功能,以及提取文本数据、输出CSV文件、分割样本集、转换数据格式等操作的方法。此外,该模块还包含了一些用于对样本数据进行加载、存储、处理和管理的工具类,以及一些用于生成和保存随机样本集的函数,旨在提高处理各种数据源的代码的可维护性和扩展性。

类成员:

name info
LocalDisk_NLSampleSource 这是一个用于从本地磁盘加载、存储和处理样本数据的工具类。
NLSampleSourceBase 这是一个抽象基类,用于定义统一的样本源接口,提高处理各种数据源的代码可维护性和扩展性。
Memory_NLSampleSource Memory_NLSampleSource是一个在内存中维护数据字典的类,用于创建、检查、操作和删除数据集。
SSHSampleSource SSHSampleSource类是用于处理SSH源文件的工具,包括下载、更新、创建数据集等操作。
MinioSampleSource MinioSampleSource类是一个用于与Minio服务端交互的类,提供数据上传下载和检查功能。
SampleSet SampleSet类用于从指定来源获取、处理和管理样本数据,支持打乱、获取、跳过、批量获取和自定义操作等功能。

函数成员:

name info
set_list 这是一个根据给定的源类型和路径获取并设置数据集列表的函数。
download 这个函数的作用是从不同类型的数据源如"minio"或"ssh"下载指定的数据集到预设或默认路径。
upload 这是一个支持通过minio和ssh两种方式上传文件到指定路径的函数。
set_info 这是一个函数,用于打印指定样本集的基础信息和详细信息。
set_data_info 这个函数用于设置并打印出数据集中每个标签键的计数信息。
delete_set 这是一个删除指定路径下特定数据集的函数,如果未指定路径则会使用配置文件中的默认路径。
capture_str capture_str函数用于从指定文件夹中提取.txt和.pdf文件的文本数据,并保存到指定路径。
output_csv 这是一个从指定源路径读取样本数据,并将其保存到CSV文件的函数。
SplitSet 这是一个用于根据指定规则将原始样本集分割成多个子样本集的函数。
convert_jsonl 这个函数用于从指定数据源读取数据,通过指定函数转换为字典形式,并写入到目标jsonl文件中。
create_SampleSet 这是一个生成并保存指定数量和大小的随机样本集的函数。

set_list

此函数的目的是从不同的源(本地、Minio或SSH)获取并设置数据集列表。它首先根据源类型验证路径或配置,然后从源获取数据集列表信息,最后以表格形式打印出数据集的关键信息。

参数:

tsource (str):数据源类型,默认为"local"。可选值有"local"(本地),"minio"(Minio对象存储),"ssh"(SSH服务器)。

path (str):数据集的路径。如果未提供,将从配置实例中获取数据集路径。

match (str):匹配数据集名称的模式字符串。如果提供了该参数,函数只会返回匹配该模式的数据集。

返回:

此函数没有返回值。

可能的错误:

如果指定的路径不存在,或者配置信息不完整,函数会引发异常。

如果从源获取的数据集信息为空,函数也会引发异常。

示例:

set_list("local", "/path/to/dataset")

set_list("minio", "/path/to/dataset", "train*")

注意:

在使用"minio"或"ssh"作为数据源类型时,你需要提供完整的配置信息,如endpoint, access_key, secret_key等。

匹配字符串中可以使用通配符,例如"train*"匹配所有以"train"开头的数据集。

download

这个函数的主要目的是用来下载数据集的. 它主要根据tsource这个参数来决定数据源是哪个类型, 例如 "minio" 或者 "ssh". 如果tsource是"minio",那么会从minio的endpoint, 使用access key和secret key,从指定bucket中下载数据到预设的路径. 如果tsource是"ssh",那么会通过ssh从远程服务器的目标路径下载数据到预设的路径. 在两种情况中,如果没有给出预设的路径,那么会使用配置文件中的sample_source_path作为默认路径.

参数:

tsource (str): 数据源的类型, 可选的类型有 "minio" 和 "ssh".

set_name (str): 要下载的数据集的名称.

path (str, optional): 数据下载的预设路径. 如果没有给出, 那么会使用配置文件中的sample_source_path作为默认路径.

返回类型:

这个函数没有返回.

异常:

如果数据源的类型不支持, 那么会抛出一个Exception.

如果在配置文件中找不到对应的配置信息, 那么也会抛出一个Exception.

使用例子:

download('minio', 'dataset1', '/path/to/dataset1')

download('ssh', 'dataset2')

注意事项:

这个函数在运行时会检查配置文件中对应的数据源的配置信息是否都存在,如果不存在,那么会抛出异常。

对于ssh数据源,如果配置文件中没有提供ssh的端口信息,那么会默认使用22端口。

upload

此函数用于上传源文件到指定路径,通过读取配置文件,支持两种上传方式:minio和ssh。

参数:

tsource (str): 上传文件的类型,可以是"minio"或者"ssh"。

set_name (str, 可选): 集合名称。默认为None。

path (str, 可选): 文件路径。如果没有指定,则会从配置文件中读取。

返回:

错误:

Exception: 如果tsource不是"minio"和"ssh"之一,或者必要的配置信息不存在,将抛出异常。

例子:

upload("minio", "myset", "/path/to/myfile")

upload("ssh", "myset", "/path/to/myfile")

注意:

此函数依赖get_config_instance函数和MinioSampleSource、SSHSampleSource类,需要先定义这些依赖才能正常使用。

并且此函数没有进行参数合法性检查,调用者需要保证参数的正确性。

set_info

这个函数是用来打印一组特定样本的基础信息和详细信息的。基础信息包括其名称、数量、基础集、键、标签和描述等,详细信息则包括每一行数据的各个属性值以及属性的名称。

参数:

setname (str): 需要打印信息的样本集的名称。

count (int, 可选): 想要打印的样本行数,默认为5。

max_len (int, 可选): 每一个样本值打印的最大长度,如果超过这个长度则会被截断,截断的部分以...表示,默认为100。

path (str, 可选): 样本文件的存储路径。如果不提供,则会从config实例中获取"sample_source_path"的值。

返回值:

用法示例:

set_info('train', count=10, max_len=50, path='/path/to/data')

注意:

如果providing的路径不存在,将会引发异常。此外,如果setname没有在指定路径下找到,也会引发异常。

set_data_info

此函数用于设置数据信息,并以表格形式显示每个标签键的计数。

参数:

setname (str): 需要分析的数据集名称。

label_key (str): 数据集中的标签键。

path (str, optional): 数据集的路径。若未提供,则从配置中获取默认路径。

返回:

无返回值。但会打印出每个标签键的计数信息。

异常:

当提供的路径不存在时,会抛出异常。

示例:

set_data_info('train', 'label', '/path/to/dataset')

可以显示训练集中每个标签的计数信息。

注意:

这个函数没有处理数据集中可能存在的空标签的情况,如果数据集中存在空标签,可能会导致错误。

delete_set

删除指定的数据集。

这个函数会在指定的路径下查找并删除给定的数据集。如果没有提供路径,那么它会从配置文件中获取默认路径。

参数:

setname (str): 要删除的数据集的名称。

path (str, 可选): 数据集所在的文件路径。默认为None,此时会从配置文件中获取默认路径。

返回:

无返回

异常:

如果指定的路径不存在,将会引发异常。

示例:

delete_set('my_dataset', '/path/to/datasets')

"my_dataset deleted."

注意:

这个函数会永久性地从磁盘上删除数据集,所以在调用之前一定要确保备份重要数据。

capture_str

capture_str 是一个函数,用于从给定文件夹中提取文本数据,并将结果保存到特定路径的本地磁盘或内存中。

该函数将在给定的文件夹中查找所有.txt和.pdf文件,然后读取他们的内容。对于pdf文件,它将提取每一页的文本。如果遇到错误,它将记录错误信息,然后继续处理其他文件。最后,它将所有收集的数据保存到指定路径的本地磁盘或内存中。

参数:

setname (str): 创建的数据集的名称。

folderpath (str): 包含要提取数据的文件的文件夹路径。

path (str, 可选): 保存数据的路径。如果未提供,将使用配置实例的 "sample_source_path"。

返回:

None

使用示例:

capture_str("dataset1", "./data_folder", "./output_folder")

注意:

  • 该函数只处理.txt和.pdf文件,遇到其他类型的文件将引发异常并记录日志。

  • 当给定的路径不存在时,该函数将引发异常。

  • 如果在处理文件时遇到错误,该函数将记录错误信息,然后继续处理其他文件。

可能的错误或bug:

  • 如果文件夹中的文件数量过多,可能会导致内存问题。

  • 如果pdf文件的页数过多,可能会导致处理速度慢。

  • 如果文件内容包含无法正确解析的字符,可能会引发异常。

output_csv

这个函数的目的是从给定的源路径读取一定数量的样本数据,并将其输出保存到CSV文件中。

参数:

setname (str): 具有样本数据的集合名称。

tpath (str, 可选): 源路径。如果未提供,将使用配置文件中的"sample_source_path"。默认值为None。

path (str, 可选): 保存CSV文件的路径。默认值为"a.csv"。

count (int, 可选): 从setname中读取的样本数量。默认值为100。

返回类型:

此函数无返回值。执行完毕后,将在指定的路径保存CSV文件。

使用示例:

output_csv("mySet", tpath="/my/path", path="/my/path/a.csv", count=200)

注意:

如果提供的tpath不存在,或者在配置文件中没有定义"sample_source_path",则会引发异常。

source使用的是LocalDisk_NLSampleSource,此类专门用于从本地磁盘读取样本数据。

使用csv.writer将数据写入csv文件,可以保证有效的数据存储和读取。

SplitSet

该函数用于根据给定的名称-键字典对样本源进行分割。每个键值对应的样本都会生成一个新的独立样本集,新生成的样本集将会以"原样本集_键名"的形式命名。

参数:

samplesource(NLSampleSourceBase): 用于操作的样本来源。

ori_set_name(str): 需要被分割的原始样本集的名称。

key_func: 用于从样本中提取键的函数。

name_to_key_dict(dict): 用于划分新样本集的名称-键字典。字典的键将作为新样本集的名称,字典的值用于匹配样本。

need_shuffle(bool, optional): 是否需要对原始样本集进行乱序处理。默认为True。

返回值:

无返回值。

注意事项:

新生成的样本集将会包含原始样本集的所有元数据。

举例:

假设我们有一个名为"total"的样本集,我们想根据样本的标签将其分割成两个子集"positive"和"negative"。那么我们可以这样使用该函数:





**def get_label(sample):**

return sample['label']

name_to_key_dict = {

"positive": {1: 10000},

"negative": {0: 10000}

}

SplitSet(sample_source, "total", get_label, name_to_key_dict)

这样,"total"样本集就被分割成了"total_positive"和"total_negative"两个样本集,每个样本集包含了10000个样本。

LocalDisk_NLSampleSource

这个类是用来从本地磁盘加载和存储样本的一个工具类,它实现了NLSampleSourceBase接口。主要功能包括从本地磁盘获取样本数据、向本地磁盘写入样本数据、获取样本的元数据等。

类的主要成员变量如下:

  • base_folder: 存储样本的基础文件夹路径

  • int_size和shortint_size: 整数和短整数的字节大小

  • pointer_size: 指针的字节大小

  • header_node_size和file_size: 头节点和文件的大小

  • file_pool: 存储打开的文件对象的字典,key为文件名,value为文件对象

  • base_seek_dic和linked_seek_dic: 存储基础头部和链接头部的字典

主要方法如下:

  • __init__: 初始化一个LocalDisk_NLSampleSource实例

  • get_dir_list: 获取本地磁盘上存储样本的所有目录列表

  • get_file_date: 获取指定文件的日期

  • _try_get_file_obj: 尝试获得指定文件对象,如果不存在则会创建

  • __del__: 在删除对象时,关闭所有打开的文件

  • flush: 刷新文件缓冲区,保证所有的写操作都被写入磁盘

  • load_pointer_data: 加载指定的指针数据

  • create_new_set: 创建新的样本集

  • has_set: 判断是否存在指定的样本集

  • add_row: 向指定的样本集中添加一行数据

  • iter_data: 从指定的样本集中迭代读取数据

  • read_one_row: 从指定的样本集中读取一行数据

  • iter_pointer: 从指定的样本集中迭代读取指针数据

  • get_set_count: 获取指定样本集的样本数

  • get_metadata_keys: 获取指定样本集的元数据键值

  • print_set_info: 打印指定样本集的信息

  • delete_set: 删除指定的样本集

  • add_attachment: 向指定的样本集中添加附件

  • read_attachment: 读取指定样本集的附件信息

示例:

ld_nls = LocalDisk_NLSampleSource("samples")

ld_nls.create_new_set("set1", "description of set1", ["tag1", "tag2"], ["key1", "key2"])

ld_nls.add_row("set1", ["data1", "data2"])




**for data in ld_nls.iter_data("set1"):**

print(data)

get_dir_list

这个函数的主要目的是获取指定目录下的所有子目录信息。

参数:

self: 指代类实例的自身引用。

返回:

sets_infos: 返回一个字典,字典的key是子目录名,value是包含该子目录的相关信息的字典,包括'meta'、'count'、'filecount' 和 'create_date'等字段。

异常:

如果在执行过程中遇到错误,函数会抛出相应的异常。

使用示例:

假设我们有一个LocalDisk_NLSampleSource类的实例ldnss,我们可以这样使用这个函数:

dir_info = ldnss.get_dir_list()

这样就可以得到ldnss实例的基础目录下的所有子目录信息。

__init__

这是LocalDisk_NLSampleSource类的初始化函数,负责初始化该类的一些基本参数。

参数:

folder_path: 一个字符串,表示基础文件夹的路径。

该类主要负责操作本地文件夹中的数据,包括获取文件夹列表,获取文件的创建日期,读取和写入数据等操作。在初始化过程中,会对基本参数进行赋值,并确保基础文件夹存在。

示例:

数据源 = LocalDisk_NLSampleSource("/path/to/folder")

# 之后我们可以进行一些数据操作,例如获取文件夹列表:

dir_list = 数据源.get_dir_list()

注意事项:

如果提供的基础文件夹路径不存在,则会抛出异常。请确保提供一个存在的文件夹路径。

返回值: 无

错误/异常:

如果基础文件夹不存在,则会抛出异常。

get_file_date

该函数用于获取本地磁盘上的指定文件的最后修改日期。

参数:

name (str):文件名(不含扩展名)。函数将在类创建时设定的基础文件夹路径下寻找该文件。

返回:

datetime:返回一个datetime对象,表示该文件的最后修改日期。如果文件不存在,将抛出异常。

使用方法:

get_file_date("example_filename")

会返回名为"example_filename.dlib"文件的最后修改日期。

注意:

  1. 文件名应不包含扩展名,扩展名会在函数内部自动添加。

  2. 在文件系统支持的情况下,函数会返回文件的最后修改日期,而非创建日期。

_try_get_file_obj

这个函数是尝试获取文件对象的方法,主要用于在文件池中查找指定的文件并返回其文件对象。如果文件池中不存在该文件,那么会打开这个文件并将它添加到文件池中。

参数:

name (str): 指定的文件名。

file_index (int): 文件索引,默认值为0。

返回:

返回指定文件的文件对象。

示例:

这个函数的使用示例如下,主要在类的其他方法中被调用。





**def some_function(self, name: str):**

file_obj = self._try_get_file_obj(name)

\# 这里可以进行后续的操作,比如读取文件内容等。

注意:

  1. 这个函数无法处理文件打开失败的情况,如果文件不存在或者无法访问,会抛出异常。

  2. 这个函数不负责关闭文件,需要在调用者中显式地关闭文件。

__del__

这是一个析构函数,用于当该对象不再被使用或调用时,清理并释放资源。在这个函数中,我们遍历文件池中的所有文件名,并关闭对应的文件。

该函数没有参数和返回值。

注意:在Python中,析构函数的执行不是确定的,也就是说我们无法准确预知何时析构函数会被调用。因此,尽量避免在析构函数中执行关键任务,而应在程序的控制流中明确地执行这些任务。

示例:

**class LocalDisk_NLSampleSource(NLSampleSourceBase):**

...




**def \_\_del\_\_(self):**




**for name in self.file_pool.keys():**

self.file_pool[name].close()

...

在上面的例子中,我们在 \_\_del\_\_ 函数中关闭所有打开的文件。当对象不再被使用时,Python的垃圾回收器将会调用该函数,释放所占用的资源。

flush

刷新缓冲区。

这个函数遍历文件池中的所有文件,对每个文件执行刷新操作,将缓冲区中的内容写入到文件中。这对于确保数据的一致性和完整性非常重要,特别是在长时间运行、大量输入/输出操作的程序中。

注意:此函数不带任何参数,也没有返回值。

如果在刷新操作中遇到任何错误或异常,它将由 Python 的内置 IOError 异常来处理。

使用示例:

disk = LocalDisk_NLSampleSource('/path/to/directory')

# ... 执行一些文件操作

disk.flush() # 确保所有更改都已写入到文件中

_read_int

这个函数是从一个文件中读取一个整数。

参数:

f: 一个文件对象。这个文件对象应该已经打开,并且设置为二进制模式。

返回:

返回一个整数。这个整数是从文件中读取的,转换成大端字节序,然后转换成无符号整数。

例子:

file_obj = open('path_to_file', 'rb')

num = _read_int(file_obj)

print(num) # 打印出文件中读取的整数

注意:

这个函数没有做任何的错误处理。如果文件中没有足够的字节来读取一个整数,或者文件没有打开,或者文件不在正确的位置,这个函数可能会抛出异常。

_write_int

该方法用于将整数值写入文件。

参数:

f (file): 需要写入的文件对象

int_value (int): 需要写入的整数值

返回类型: 无

示例:

_write_int(file_obj, 100)

注意:

在使用此函数时,确保文件对象已正确打开并且处于正确的位置,否则可能会覆盖原有的数据.

_add_int_plusone

这个函数是用于在指定位置将整数值加一。

参数:

f: 文件对象, 需要进行读写操作的文件对象。

seekp: int, 指定的文件位置。

返回:

这个函数没有返回值。

使用例子:

假设我们有一个文件对象f,我们希望在文件的第10个字节的位置加1,我们可以这样使用这个函数:

_add_int_plusone(f, 10)

_read_shortint

此函数是用于从给定的文件对象中读取一个大小为shortint_size的整数,并返回该整数。

参数:

f: 文件对象,已经打开并可以读取。

返回:

返回从文件对象读取的整数。

示例:

f = open('test_file', 'rb')

value = _read_shortint(f)

print(value)

_write_shortint

这是一个Python函数,用于将整数值写入到文件中。

参数:

f: 文件句柄,文件必须以二进制写入模式打开

int_value: 需要写入的整数值

返回:

无返回值

使用示例:

with open("testfile", 'wb') as f:

_write_shortint(f, 10)

注意:

  • 这个函数不会关闭文件,因此需要在调用它后手动关闭文件。

  • 该函数使用big-endian字节顺序以无符号整数格式写入数据,这点需要注意与读取时的字节顺序保持一致。

_read_pointer

此函数用于从文件中读取一个“指针”。在这个上下文中,一个“指针”是一个由两部分组成的元组,其中包括一个页码和一个起始位置。这个函数首先读取一个短整数,然后读取一个普通的整数。

参数:

f -- 一个打开的文件或者类文件对象。

返回:

一个元组,第一个元素是页码,第二个元素是起始位置。

_write_pointer

此函数是向指定的文件对象中写入指针信息,包括页码和偏移量。

Args:

f: 需要写入的文件对象,必须是已打开且可以读写的文件对象。

page (int): 需要写入的页码信息,用于标识数据在文件中的位置。

seek (int): 需要写入的偏移量信息,用于标识数据在文件中的位置。

Returns:

无返回值。

Raises:

无特定异常,但如果文件对象无法写入或者参数类型不正确,会抛出异常。

Example:

# 打开一个文件,然后向其中写入页码和偏移量

with open('test.dlib', 'wb') as f:

_write_pointer(f, 5, 1024)

注意:

此函数不会自动关掉文件对象,需要在外部手动关闭。

_read_node

此函数用于从文件流中读取并返回一个序列化的对象。

参数:

f: 文件流对象。

返回:

返回从文件流中读取的已反序列化的对象。

使用示例:

**with open('filename', 'rb') as f:**

obj = _read_node(f)

注意:

  • 文件流对象f必须已经打开并可读。

  • 在文件流中的当前位置应该是序列化对象的开始位置。函数会从当前位置开始读取,而不是从文件的开头或结尾。

  • 此函数使用pickle模块进行反序列化,因此文件流中的数据必须是使用pickle序列化的。

_seek_to_node

这个函数的目的是读取一个文件的节点位置。函数首先保存当前的文件指针位置,然后读取该位置的两个整数值,分别是节点的长度和数据的实际长度。最后,函数将文件指针移动到实际数据的末尾,并返回原始的文件指针位置。

参数列表:

f: 文件对象。用于读取和定位节点。

返回类型:

int,返回的是文件对象的指针位置。

使用方法:

假设我们有一个文件对象f,我们可以这样调用此函数:


location = _seek_to_node(f)

这将返回指针位置,然后我们可以在其他函数中使用这个位置来读取或写入数据到文件对象f。

注意:

此函数假定文件对象f已经被打开了,并且可以读取。如果文件不可读或者没有打开,则使用此函数将会出错。此外,此函数会改变文件指针的位置,因此在使用后应当小心保存和恢复原始的文件指针位置。

_write_node

该函数的主要目的是将给定的节点数据(node)写入到指定的文件对象中(f)。数据的最大尺寸可以通过可选的size参数进行设置。

参数:

  • f: 要写入的文件对象,通常是一个已经打开的文件或者类文件的对象。

  • node: 要写入的节点数据,数据类型并未特定,可以是任意Python对象,这个对象会被pickle序列化存储。

  • size: 可选参数,用于设置要写入的数据的最大尺寸,单位为字节。如果Node的序列化后的长度超过此值,将会抛出异常。如果不设置此参数则不会对数据大小进行限制。

返回值:

  • 无。此函数没有返回值,但是会直接改变传入的文件对象,写入的数据将保存在文件的当前位置。

例子:


f = open('test.dat', 'wb')

node = {'name': 'test', 'value': 123}

_write_node(f, node, size=1024)

f.close()

这个例子中,创建了一个新的文件对象f,指向文件'test.dat',然后创建了一个字典作为节点数据,最后调用_write_node函数将节点数据写入到文件中,设置了最大的数据大小为1024字节。

注意事项:

  • 请确保在使用完文件后正确地关闭了它,以防止数据丢失或者文件被意外地修改。

  • 此函数没有做任何关于文件权限或者文件存在性的检查,所有这些都需要在调用此函数前完成。

  • 传入的node数据需要能够被pickle模块正确地序列化和反序列化,否则在读取数据时可能会出现问题。

_read_base_header

_read_base_header方法用于读取文件的基本头.

参数:

f: 文件对象, 已打开的文件对象,该文件对象需要有可读权限并且已经打开.

返回:

该函数返回一个包含六个元素的元组,元素分别为file_index, append_seek, data_start_seek, count, filecount, current_count.

其中,

file_index: 文件索引,为int类型,

append_seek: 从文件开始处到当前位置的偏移量,为元组类型,包括页码和偏移量,

data_start_seek: 从文件开始处到数据开始的偏移量,为int类型,

count: 数据数量,为int类型,

filecount: 文件数量,为int类型,

current_file_count: 当前文件数量,为int类型。

使用方法:

file_obj = open('sample_file', 'r')

base_header = _read_base_header(file_obj)

print('Base Header:', base_header)

file_obj.close()

注意:

对于无法打开或者读取的文件,该函数可能会抛出异常.

_read_linked_header

此函数用于读取并返回链接头数据。

参数:

f (file): 一个已打开的文件对象,该对象用于读取文件操作。

返回:

返回一个元组,其中包含三个整型数据,分别为文件索引、当前计数和数据开始检索位置。

示例:

file_index, current_count, data_start_seek = self._read_linked_header(file)

注意:

传入的文件对象应处于读取状态,并且文件的内容应符合链接头数据的预期格式,否则可能会引发异常或错误。

load_pointer_data

根据提供的数据集名称和指针,从基本文件读取并返回所需的数据。

参数:

name: str, 数据集的名称,该数据集应该存在于本地磁盘中。

pointer: tuple, 包含文件索引和开始搜索的位置的元组。

返回:

返回从指定位置读取的数据。

举例:

假设我们有一个名为'test'的数据集,并且我们想要从该数据集的第一行中读取数据,则可以这样做:

data = load_pointer_data('test', (0,0))

这将返回'test'数据集的第一行数据。

注意:

如果数据集不存在或者指针指向的位置没有数据,那么这个函数会引发异常。

create_new_set

创建一个新的数据集

这个方法用于创建一个新的数据集,包含数据集名称、描述、标签和关键字等信息,并在本地磁盘上创建相关文件以存储数据。

参数:

name : str

新数据集的名称,同时也会作为数据集所在文件夹的名称.

description : str

数据集的描述信息,可以包含数据集的用途、来源等信息.

tags : list of str

数据集的标签,可以用来分类和搜索数据集.

keys : list of str

数据集中每条数据的关键字,对应数据的多个字段, 表示每条数据的结构.

返回:

bool

如果数据集创建成功,返回True. 如果数据集已经存在,会抛出异常,不会返回False.

例子:

source = LocalDisk_NLSampleSource('/path/to/dataset')

source.create_new_set('my_dataset', 'This is a dataset for testing.', ['test', 'sample'], ['field1', 'field2'])

# 这将在/path/to/dataset/my_dataset下创建相关的文件,以存储数据集信息和数据.

注意:

  • 数据集名称中不能包含下划线('_'),否则会抛出异常.

  • 如果数据集已经存在,再次调用这个方法会抛出异常.

has_set

此函数是用来检查是否存在某个set的。set是数据的集合,每个set中包含许多行数据,每行数据可以是一组标签和数据。

Args:

name: str, 待检查的set的名称。

Returns:

bool类型。如果存在该set,返回True; 否则,返回False。

Examples:

检查一个名为'sample_set'的set是否存在:

disk = LocalDisk_NLSampleSource(folder_path)




**if disk.has_set('sample_set'):**

print("The set 'sample_set' exists.")




**else:**

print("The set 'sample_set' does not exist.")

注意:此函数不会对不存在的路径或无效的set名称做错误处理。如果提供了不存在的路径或无效的set名称,可能会抛出异常。在使用此函数时,应确保提供的set名称是合法且存在的。

add_row

该函数为一个指定的set添加一行数据,这个set的名字由参数'name'指定。

参数:

name: str - 要添加数据的set的名字

data: list - 要添加到set中的数据,应该是一个列表,列表的长度应该与set的元数据键的数量相同

返回:

bool - 如果数据添加成功,则返回True,否则返回False

注意:

  • 如果data参数不是一个列表,函数会抛出一个异常

  • 如果data的长度与set的元数据键的数量不相同,函数会抛出一个异常

iter_data

此函数遍历并返回指定名称的数据集中的所有数据。

参数:

name (str): 数据集的名称。

返回:

生成器: 返回一个生成器,该生成器按顺序产生数据集中的每一行数据。

使用示例:

ld = LocalDisk_NLSampleSource("my_folder_path")




**for data in ld.iter_data("my_dataset_name"):**

print(data)

注意:

如果数据集不存在,此函数会引发异常。数据集的名称是大小写敏感的。

read_one_row

def read_one_row(self, name: str):

这是一个读取本地磁盘上特定文件(在此文件中,数据被存储为行)中的一行数据的方法。读取的行数据是以序列化形式存储的。

参数:

name (str): 需要读取的文件的名字。

返回:

返回一个被反序列化的数据对象,这个对象包含了在指定文件中的一行数据。

使用示例:

disk_source = LocalDisk_NLSampleSource('path/to/folder')

row_data = disk_source.read_one_row('filename')

注意:

  • 使用这个函数需要确保文件已经存在并且包含至少一行数据。

  • 在使用这个函数读取数据前,数据应该已经被正确地序列化和保存到了文件中。

  • 文件的存放位置应该在类初始化时设定的文件夹内。

iter_pointer

这个函数是在LocalDisk_NLSampleSource类中定义的,其主要作用是迭代返回给定名称的数据集的指针。指针包含文件索引和每个节点在文件中的位置。

参数:

name: str 类型,需要迭代的数据集的名称。

返回:

生成器,每次迭代返回一个包含文件索引和节点在文件中的位置的元组。

举例:

source = LocalDisk_NLSampleSource(folder_path)




**for pointer in source.iter_pointer('dataset_name'):**

file_index, position = pointer

print(f"文件索引: {file_index}, 位置: {position}")

以上代码创建了一个LocalDisk_NLSampleSource对象,并使用iter_pointer函数迭代打印出'dataset_name' 数据集的每个节点指针的信息。

注意: 这个函数不会检查给定的数据集名称是否存在,如果不存在,则会抛出异常。

get_set_count

此函数用于获取指定名称的数据集的数据行数。

参数:

name: str,数据集的名称。

返回:

int,数据集的数据行数。

示例:

count = get_set_count('dataset_name')

注意:

如果提供的名称不存在于数据库中,该函数将引发异常。

get_metadata_keys

获取指定数据集的元数据键。

参数:

name (str): 要获取元数据键的数据集的名称。

返回:

dict: 返回元数据键的字典。

示例:

disk = LocalDisk_NLSampleSource(folder_path)

keys = disk.get_metadata_keys("my_dataset")

print(keys)

在这个示例中,我们首先创建一个LocalDisk_NLSampleSource对象,并指定一个文件夹路径。然后,我们调用get_metadata_keys方法,传入我们想要获取元数据键的数据集的名称。这将返回一个包含元数据键的字典。

注意: 如果数据集名称不存在,这个函数将抛出一个异常。因此,需要确保数据集名称的有效性。

print_set_info

该函数用于打印指定名称的数据集信息。

参数:

name: str, 需要打印信息的数据集的名称。

返回值:

使用方法:

首先, 我们需要一个LocalDisk_NLSampleSource类的对象, 然后调用此对象的print_set_info方法来打印数据集的信息.

例如:

data_source = LocalDisk_NLSampleSource("/path/to/your/dataset")

data_source.print_set_info('your_dataset_name')

这个函数主要用于调试和检查数据集状态, 它会打印数据集的以下信息:

  • 文件索引

  • 追加的起始位置

  • 数据开始的位置

  • 数据的计数

  • 文件的计数

  • 文件的当前计数

注意: 这个函数没有返回值, 它只是打印信息到控制台.

这个函数没有已知的错误或者bug.

delete_set

delete_set 方法是 LocalDisk_NLSampleSource 类的一个方法,用于删除具有指定名称的数据集。它首先关闭所有打开的文件对象,然后清空文件池,最后删除该数据集的文件夹。

参数:

name (str): 要删除的数据集的名称。

返回值:

此方法无返回值。

示例:

source = LocalDisk_NLSampleSource(folder_path)

source.delete_set('my_dataset')

注意:

此方法不会检查数据集是否存在,如果尝试删除不存在的数据集,将会引发异常。在调用此方法前,应先使用 has_set 方法检查数据集是否存在。

add_attachment

该方法是向指定的数据集中添加附件。

参数:

set_name (str): 指定的数据集名称。

key: 附件的关键字,用于检索数据。

data: 要添加的数据。

返回:

无返回值。

例子:

local_disk = LocalDisk_NLSampleSource('path_to_folder')

local_disk.add_attachment('some_set', 'key1', 'some_data')

注意:

在使用此方法之前,需要确保数据集已经存在,否则会抛出异常。

错误和bug:

暂无已知错误或bug。

read_attachment

这个函数是用于读取集合(或称为 set)的附件信息的。

参数:

set_name: str类型, 它是你想要读取附件的集合的名字。

返回:

返回一个列表,列表中包含集合的所有附件信息。

用法:

read_attachment('text_set')

注意:

如果集合不存在或者集合没有附件,会引发异常。

NLSampleSourceBase

这是一个基于abc.ABCMeta元类的NLSampleSourceBase抽象类。这个类定义了一个样本源的通用接口,包括创建新的样本集,检查样本集是否存在,添加行数据,获取元数据键,获取目录列表,迭代数据,迭代指针,删除样本集,加载指针数据,获取样本集的数量,添加附件,读取附件,读取一个行数据等方法。

此抽象类的目的是为了定义一个统一的样本源接口,不同的数据源可以实现这个接口,提供统一的访问方式。这对于大型项目中处理各种各样的数据源非常有用,可以大大提高代码的可维护性和扩展性。

以下是一个可能的使用例子:

**class MySampleSource(NLSampleSourceBase):**




**def create_new_set(self, name: str, description: str, tags: [str], keys: [str], base_set="") -> bool:**

\# 在这里实现你的逻辑

\# 实现其他的抽象方法...

注意,由于这是一个抽象类,你不能直接实例化它。你应该创建一个新的类,继承这个抽象类,并实现所有的抽象方法。

此类没有已知的错误或bug。

create_new_set

create_new_set是一个抽象方法,需要在子类中实现。这个方法的目的是创建一个新的数据集。

参数:

name (str): 新的数据集的名称。

description (str): 新的数据集的描述。

tags (str): 新的数据集的标签列表。

keys (str): 新的数据集的关键字列表。

base_set (str, optional): 基础数据集的名称,这个参数默认为空字符串。

返回:

bool: 如果新的数据集成功创建则返回True,否则返回False。

用法示例:

class MySampleSource(NLSampleSourceBase):

def create_new_set(self, name, description, tags, keys, base_set=""):

# 实现创建新数据集的逻辑

pass

source = MySampleSource()

source.create_new_set("my_new_set", "This is a new set.", ["tag1", "tag2"], ["key1", "key2"])

注意:

这个方法是抽象方法,不能直接调用,需要在子类中实现。如果在子类中没有实现这个方法,将会在运行时抛出NotImplementedError异常。

has_set

此方法用于检查是否存在指定名称的数据集。

参数:

name (str): 要检查的数据集的名称。

返回:

bool: 如果数据集存在则返回True,否则返回False。

示例:

nl_sample_source = NLSampleSourceBase()




**if nl_sample_source.has_set("my_dataset"):**

print("Dataset exists.")




**else:**

print("Dataset does not exist.")

add_row

此函数是抽象基类NLSampleSourceBase的一个抽象方法,需要在子类中实现。其主要目的是向给定名称的数据集中添加一行数据。

参数:

name: str,数据集的名称。

data: list,要添加的数据列表。

返回值:

bool,如果数据添加成功,则返回True,否则返回False。

示例:

class SampleSource(NLSampleSourceBase):

data_dict = {}

def add_row(self, name: str, data: []) -> bool:

if name in self.data_dict:

self.data_dict[name].append(data)

return True

return False

source = SampleSource()

source.add_row('dataset1', ['data'])

注意:

子类必须实现此方法,否则在实例化子类并调用此方法时会抛出TypeError。

get_metadata_keys

根据提供的名称,获取元数据键的抽象方法。

参数:

name (str): 数据集的名称。

返回:

dict: 返回一个字典,字典的键是元数据的键,值是具体的元数据内容。

注释:

这是一个抽象方法,需要在子类中实现。在使用此方法时,请确保提供的数据集名称存在,否则可能会抛出异常。

示例:

source = NLSampleSourceBaseSubClass() # 假设NLSampleSourceBaseSubClass是NLSampleSourceBase的子类

meta_keys = source.get_metadata_keys("example_set")

print(meta_keys)

{'key1': 'value1', 'key2': 'value2'}

get_dir_list

get_dir_list方法是一个抽象方法,需要在子类中实现。该方法的主要目的是获取所有数据集合的目录列表。

返回值:

返回一个字典。字典的键是数据集合的名称,值是这些数据集合的相关信息。相关信息可能包括但不限于数据集合的元数据和其他相关信息。

iter_data

这是一个抽象方法。子类需要实现这个方法以便提供对特定数据集的迭代访问。

参数:

name (str): 数据集的名称。

返回:

生成器: 返回一个可以用于迭代数据集每一行的生成器。

示例:

class MySampleSource(NLSampleSourceBase):

...

def iter_data(self, name: str):

with open(name, "r") as file:

for line in file:

yield line.strip().split(",")

...

my_source = MySampleSource()

for row in my_source.iter_data("my_dataset"):

print(row)

注意: 这个方法只是一个接口,具体实现应该由子类提供。如果子类没有提供特定的实现,那么调用这个方法时会抛出 NotImplementedError 异常。

iter_pointer

该函数是一个抽象方法,需要在子类中实现。它的主要目的是迭代并返回一个特定集合的指针。

参数:

name (str): 需要迭代的集合的名称。

返回:

此方法的返回类型依赖于具体实现,但通常它会返回一个可迭代的对象,例如列表或生成器。

注意:

由于这是一个抽象方法,它自身并不执行任何操作。具体的行为将在子类中定义。请参照子类的文档以获取更具体的使用细节和例子。

异常:

如果集合不存在,或者name参数不是字符串,那么实现此函数的子类可能会抛出异常。具体的行为和异常类型依赖于具体的实现。

delete_set

删除指定的数据集

此方法将删除具有给定名称的数据集。此操作是不可撤销的。

参数:

name (str): 要删除的数据集的名称。

返回:

此方法没有返回值。

注意:

此方法不会检查数据集是否存在,如果尝试删除不存在的数据集,它可能会引发错误。

例子:

sample_source = NLSampleSourceBase()

sample_source.delete_set('my_dataset')

此代码将删除名为'my_dataset'的数据集。如果此数据集不存在,将会引发一个错误。

可能的错误:

如果尝试删除不存在的数据集,将会引发一个错误。

注意:

由于此操作无法撤销,因此在调用此方法之前, 应该要小心确认是否真的需要删除此数据集。

load_pointer_data

这是一个抽象方法,由子类实现。该方法的目的是从给定名称的数据集中加载指定的数据。

Args:

name: str 类型,表示数据集的名称。

pointer: 指向数据集中特定数据的指针。

Raises:

由于这个函数是由子类实现的,因此可能会抛出任何类型的异常。具体的异常类型和处理方式依赖于子类的实现。

Returns:

此函数的返回类型取决于子类的具体实现。通常,它应该返回从数据集中加载的数据。

注意:

这是一个抽象方法,必须在子类中重写。如果在没有重写的情况下调用,将会引发 NotImplementedError。

示例:

以下示例假设我们有一个名为 MySampleSource 的子类,它实现了 load_pointer_data 方法。





**class MySampleSource(NLSampleSourceBase):**




**def load_pointer_data(self, name: str, pointer):**

return my_dataset[name][pointer]

source = MySampleSource()

data = source.load_pointer_data('my_dataset', 123)

get_set_count

该函数的主要目标是获取指定数据集的数量。

参数:

name: 字符串,指定的数据集的名称。

返回:

返回整数,表示指定数据集中的数据条目数。

使用示例:

示例代码:

nls = NLSampleSourceBase()

count = nls.get_set_count('dataset_name')

在这个示例中,我们首先实例化了一个NLSampleSourceBase对象,然后使用get_set_count方法来获取名为'dataset_name'的数据集的条目数量。

注意:

由于这是一个抽象方法,因此具体的实现可能会根据不同的子类变化。

在使用此方法时,需要确保数据集的名称是存在的,否则可能会抛出异常。

add_attachment

该函数是用于向指定的数据集添加附件。

参数:

set_name (str): 数据集的名称。

key: 附件的键值。该键值应该是可哈希且可用作字典的键。

data: 要添加的附件数据。

返回类型:

此方法没有返回值。

注意:

此方法可能会抛出由于无法找到指定的数据集或者无法添加附件数据导致的异常。

使用示例:

add_attachment("my_dataset", "my_key", my_data)

这会将my_data作为附件添加到名为"my_dataset"的数据集中,可以使用"my_key"作为键来获取这个附件。

此函数假定set_name和key不会为None,如果为None则可能会引发错误。

read_attachment

这是一个抽象方法,需要在子类中实现。该方法的功能是读取指定数据集的附件。

参数:

set_name (str): 要读取附件的数据集名称。

返回:

该方法的返回类型取决于具体实现,一般情况下应返回附件的数据。

注意:

由于这是一个抽象方法,如果在子类中没有实现该方法,那么在调用该方法时将会抛出 NotImplementedError 异常。

示例:

**class MySampleSource(NLSampleSourceBase):**




**def read_attachment(self, set_name: str):**

\# 实现具体的读取附件的逻辑,比如从硬盘中读取某个文件

return read_file(set_name)

read_one_row

该抽象方法用于读取指定数据集的一行数据。

Args:

set_name (str): 指定数据集的名称

Returns:

该方法应返回一个列表,列表元素为数据集中的一行数据。数据类型可能包括但不限于整型,浮点型,字符串型等。

Raises:

NotImplementedError: 如果该方法在子类中没有被实现,则会抛出此异常。

注意:

这是一个抽象方法,需要在子类中实现。具体的实现会根据数据存储和获取的方式有所差别,比如可能会从数据库中获取数据,也可能会从文件中读取数据等。因此,具体的实现需要根据实际的数据存储方式来决定。

arrange_dir_list

这是一个方法,用于整理目录列表。

这个方法首先通过调用'get_dir_list()'方法获取目录列表。然后,它会创建一个新的字典,其中包含目录列表中每个键的元数据、子项、数量和基础集。

接着,这个方法会遍历新字典中的每个项目,检查每个项目的基础集是否为空。如果基础集不为空,并且基础集在新的字典中,那么这个项目会被添加到其基础集的子项中。否则,这个项目将被置顶输出。

最后,方法将返回新字典中基础集为空的所有项目。

参数:

返回:

一个字典,其中包含新字典中基础集为空的所有项目。

注意:

尽管这个方法包含一段被注释掉的代码,但是这段代码基本上是将新的字典打印出来。这可以用于调试或检查新的字典。

错误和bug:

print_markdown_arrange_dir_list

此方法用于以markdown格式打印并整理目录列表。

参数:

path (str): 输出markdown文件的路径,默认为None,此时会创建一个名为"data_list.md"的文件。

max_length (int): markdown文档中,每行最大的字符长度。

返回:

None

此函数首先获取目录列表,然后根据目录列表生成一个新的字典,字典中的每个元素包括:

  • "meta": 元数据

  • "children": 子目录

  • "count": 数据集中的记录数

  • "row_sample": 示例行

对于每个键,函数都会读取一行数据作为示例,并打印其在markdown文件中的相关信息。

注意:此函数可能会覆盖已存在的文件,因此在使用时请确认文件路径的正确性。

flush

这个是一个抽象方法,具体的实现应由继承该基础类的子类进行定义。

方法简介:flush()方法的主要目的是清空或同步缓冲数据。一般在进行文件读写操作或数据库操作时,有可能会先把数据写入到缓冲区,等缓冲区满了或手动调用flush()方法时,才真正的将数据写入到文件或数据库。对于一些需要立即看到写入效果的操作,可能需要在写入后立即调用此方法。

参数列表:该方法不接受任何参数。

返回类型:由于这是一个抽象方法,所以没有具体的返回值类型,具体的返回值类型由子类实现的flush()方法决定。

使用示例:由于这是一个抽象方法,没有具体的使用例子。但是一般在子类中,可以按照如下方式进行实现和使用:

class ChildClass(NLSampleSourceBase):

def flush(self):

# 实现flush操作

print('flush')

child_class = ChildClass()

child_class.flush() # 输出: flush

注意: 这是一个抽象方法,如果子类没有实现这个方法,那么在实例化子类时,Python解释器会抛出TypeError的错误。

Memory_NLSampleSource

Memory_NLSampleSource是一个继承自NLSampleSourceBase基类的类,用于处理和存储数据样本。

类的主要方法包括创建新的数据集、检查数据集是否存在、添加数据行、获取元数据键、获取目录列表、遍历数据和指针、删除数据集、加载指针数据、获取数据集数量以及处理附件。

此类主要在内存中维护一个数据字典,对数据集的各种操作都是对这个字典的操作。使用前需要先实例化对象,再调用相应的方法。

主要方法说明:

  • \_\_init\_\_: 初始化方法,创建空的数据字典。

  • create_new_set: 创建新的数据集,如果数据集已存在则抛出异常。返回操作是否成功的布尔值。

  • has_set: 检查指定名称的数据集是否存在。返回布尔值。

  • add_row: 向指定数据集添加一行数据。返回操作是否成功的布尔值。

  • get_metadata_keys: 获取指定数据集的元数据键。返回一个包含元数据的字典。

  • get_dir_list: 获取所有数据集的目录列表。返回一个包含所有数据集的元数据和数量的字典。

  • iter_dataiter_pointer: 遍历指定数据集的数据和指针。返回一个迭代器。

  • delete_set: 删除指定的数据集。

  • load_pointer_data: 加载指定数据集的指针数据。返回加载的数据。

  • get_set_count: 获取指定数据集的数据数量。返回数据数量。

  • add_attachmentread_attachment: 操作附件。但当前版本不支持,调用时会抛出异常。

  • read_one_row: 读取指定数据集的第一行数据。返回数据。

注意:此类在操作附件时会抛出异常,因为当前版本还不支持。

使用示例:

mem_source = Memory_NLSampleSource()

mem_source.create_new_set('set1', 'description', ['tag1', 'tag2'], ['key1', 'key2'])

mem_source.add_row('set1', ['data1', 'data2'])

print(mem_source.get_dir_list())

__init__

这是一个基于内存的样本源类,实现了NLSampleSourceBase中定义的接口。

该类提供一种快速加载和处理内存中数据的方式,可以创建新的数据集,检查数据集是否存在,向数据集中添加行,获取数据集的元数据等操作。

属性:

  • datas(dict): 存储数据的字典,键是数据集的名字,值是一个字典,包含了数据集的各种元数据信息以及数据本身。

使用示例:


\# 创建一个内存样本源对象

mem_source = Memory_NLSampleSource()

\# 创建一个新的数据集

mem_source.create_new_set('dataset1', 'this is a test dataset', ['tag1', 'tag2'], ['key1', 'key2'])

\# 检查一个数据集是否存在

print(mem_source.has_set('dataset1'))  \# 输出: True

\# 向一个数据集中添加行

mem_source.add_row('dataset1', ['data1', 'data2'])

\# 获取一个数据集的元数据

print(mem_source.get_metadata_keys('dataset1'))  \# 输出: {'des': 'this is a test dataset', 'tags': ['tag1', 'tag2'], 'label_keys': ['key1', 'key2'], 'base_set': '', 'base_set_process': ''}

\# 获取所有数据集的列表

print(mem_source.get_dir_list())  \# 输出: {'dataset1': {'meta': {'des': 'this is a test dataset', 'tags': ['tag1', 'tag2'], 'label_keys': ['key1', 'key2'], 'base_set': '', 'base_set_process': ''}, 'count': 1}}

create_new_set

这个类的目的是模拟一个名为Memory_NLSampleSource的内存数据源。它提供一些基本的操作,如创建新的数据集、检查数据集是否存在、添加行、获取元数据键等。这个类最主要的用途是进行数据集的管理和数据操作。

函数 create_new_set 是用来创建新的数据集的。它需要以下输入参数:

  • name: str 类型,用于指定新数据集的名称。

  • description: str 类型,用于描述新的数据集。

  • tags: str 类型的列表,用于给新的数据集添加标签。

  • keys: str 类型的列表,用于指定新的数据集的键。

  • base_set: str 类型,是一个可选参数,默认为空字符串,用于指定新的数据集的基础集。

  • base_set_process: str 类型,是一个可选参数,默认为空字符串,用于指定新的数据集的基础集处理方式。

返回值为 bool 类型,如果新增数据集成功,返回True。

值得注意的是,如果在添加新的数据集时,存在名称相同的数据集,程序会抛出一个异常,提示"已存在相同的set"。

此外,这个类不支持添加附件和读取附件,如果尝试调用这两个方法,程序会抛出一个异常,提示"not support"。

has_set

该函数的主要作用是检查给定的名字是否存在于当前数据源中。

参数:

name (str): 需要查询的数据集名称。

返回:

bool: 如果数据集存在, 返回True, 否则返回False.

使用示例:

memory_sample_source = Memory_NLSampleSource()

memory_sample_source.create_new_set("test_set", "a description", ["tag1", "tag2"], ["key1", "key2"])

print(memory_sample_source.has_set("test_set"))  \# 输出: True

print(memory_sample_source.has_set("nonexistent_set"))  \# 输出: False

注意: 如果查询的数据集名称不存在, 该函数不会引发异常, 而只是返回False.

add_row

这是一个函数,名为 add_row,它的作用是在指定的数据集中添加一条新的数据。

参数:

  • name (str):数据集的名称。

  • data (list):要添加的数据,它是一个列表。

返回类型:

  • bool:如果数据成功添加到数据集中,则返回 True

示例:

\# 创建一个 Memory_NLSampleSource 实例

source = Memory_NLSampleSource()

\# 创建一个新的数据集

source.create_new_set(name='new_set', description='This is a new set', tags=['tag1', 'tag2'], keys=['key1', 'key2'])

\# 在新的数据集中添加数据

source.add_row(name='new_set', data=['data1', 'data2'])

注意:

  • 如果数据集名不存在,函数将会引发 KeyError。

  • 此函数不支持并发调用,如果多个线程同时调用此函数可能会导致数据错误。

get_metadata_keys

此函数用于获取指定数据集的元数据键及其对应的值。

参数:

name: str - 数据集的名称。

返回值:

dict - 返回一个字典,字典中的每个元素对应数据集的一种元数据,包括'描述'、'标签'、'标签键'、'基础集合'和'基础集合处理'。

示例:

get_metadata_keys("dataset_name") 返回值可能如下:

{

'des': '这是一个用于图像分类的数据集',

'tags': ['图像', '分类'],

'label_keys': ['猫', '狗', '鸟'],

'base_set': '原始数据集',

'base_set_process': '数据预处理步骤详细描述',

}

注意:

如果提供的数据集名称不存在于当前数据源中,此函数将抛出KeyError异常。

get_dir_list

get_dir_list函数是用来获取所有数据集的元数据以及每个数据集的数据数量。

遍历数据集(datas)的所有键值,对于每一个键值(x_key),

获取其元数据(通过调用get_metadata_keys方法)和数据数量(通过调用get_set_count方法)。

返回值是一个字典,键值是数据集的名称,值是一个字典,其中包含了元数据(存储在'meta'键值下)以及数据数量(存储在'count'键值下)。

函数不接受任何参数。

返回类型是一个字典,它的键是字符串类型(数据集的名称),值是一个字典,这个字典的键是字符串类型('meta'和'count'),值的类型分别是字典和整数。

例如:

假设当前的数据集是{'set1': {'name': 'set1', 'des': 'description1', 'tags': ['tag1', 'tag2'], 'label_keys': ['key1', 'key2'], 'base_set': '', 'base_set_process': '', 'data': [1, 2, 3]}}

那么这个函数的返回值将会是:

{'set1': {'meta': {'des': 'description1', 'tags': ['tag1', 'tag2'], 'label_keys': ['key1', 'key2'], 'base_set': '', 'base_set_process': ''}, 'count': 3}}

iter_data

这是一个用于迭代数据集中数据的函数。

参数:

name (str): 需要迭代的数据集的名称。

返回:

iterator: 返回一个迭代器,用于按顺序访问数据集中的所有数据。

示例:

mem_source = Memory_NLSampleSource()

mem_source.create_new_set('test', 'This is a test set', [], ['key1', 'key2'])

mem_source.add_row('test', ['value1', 'value2'])

>>> for data in mem_source.iter_data('test'):

... print(data)

['value1', 'value2']

注意:

  • 如果给定的数据集名称不在内存中,这个函数将会抛出一个KeyError。

  • 这个函数不会改变数据集或数据的状态,可以安全的多次调用。

iter_pointer

该方法作用于Memory_NLSampleSource类,主要用于生成数据集名称对应的数据索引迭代器。

参数:

name : str

数据集名称,用于指定需要生成索引迭代器的数据集。

返回:

一个生成器,会产生指定数据集中的数据索引,从0开始到数据集的长度。

使用示例:

memory_source = Memory_NLSampleSource()

memory_source.create_new_set("example_set","An example set",[],[],"")

memory_source.add_row("example_set", ["data1","data2","data3"])

pointer_iter = memory_source.iter_pointer("example_set")

for pointer in pointer_iter:

print(pointer)

上述代码中,首先创建了Memory_NLSampleSource的实例memory_source,然后创建了一个新的数据集"example_set"并添加了一些数据,然后使用iter_pointer方法生成了索引迭代器pointer_iter,最后用for循环遍历并打印出所有的索引。

注意:

该方法不会对输入的数据集名称进行检查,如果输入的数据集名称不存在于数据源中,会直接导致KeyError异常。使用时需要确保数据集名称的正确性。

delete_set

这是一个用于删除已存在的数据集的方法。

参数:

name: str类型,表示需要删除的数据集的名称。

返回:

无返回值。

例子:

delete_set('example_set')将会删除名为'example_set'的数据集。

注意:

如果传入的名称在数据集中不存在,将会引发KeyError错误。

load_pointer_data

此函数用于加载指定索引位置的数据。

参数:

name: str类型,数据集的名称

pointer: 数据索引位置

返回:

返回给定数据集中指定索引位置(pointer)的数据。

示例:

data_source = Memory_NLSampleSource()

data_source.create_new_set("test", "", [], [], "")

data_source.add_row("test", ["Hello, world!"])

data_source.load_pointer_data("test", 0)

['Hello, world!']

注意:

这个函数不会检查索引(pointer)是否在数据边界之内,如果提供的索引超过数据边界,会抛出IndexError。

get_set_count

这个函数用于获取指定数据集的数量。

参数:

name: str, 指定的数据集的名称。

返回:

int, 返回指定数据集的数量。

示例:

instance = Memory_NLSampleSource()

instance.create_new_set('name1', 'description1', ['tag1', 'tag2'], ['key1', 'key2'])

instance.add_row('name1', ['data1', 'data2'])

print(instance.get_set_count('name1'))  \# 输出1

add_attachment

这个方法是在类Memory_NLSampleSource中定义的,该类主要用于管理和存储数据集样本的相关信息,不过这个方法并未实现任何功能,尝试调用这个方法将会触发一个异常。

方法名称:add_attachment

参数:

set_name (str): 数据集的名字。

key: 该参数在当前方法中没有用到。

data: 该参数在当前方法中没有用到。

返回:

无返回值。

异常:

如果调用该方法,将会触发一个异常,提示"not support"。

注意:在当前版本中,这个方法并未实现,如果你尝试调用它,将会触发一个异常。

read_attachment

这个函数是在Memory_NLSampleSource类中定义的,用于读取指定数据集(set_name)的附加信息,但实际上并未实现这个功能,因此当调用此函数时,会抛出一个"Not support"的异常。

参数:

set_name (str): 需要读取附加信息的数据集名称。

返回:

抛出异常。

示例:

mem = Memory_NLSampleSource()

mem.read_attachment('dataset_name')

注意:

这个函数并未实现,调用时会抛出异常。

read_one_row

这个函数的目的是从给定的集合(set)中读取一行数据。

Args:

set_name: str类型,表示需要读取的数据集的名称。

Returns:

返回类型是list,返回从数据集中读取的第一行数据。

Usage:

例如,我们有一个名为'example_set'的数据集,我们可以通过以下方式读取第一行数据:

data_source = Memory_NLSampleSource()

first_row = data_source.read_one_row('example_set')

注意:

  • 如果给定的集合名称不存在,将会抛出一个KeyError异常。

  • 如果数据集为空(即没有数据行),此函数将返回一个空列表。

convert_jsonl

这个函数用于将数据源中的数据转换为jsonl文件。函数首先会从数据源中读取指定的“set_name”的样本集,然后使用“prompt_completion_fun”函数将每个样本的问题和答案转换为字典形式,最后将所有的行写入到“target_path”指定的文件中。

参数:

datasource(NLSampleSourceBase): 数据源,必须是NLSampleSourceBase类型或其子类的实例。

set_name(str): 从数据源中获取样本集的名称。

prompt_completion_fun(function): 用于将样本转换为字典形式的函数。该函数必须接受一个参数(样本),并返回一个元组,元组的第一个元素是问题,第二个元素是答案。

target_path(str): 输出jsonl文件的路径。

返回:

None

示例:

def prompt_completion(item):

return item.question, item.answer

datasource = MyDataSource()

convert_jsonl(datasource, "train", prompt_completion, "train.jsonl")

注意:这个函数不会检查“target_path”是否已经存在,如果存在,它会直接覆盖旧文件。

错误与异常:

如果数据源中不存在指定的“set_name”,函数会抛出异常。

如果“prompt_completion_fun”函数不能正确处理样本,也会抛出异常。

create_SampleSet

这是一个在内存中生成样本集并保存的函数。该函数首先创建一个新的样本集,然后在每次循环中生成随机输入样本和对应的随机标签,最后将这些样本添加到样本集中。

参数:

set_name (str): 新建样本集的名称。

input_size (int): 每个样本输入的大小。

count (int): 需要生成的样本数量。

返回值:

Memory_NLSampleSource: 包含新生成样本集的Memory_NLSampleSource对象。

示例:

创建一个名为'test_set',每个输入样本大小为10,样本数量为100的样本集:

sample_set = create_SampleSet('test_set', 10, 100)

注意:

该函数没有做参数类型和值的检查,如果传入的参数类型或值不正确,可能会抛出异常。

SSHSampleSource

SSHSampleSource类是一个用于处理SSH源文件的类,继承自LocalDisk_NLSampleSource类。这个类实现了各种对SSH源文件的操作,包括下载、更新、创建新的数据集、检查数据集是否存在、添加行、获取元数据键、迭代数据、获取远程目录列表等等。

注意:此类需要先安装paramiko库才能使用。

类的初始化函数参数介绍:

  • folder_path: 本地文件夹路径

  • endpoint: SSH服务器的IP地址或者主机名

  • access_user: SSH登录的用户名

  • secret_pwd: SSH登录的密码

  • target_path: SSH服务器上的目标文件夹路径

  • port: SSH服务器的端口,默认为22

使用示例:

ssh_source = SSHSampleSource('/local/path', 'ssh.server.com', 'user', 'password', '/remote/path')

ssh_source.download('dataset_name')

此类可能存在的问题:

  • 对于大文件的同步可能会有性能问题

  • 当SSH服务器连接问题时,可能会出现异常

  • 使用的SSH连接库paramiko没有对并发做优化,可能会有并发问题

instance_default

这是一个staticmethod,命名为instance_default的类方法。这个方法主要用于获取配置信息,并依据这些配置信息创建一个SSHSampleSource实例。

这个方法的工作流程是:

  1. 通过调用get_config_instance().get_config("ssh_samplesource_xxx")函数,获取必要的SSH连接参数,包括folder_path, endpoint, access_user, access_pwd和access_target_path等。

  2. 使用上述获取的参数创建并返回一个SSHSampleSource实例。

返回类型:

该方法返回一个SSHSampleSource类的实例。

使用示例:

sample_source = SSHSampleSource.instance_default()

注意事项:

在使用这个方法的过程中需要注意,所有的配置信息都需要在应用的配置文件中进行预设,并且这个方法在读取配置信息的时候不会进行任何的错误处理,所以如果配置信息不存在或者格式错误,都会导致程序运行错误。

__init__

MinioSampleSource

MinioSampleSource类是一个继承于LocalDisk_NLSampleSource的子类,主要用于实现与Minio服务端交互的操作,如上传数据、下载数据、检查数据是否存在等。Minio是一个开源的对象存储服务器,兼容亚马逊S3云存储服务接口,可以用于存储非结构化数据如照片、视频、日志文件、备份数据等。

此类的构造函数接收Minio服务端的端点、访问密钥、秘密密钥和桶名作为输入参数。此外,它还提供了用于数据集的创建、检查和下载的方法。

以下是使用此类的一些例子:

folder_path = 'my_folder'

endpoint = 'my-minio-endpoint'

access_key = 'my-access-key'

secret_key = 'my-secret-key'

bucket_name = 'my-bucket'

\# 创建一个新的MinioSampleSource实例

source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

\# 检查一个数据集是否存在

exists = source.has_set('my-dataset')

\# 如果不存在,则创建一个新的数据集




**if not exists:**

source.create_new_set('my-dataset', 'A description of my dataset', ['tag1', 'tag2'], ['key1', 'key2'])

\# 添加数据到数据集

data = {'key1': 'value1', 'key2': 'value2'}

source.add_row('my-dataset', data)

\# 下载数据集

source.download('my-dataset')

注意:该类的某些方法可能会抛出异常,使用时需要加以处理。

此类的方法列表如下:

  • \_\_init\_\_(self, folder_path, endpoint, access_key, secret_key, bucket_name): 构造函数,创建一个新的MinioSampleSource实例。

  • _join(self, *args): 私有方法,用于连接多个字符串并用'/'分隔。

  • _get_minio_client(self): 私有方法,获取Minio客户端对象。

  • _download_if_not_exsited(self, name: str): 私有方法,如果本地没有指定的数据集,则从Minio服务器下载。

  • _object_exsited(self, name): 私有方法,检查指定的数据集是否在Minio服务器上存在。

  • update(self, set_name=None): 更新指定的数据集,如果没有指定,则更新所有数据集。

  • create_new_set(self, name: str, description: str, tags: [str], keys: [str], base_set="", base_set_process=""): 创建一个新的数据集。

  • has_set(self, name: str): 检查指定的数据集是否存在。

  • add_row(self, name: str, data): 向指定的数据集添加一行数据。

  • get_metadata_keys(self, name: str): 获取指定数据集的元数据键。

  • iter_data(self, name: str): 返回一个迭代器,用于在指定的数据集中迭代数据。

  • get_remote_dir_list(self): 获取远程目录的列表。

  • download(self, name): 下载指定的数据集。

  • read_one_row(self, name: str): 读取指定数据集的一行数据。

请注意,本类需要有对指定Minio服务端的访问权限,并且在使用Minio服务端时,需要遵守其使用协议和条款。

__init__

这个类是用于处理Minio存储桶的接口类,继承了LocalDisk_NLSampleSource类,用于实现基于本地磁盘的样本源接口。通过此类,可以实现样本的读取、上传、下载等操作。并且,可以确保数据的一致性和完整性。

该类的初始化方法接受以下五个参数:

  • folder_path:本地的文件夹路径,用于存储从Minio存储桶下载的样本数据,或者需要上传到Minio存储桶的样本数据。

  • endpoint:Minio服务的URL,例如:"http://localhost:9000"。

  • access_key:用于访问Minio服务的Access Key。

  • secret_key:用于访问Minio服务的Secret Key。

  • bucket_name:Minio存储桶的名称,用于存储和获取数据。

初始化方法将创建Minio客户端对象,此对象在之后的方法中将用于与Minio服务进行交互。

实例化该类的例子如下:

folder_path = "/path/to/local/storage"

endpoint = "http://localhost:9000"

access_key = "your-access-key"

secret_key = "your-secret-key"

bucket_name = "your-bucket-name"

minio_sample_source = MinioSampleSource(

folder_path,

endpoint,

access_key,

secret_key,

bucket_name

)

注意:此类不负责Minio服务的启动和关闭,这些操作需要在实例化类之前/之后手动完成。

_join

这个方法的主要目的是把传入的多个字符串参数使用'/'连接起来。

Args:

*args: 一个或多个字符串参数。

Returns:

返回一个新的字符串,该字符串是由传入的各个字符串使用'/'连接而成。

Example:

_join("home", "user", "documents")

'home/user/documents'

_get_minio_client

这是一个获取Minio客户端的方法。

根据类中定义的endpoint、access_key、secret_key和bucket_name去初始化一个Minio的客户端。 如果客户端未被初始化或者是第一次调用,就会创建一个新的客户端。然后检查这个客户端所连接的bucket是否存在,如果不存在就会抛出一个异常。

Args:

self: 类的实例。

Returns:

返回一个已经初始化并且连接成功的Minio客户端。

Raises:

Exception: 如果bucket不存在的话,就会抛出异常。

_download_if_not_exsited

此方法主要用于检查及下载指定的文件集。

在我们的本地磁盘中,每个文件集都对应一个文件夹,此方法主要检查指定的文件集(名字为参数name)是否已经存在于本地磁盘。如果已经存在,那么就不再执行任何操作。如果不存在,则通过Minio客户端从远程Minio服务器下载该文件集,并保存到本地磁盘的指定文件夹中。

参数:

name: str,文件集的名称。

返回值:

无返回值。

例子:

假设我们的本地磁盘中没有名为'sample'的文件集,那么我们可以通过如下代码来下载它:

minio_source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

minio_source._download_if_not_exsited('sample')

注意事项:

如果远程Minio服务器中也没有名为'sample'的文件集,那么此方法会抛出一个异常。

错误和异常:

如果指定的文件集在远程Minio服务器中不存在,那么此方法会抛出一个异常。

_object_exsited

此函数用于检查指定名字的对象是否存在于Minio服务器的桶中。

参数:

name (str): 需要检查的对象的名字。

返回:

bool: 如果存在则返回True,否则返回False。

异常:

任何未处理的异常都将被捕获并返回False。

示例:

**def test_object_existed(self):**

\# 假设我们有一个已经初始化并配置过的MinioSampleSource对象

minio_source = MinioSampleSource(...)

\# 检查名为"my_object"的对象是否存在




**if minio_source._object_exsited("my_object"):**

print("my_object exists in the bucket.")




**else:**

print("my_object does not exist in the bucket.")

注意:

这是一个内部方法,通常不应直接调用。

update

此函数用于更新存储桶(bucket)中的对象(object),可以选择性地只更新特定的集合(set)。

参数:

set_name : str, 可选

要更新的集合名。如果未设置,将更新所有集合。

返回:

无返回值

使用范例:

source = MinioSampleSource(folder_path="my_folder", endpoint="my_endpoint", access_key="my_access_key",

secret_key="my_secret_key", bucket_name="my_bucket")

source.update(set_name="my_set")

注意:

  • 更新过程中,会首先判断本地文件和存储桶中的文件是否一致,只有在文件内容或大小发生变化时,才会上传新文件。

  • 如果存储桶中不存在要求的集合名,会抛出异常。

  • 本方法可能会消耗较大的网络流量和磁盘空间。

create_new_set

创建新的数据集

此方法是用来在Minio存储服务上创建一个新的数据集。此数据集包含描述,标签,键,基础数据集和基础数据集的处理方法。

参数:

name (str): 数据集的名称

description (str): 数据集的描述

tags (str): 数据集的标签列表

keys (str): 数据集的键列表

base_set (str, 可选): 基础数据集的名称. 默认为空字符串.

base_set_process (str, 可选): 基础数据集的处理方法. 默认为空字符串.

返回:

bool: 如果数据集成功创建,返回True

错误:

如果已经存在同名的数据集,将会抛出异常。

示例:

create_new_set("exampleSet", "This is an example set", ["exampleTag"], ["exampleKey"], base_set="baseSet", base_set_process="process")

has_set

该函数用于检查指定的集合名称是否在本地或远程存储桶中存在。

参数:

name: str - 需要检查的集合名称。

返回:

bool - 如果集合存在于本地或远程存储桶中,则返回True,否则返回False。

使用方法:

minio_source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

if minio_source.has_set("sample_set"):

print("集合存在")

else:

print("集合不存在")

注意:

  • 该函数会首先在本地检查集合是否存在,如果不存在,才会去远程存储桶检查。

  • 如果在检查远程存储桶时发生网络等错误,该函数可能会抛出异常。

add_row

此函数的功能是在指定的数据集中添加一行新的数据。

参数:

name: str, 数据集的名称。

data: 需要添加的数据。

返回:

bool,如果数据成功添加则返回True,否则返回False。

使用方法:

add_result = add_row('data_set_name', data)

注意:

在执行这个函数之前,首先会检查是否已经在本地存在此数据集,如果不存在,会从minio服务器上下载对应的数据集。

在添加数据之后,数据会被持久化保存在本地磁盘上。

get_metadata_keys

get_metadata_keys方法主要用于获取指定数据集的元数据键。

参数:

name (str): 指定的数据集名称。

返回:

dict: 返回一个字典,其中包含了指定数据集的所有元数据键。

使用示例:

minio_sample_source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

metadata_keys = minio_sample_source.get_metadata_keys('sample_dataset')

print(metadata_keys)

注意:

如果数据集在本地不存在,该方法将会先从远程服务器下载数据集到本地,然后再获取元数据键。

如果数据集在本地和远程服务器都不存在,将会抛出异常。

iter_data

此函数为迭代器, 用于迭代返回一个名称为"name"的数据集的所有数据。

Args:

name (str): 数据集的名字

Yields:

iter: 指向数据集中下一个元素的迭代器

举例:

假设有一个名为"sample_set"的数据集,使用方法如下:

source_connection = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

for data in source_connection.iter_data("sample_set"):

# 在这里处理数据

print(data)

注意事项:

如果数据集不存在,或者数据集为空,那么这个函数将会返回一个空的迭代器。

get_remote_dir_list

get_remote_dir_list 是一个方法,用于获取远程目录列表,并提供有关这些目录的信息,包括元数据、计数和文件计数。

这个方法没有接收任何参数,返回的是一个字典,字典中的每个键是远程目录的名称,值是一个字典,包含'meta'(元数据节点),'count'(该目录中的对象数量),'filecount'(该目录中的文件数量)。

该方法首先初始化一个minio客户端,然后列出存储桶中的所有对象。对于每个对象,它会读取并解析相关的头部信息,包括基本头部和节点信息,然后将这些信息存储在一个字典中,并作为结果返回。

注意:这个方法在处理大量目录时可能会需要较长的时间,因为它需要发送网络请求来获取每个目录的信息。

示例:

minio_sample_source = MinioSampleSource(参数省略)

remote_dir_list = minio_sample_source.get_remote_dir_list()




**for dir_name, info in remote_dir_list.items():**

print(f"Directory: {dir_name}, Meta: {info['meta']}, Count: {info['count']}, File count: {info['filecount']}")

上述代码示例中,我们首先创建了一个MinioSampleSource对象,然后调用其get_remote_dir_list方法来获取远程目录列表。对于获取的每一个目录,我们都打印出了目录的名称,元数据,计数和文件计数。

注意: 如果minio服务器的响应时间过长,或者网络连接不稳定,这个方法可能会抛出异常。确保在调用这个方法时具有稳定的网络连接,并准备好处理可能出现的异常。

download

这个函数的目的是从远程存储下载数据集。

Args:

name (str): 需要下载的数据集的名称。

这个函数没有返回值。

使用方法如下:

minio_sample_source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

minio_sample_source.download('your_dataset_name')

在以上代码中,你需要提供有效的MinIO连接信息以及你想要下载的数据集的名称。函数会检查本地是否已有此数据集,如果没有,则从远程MinIO存储中下载。

请注意,这个函数会依赖MinIO API进行操作,如果在MinIO上不存在指定的数据集,或者MinIO服务器无法连接,函数会抛出异常。

read_one_row

此函数用于从指定的Minio云存储中读取一行数据。该函数首先生成需要的初始偏移量然后从远程数据对象中获取这些偏移量,接着读取一行数据的长度和实际长度,最后在按照实际长度从远程对象中读取数据。

参数:

name (str): 这是一个字符串类型的参数,用于指定待读取数据的集合名称。

返回:

返回从Minio云存储中读取的一行数据。

注意:

  1. 函数在执行过程中可能会遇到网络问题,导致从Minio云存储读取数据失败。需要在使用时处理这种可能的异常情况。

  2. 函数只读取一行数据,如果需要读取多行数据,需要多次调用该函数。

  3. 由于涉及到网络IO,函数的执行时间可能会比较长,需要在调用时考虑到性能问题。

示例:

# 创建MinioSampleSource对象

sample_source = MinioSampleSource(folder_path, endpoint, access_key, secret_key, bucket_name)

# 读取一行数据

row_data = sample_source.read_one_row("sample_set")

# 打印读取的数据

print(row_data)

SampleSet

这是一个名为SampleSet的类,主要用于管理和处理样本数据。

SampleSet类通过与样本来源交互,可以从样本来源获取所需的样本数据,同时还可以对样本数据进行各种操作,如shuffle(打乱顺序)、take(获取一定数量的样本)、skip(跳过一定数量的样本)、batch(将样本分批次获取)和func(对样本进行自定义操作)。

类的属性包括:

  • sample_source: 样本来源,用于获取样本数据

  • set_name: 样本集名称,用于标识当前的样本集

  • count: 样本数量

类的方法包括:

  • __init__: 初始化类,设置样本来源和样本集名称

  • __iter__: 迭代器方法,返回当前类的迭代器

  • sample_source: 返回样本来源

  • set_name: 返回样本集名称

  • count: 返回样本数量

  • _base_iter: 基础迭代器,用于在迭代过程中管理样本数据的获取和处理

  • shuffle: 打乱样本顺序

  • take: 获取一定数量的样本

  • skip: 跳过一定数量的样本

  • batch: 将样本分批次获取

  • func: 对样本进行自定义操作

示例:

# 创建一个SampleSet对象

sample_set = SampleSet(sample_source, 'train')

# 对样本进行打乱顺序

sample_set.shuffle()

# 获取10个样本

sample_set.take(10)

# 跳过5个样本

sample_set.skip(5)

# 将样本分批次获取,每批次5个样本

sample_set.batch(5)

# 对样本进行自定义操作

sample_set.func(lambda x: x * 2)

注意: 该类没有显著的错误或者bug,但在处理大规模数据时可能需要考虑内存和性能问题。

__init__

SampleSet类是一个用于管理和处理样本集的类。它提供了各种方法,如shuffletakeskipbatchfunc,以对样本集进行操作,如打乱样本、取样本、跳过样本、分批样本以及应用函数到样本。此外,它还保存了有关样本集的一些信息,如样本源、集合名称以及样本的数量等。

此类的初始化方法定义如下:

**def \_\_init\_\_(self, source_base: NLSampleSourceBase, set_name: str):**

## 参数

  • source_base (NLSampleSourceBase): 样本源的基类对象,用于获取样本数据和样本元数据等信息。

  • set_name (str): 样本集的名称。

## 属性

  • _sample_source (NLSampleSourceBase): 存储source_base参数,表示样本源。

  • _set_name (str): 存储set_name参数,表示样本集的名称。

  • _shuffle (bool): 初始化为False,表示样本是否需要打乱。

  • batch_count (int): 初始化为None,表示每批样本的数量。

  • _data_keys (list): 存储从样本源获取的样本元数据键。

  • _iter_keys (list): 初始化为空列表,用于存储迭代的键。

  • _loaded_pointer (bool): 初始化为False,表示是否已加载指针。

  • _func (list): 初始化为只包含_base_iter函数的列表,用于存储需要对样本进行的函数操作。

  • _count (int): 存储从样本源获取的样本集的数量。

## 使用示例

from nl_sample_source_base import NLSampleSourceBase

\# 创建样本源

sample_source = NLSampleSourceBase()

\# 创建样本集

sample_set = SampleSet(sample_source, 'train')

\# 打乱样本

sample_set.shuffle()

\# 取100个样本

sample_set.take(100)

\# 跳过10个样本

sample_set.skip(10)

\# 分成每批10个样本

sample_set.batch(10)

\# 对每个样本应用函数

sample_set.func(lambda x: x**2)

__iter__

该方法是一个特殊的迭代器方法,允许SampleSet对象进行迭代操作。当Python执行for...in...循环时,如果在for后面的对象是一个迭代器,那么Python将会自动调用这个方法。

返回:

返回一个函数对象,该函数对象通过调用self._func-1方法得到。这意味着,每次迭代都将使用self._func列表中的最后一个函数进行。根据SampleSet类中其他方法对self._func的操作,这个函数可能实现了一系列的数据操作,比如取样、跳过、批处理等。

例子:

假设我们有一个SampleSet对象sampleset,我们可以这么使用这个迭代器方法:





**for sample in sampleset:**

print(sample)

在这个例子中,每次迭代都会打印出一个样本。

警告:

在多线程环境中,由于self._func-1返回的函数对象可能会被其他线程修改,因此这个迭代器方法可能会产生预期之外的结果。为了避免这种情况,建议在单线程环境下使用这个方法,或者使用线程锁确保访问self._func的原子性。

sample_source

这是一个property函数,用于返回_sample_source属性。

返回:

_sample_source(NLSampleSourceBase类型): 返回提供样本集信息及访问的基本对象。

set_name

获取当前样本集的名称

这是一个简单的getter方法,不接受任何参数,返回结果是一个字符串,代表当前样本集的名称。

返回:

返回当前样本集的名称(set_name)。

count

count 是一个方法,用于获取 SampleSet 实例中的样本数量。

此方法无需任何输入参数。

返回值:

返回一个整数,表示 SampleSet 实例中的样本数量。

示例用法:

sample_set = SampleSet(source_base, set_name)

num_samples = sample_set.count()

此示例说明如何创建 SampleSet 类的实例并使用 count 方法获取样本数量。

不含已知错误或 bug。

_base_iter

_base_iter是一个私有生成器方法,主要用于数据集的迭代和随机打乱。

首次调用时,此方法会加载数据集的迭代指针并存储在_iter_keys列表中,其后的调用则直接从这个列表中读取。

如果_shuffle属性为True,则会对_iter_keys列表中的元素进行随机打乱。

每次迭代时,该方法都会根据当前的指针,从数据源中加载对应的数据,并以字典的形式返回。

注意,此方法是一个私有方法,仅供内部使用。

返回:

返回一个生成器,每次迭代返回一个含有数据的字典。

示例:

**for data in self._base_iter():**

process(data)

shuffle

此函数用于启用样本混洗功能。

函数shuffle(self)将实例变量self._shuffle标记为True,表示在生成器函数_base_iter(self)中,将对数据指针的迭代顺序进行随机打乱。

这个函数没有任何参数。

返回的是包含此函数的类的实例,即self,这样可以实现函数链式调用。

示例:

假设sample_set是SampleSet类的一个实例,

那么我们可以这样调用此函数:sample_set.shuffle()。

在这之后,当我们从sample_set中取出样本时,样本的顺序就会被随机打乱。

注意:这个函数必须在生成样本之前调用,如果在生成样本之后调用,将不会有任何效果。

在实际使用中,我们通常会这样使用:

sample_set.shuffle().batch(128)

这样可以先打乱样本的顺序,然后按照每128个样本为一组,进行分组操作。

take

take是一个实例方法,用于从样本集中获取指定数量的样本。

参数:

  • count: int

指定要从样本集中获取的样本数量。

返回:

  • SampleSet类的实例。通过调用此方法,可以在类实例上应用链式操作。

示例:

sample_set = SampleSet(source_base, set_name)

sample_set = sample_set.take(5)  \# 从样本集中获取5个样本




**for sample in sample_set:**

print(sample)

注意:

  • 此函数会更改内部计数器self._count的值,以反映从样本集中获取的样本数量。

  • 这个函数不会立即执行获取样本的操作,而是在迭代样本集时才会真正获取样本。这是通过在函数内部创建并添加一个新的生成器函数到self._func列表实现的。

  • 当请求的样本数量大于样本集中的可用样本数量时,此函数会将self._count设置为样本集的大小。

skip

skip函数用于在数据集中跳过指定数量的样本。

参数:

count (int): 需要跳过的样本数量。

返回:

SampleSet: 返回当前的SampleSet实例,便于链式操作。

用法示例:

sample_set = SampleSet(source_base, set_name)

sample_set.skip(10) # 跳过前10个样本

注意:

  1. 如果跳过的样本数量比当前样本集中的样本总数还要多,则所有样本都会被跳过,样本集的数量会被设置为0。

  2. 该函数会影响当前样本集的总样本数(self._count)。

batch

这个方法是SampleSet类的一个成员方法,其主要目标是将样本集按照指定的批量大小进行分批。

参数:

batch_count (int): 指定每批样本的数量。

返回:

SampleSet: 返回修改后的样本集对象,该对象的每个迭代都将产生一个包含指定数量样本的批次。

示例:

假设我们有一个SampleSet对象s,将其按每批10个样本进行分批,可以通过以下方式实现:

s.batch(10)

注意:

此方法可能会更改样本集的总数(self._count),这是因为它会将总数调整为能够容纳完整批次的最大数量。例如,如果样本总数是25,批次大小是10,那么总数将被减少到20,既然最后的5个样本不足以形成一个完整的批次,将不会被迭代产生出来。

此外,此方法会对样本集的迭代方式进行更改,使其在每次迭代时产生一个包含batch_count个样本的列表,而不是单个样本。

func

此函数是在SampleSet类中定义的一个方法,用于将给定的函数应用到SampleSet的每一个元素上。

参数:

func: 一个函数。这个函数将会应用到SampleSet的每一个元素上。

示例:

def double(x):

return x * 2

sample_set = SampleSet(source_base, set_name)

sample_set.func(double)

上述代码会将函数double应用到SampleSet的每一项上,即每一项都会乘以2。

注意:

func的参数类型和返回类型应与SampleSet中的元素类型一致。

返回:

SampleSet对象本身,用于链式调用。

utils

这个Python模块主要用于数据库状态表的操作和多线程处理。它提供了一系列函数,包括在数据库状态表中设置键值对,获取键对应的值,判断指定键是否存在,以及为指定键添加或更新值。此外,还提供了一个用于大文件中实现高效键值存储的类,一个显示和控制进度条的类,以及一个线程安全的原子计数器类。同时,它也支持多线程任务的并行处理。

类成员:

name info
chunk_file 这是一个用于在大文件中实现高效键值存储,不需加载整个文件到内存的类。
process_status_bar 这是一个显示和控制进度条的类,提供进度展示、时间转换、日志打印等功能。
AtomicCounter 这是一个线程安全的原子计数器类,用于在多线程环境下跟踪资源的使用,例如计数器或ID生成器。

函数成员:

name info
set_value 这是一个在数据库'state'表中设置键值对的函数,能替换或插入新记录,需确保表存在且参数正确。
get_value 这是一个从数据库中获取给定键对应值的函数,可能用于配置读取工具。
has_value 这是一个判断指定键是否在数据库状态表中存在的函数,存在返回True,不存在返回False。
value_add 这是一个用于数据库中为指定键添加或更新值的函数,需要确保依赖函数正确实现并注意异常处理。
do_multitask 这是一个多线程任务执行函数,用来并行处理任务以提高执行效率,结果和任务执行顺序可能不一致。

set_value

这是一个设置值的函数,它将会把key和value(键值对)存储在名为'state'的数据库表中。如果已经存在相同的key,那么它会用新的value替换旧的value,否则就插入一条新的记录。

参数:

key: 需要设置的键。它是任意可以被哈希化的对象,例如字符串、数字或者元组。

v: 与键对应的值。它也是一个任意的对象,可以是数字、字符串、列表、字典等等。

返回值:

这个函数没有返回值。

使用方法:

set_value('name', '张三')

set_value(1, 100)

注意事项:

使用这个函数前,需要确保已经创建了名为'state'的数据库表,且该表有'key'和'value'两列。如果表不存在或者列不正确,那么这个函数将会抛出异常。

此外,这个函数没有做任何的类型检查或者错误处理,所以调用者需要确保传入的参数是正确的,否则可能抛出异常。

get_value

这个函数用于从数据库中获取一个给定键对应的值。这个函数可能是一个配置读取工具的一部分,

用于获取某个配置项的值。

参数:

key: 要获取的键。这是一个字符串。

返回值:

返回一个元组。如果键在对象中,则返回一个包含值的元组,如果键不在对象中,返回一个空元组。

示例部分:

\# 获取 'my_key' 对应的值

value = get_value('my_key')




**if value:**

print('my_key 对应的值是 ', value)




**else:**

print('my_key 不在配置中')

注意,这个函数依赖于一个名为 _load 的函数,但是这个函数在这段代码中没有定义。我们假设 _load 函数的目的是加载一个包含配置信息的对象,并返回这个对象。

这个函数没有明显的错误或者bug。

has_value

该函数用于判断给定的键是否在数据库的状态表中存在。

参数:

key: 需要检查的键,作为字符串传入。

返回:

如果查询的键在状态表中存在,则返回True,否则返回False。

示例:

has_value('my_key')

True

has_value('nonexistent_key')

False

value_add

这是一个value_add函数,主要用于数据库中为指定的键添加值。如果该键尚不存在,则会先插入一条新的记录。

参数:

key: 需要添加值的键,数据类型为字符串。

v: 需要添加的值,数据类型为整型。

返回:

返回更新后键对应的值。

使用示例:

value_add('test', 1)

注意:

该函数依赖于_load()函数和get_value()函数,所以在调用value_add()函数前,请确保前述两个函数已被正确定义和实现。

这个函数没有进行异常处理,如果在执行SQL语句时出现错误,程序可能会崩溃。在后续版本中会考虑加入异常处理机制。

该函数在version 1.1版本中已更新。

chunk_file

这是一个chunk_file类,其主要目的是实现在一个大文件上的键值存储。它将所有的key和相应的value存储起来,每个value存储在文件的某个位置,这个位置由key指示。这个类的设计使我们可以在不必将整个文件加载到内存的情况下高效地获取键值。这在处理大数据集时非常有用,尤其是当数据集大到无法完全装入内存时。

示例使用方法如下:

创建一个实例:

cf = chunk_file('example')

添加键值对:

cf.add('key1', 'value1')

cf.add('key2', 'value2')

获取键值对:

value1 = cf.get('key1')  \# 返回'value1'

value2 = cf.get('key2')  \# 返回'value2'

该类的主要方法和属性如下:

  • \_\_init\_\_(file_path):初始化方法,file_path是将要存储数据的文件路径。

  • \_\_contains\_\_(item):判断item是否在当前的键值对中。

  • \_\_iter\_\_():返回一个迭代器,可以遍历所有的键值对。

  • \_\_getitem\_\_(key):获取给定key对应的值。

  • verify_data():验证所有的数据是否完整,因为该方法只是简单地遍历所有数据,所以总是返回True。

  • add(key, value):向文件中添加一对键值对,返回存储的值的长度。

  • get(key):获取给定key对应的值。

  • flush():将当前的键值对写入硬盘。

注意,这个类没有提供删除或修改数据的方法。另外,在使用这个类时,需要确保操作系统支持大文件的处理。

__init__

这是一个名为chunk_file的类,其目的是为了处理二进制文件,并将其分块存储以便更有效地进行搜索和访问。

这个类可以以键值对的形式存储和读取二进制数据,其中键用于标识特定的数据块,而值则是要存储的实际数据。

数据将被pickle化并存储在二进制文件中,同时还会创建一个索引字典,其中包含每个键对应的文件位置和长度信息。

此索引字典也会作为一个.key文件存储在磁盘上,以便在后续的会话中快速加载和访问。

类的初始化方法:\_\_init\_\_(self, file_path)

此方法是为了初始化一个chunk_file对象。

参数:

file_path:存储数据块的二进制文件的路径。

它会首先根据提供的file_path生成.key和.value文件的路径。

然后,它会尝试从.key文件加载索引字典,如果文件不存在,则会创建一个空的索引字典。

最后,它会打开.value文件以便后续的读写操作。

__contains__

这是一个魔法方法,用于支持 in 操作符。它检查一个给定的键(item)是否在 index_dict 字典中。这个字典存储了文件的索引信息,其中键是数据的标识符,值是数据在文件中的位置和长度。

参数:

item:需要查询的键。

返回:

bool:如果键在 index_dict 中,返回 True,否则返回 False。

__iter__

这是一个迭代器方法,用于将类的实例对象作为迭代器使用。

在每一次迭代时,它会返回一个元组,元组的第一个元素为键,第二个元素是通过get方法得到对应的值。

这个方法允许我们遍历整个索引字典,同时获取字典中每一个键对应的值。

例如:

chunk = chunk_file("your_file_path")




**for key, value in chunk:**

print(f"Key: {key}, Value: {value}")

注意:

如果getitem方法无法获取到key对应的值,get方法会返回None,这时迭代出的value也就是None。

参数列表:

返回类型介绍:

返回一个元组,元组的第一个元素为键,第二个元素是通过get方法得到对应的值。

错误或者bug:

暂无

__getitem__

该方法实现了类的特殊方法__getitem__,使得类对象可以通过key值索引查询到对应的value。

参数:

self: 类的实例。

key: 需要查询的键值。

返回:

通过pickle模块反序列化后的数据。

例子:

chunk = ChunkFile('some_file')

value = chunk['some_key'] # 使用__getitem__方法查询对应的value值。

注意:

如果查询的key值不存在于index_dict字典中,那么方法会返回None。

verify_data

此函数用于验证存储的数据是否可以正常读取,主要是验证存储数据的完整性。它会遍历类内部的所有键值对,尝试读取每个键的值,但并不实际返回任何值。如果在读取过程中没有发生任何错误,说明存储的数据是完整的,此时返回True。如果在读取过程中发生了错误,此函数会抛出异常。

参数:

返回:

Boolean : 如果所有数据都被成功读取,返回True。

使用示例:

cf = chunk_file('test_file')

cf.add('key1', 'value1')

cf.verify_data()

True

注意:

此函数没有捕获可能抛出的异常,因此在使用时需要配合try...except...结构使用,以捕获可能出现的读取错误。

错误与异常:

如果在读取数据时出现错误,此函数会抛出异常,具体的异常类型取决于错误的性质。

add

该函数主要用于向文件中添加数据,并且将数据存储位置及长度的信息存储到索引字典中。

参数:

  • key: 待添加数据的关键字,用于数据的检索。

  • value: 实际需要存储的数据。

返回值:

  • act_len: 实际存储数据的长度。

使用方法:

首先,函数创建一个BytesIO缓存对象,并使用pickle将待存储的数据序列化后存储到缓存对象中。接着,函数计算出序列化数据的实际长度,并将缓存对象的指针重新设置到初始位置。

然后,函数调整文件对象的指针到下一次写入位置,并将序列化后的数据写入到文件中。函数将数据的存储位置和长度作为值,关键字作为键,添加到索引字典中。最后,函数将文件对象的指针位置设置为文件的当前位置,并返回实际存储数据的长度。

get

这是一个用于从文件中获取值的函数。

对于给定的键,该函数首先检查键是否存在于索引字典中。如果键不存在,函数将返回None。如果键存在,它将找到与该键关联的值在文件中的位置和长度,然后在该位置读取长度为len的数据,并将数据加载回Python对象,然后返回。

函数参数:

key : 需要获取的数据的键。

返回类型:

函数返回与键关联的数据值。如果键不存在,则返回None。

注意:如果键不存在于索引中,函数将返回None。此外,函数假设存在键的值可以成功的被pickle模块加载回Python对象,如果不能加载,函数可能会抛出异常。

flush

这个函数是chunk_file 类的一个成员方法。该方法的主要作用是将当前对象的状态持久化到磁盘中。具体实现是,先刷新已打开文件对象的IO缓冲区,然后将索引字典index_dict保存到.key文件中。

函数没有输入参数,也没有返回值。

示例:

cf = chunk_file("/path/to/chunkfile")

cf.add("key", "value")

cf.flush()  \# 此时,磁盘上的`/path/to/chunkfile.key`文件中保存了索引字典

注意,调用该函数后,如果不再对chunk_file对象进行修改,那么就可以安全的关闭Python进程,不会丢失已添加到chunk_file对象中的数据。

process_status_bar

这是一个用于处理状态条显示的类,可以用于需要状态条展示进度的情况,显示进度条的长度、进度、剩余时间等信息。

属性:

_iter_stack: 用于储存状态的栈。

_process_bar_len: 进度条显示长度。

_print_str: 用于打印的字符串。

hidden: 是否隐藏进度条。

方法:

_cast_second_strformat(self, sec): 将给定的秒数转换为时:分:秒的格式。

flush(self): 根据状态栈更新并显示进度条。

iter_bar(self, iter_item, value=0, key=None, max=None): 开始一个新的进度显示,显示对应的迭代对象的进度。

start(self, key, max, value=0): 开始一个新的状态条显示。

set_value(self, v): 设置进度条的当前值。

one_done(self): 完成一个任务,进度值加1。

stop_current(self): 停止当前的状态条显示。

process_print(self, str): 打印自定义字符串。

print_log(self, str): 打印日志信息。

使用示例:

psb = process_status_bar(30)

for item in psb.iter_bar(range(100), key="Processing", max=100):

job_do_something(item)

psb.stop_current()

注意事项:

  • 当迭代对象为无限长度或者非可迭代对象时,需要手动设置max值。

  • 在使用iter_bar开始一个新的进度显示时,需要在完成任务后手动调用stop_current停止当前的状态条显示。

__init__

初始化一个进程状态栏(process status bar)类。

参数:

processbar_length (int, 默认为20): 用于指定进程状态栏的长度。

此类的主要目的是用于在控制台显示进程或任务的进度信息,例如,当你需要在循环中处理大量数据时,你可以使用此类来跟踪和显示进度。此类还可以显示每个任务的剩余时间和平均处理时间,这对于估计长时间运行的任务非常有用。

示例:


bar = process_status_bar(30)




**for i in bar.iter_bar(range(100), key="Processing"):**

time.sleep(0.1)  \# 模拟数据处理

在上述示例中,我们首先创建了一个长度为30的进程状态栏实例。然后我们使用iter_bar方法在循环中处理数据。在此方法中,我们需要传递一个可迭代对象以及一个关键字参数key来描述正在进行的任务。在每次循环迭代中,我们都会使用time.sleep来模拟数据处理的过程,处理完一个数据,状态栏会自动更新进度。

此类没有明显的错误或BUG,但是在处理无法预计长度的迭代对象(如生成器)时,可能无法正确显示进度。对于这种情况,你需要手动设置进度条的最大值。

_cast_second_strformat

此函数的主要目的是将秒数转为字符串格式的时间展示。

参数:

sec: int

输入的秒数

返回:

str

返回一个HH:MM:SS格式的字符串,其中HH、MM、SS分别表示小时、分钟和秒。

例如,输入3661秒,返回'01:01:01'。

注意:

此函数假设输入的秒数sec为非负整数。如果输入的sec为负数或者非整数,可能会产生无法预期的结果或错误。

flush

_flush_方法是process_status_bar类内部使用的方法,主要用于刷新和输出进度条的状态。这个方法首先检查是否需要隐藏进度条,如果需要则直接返回。然后,它会构建含有所有迭代器状态的字符串,包括进度百分比、剩余时间等信息。然后,它会更新当前迭代器的平均执行时间和剩余执行时间。最后,它会通过日志输出当前的进度条状态。这个方法没有参数和返回值。

这个方法可能会有一些小问题。比如,当迭代器的最大值为0时,进度条的长度可能会取整到0,导致进度条显示不正确。此外,当迭代器的值为0时,平均执行时间和剩余执行时间都会被设置为0,这可能不是我们期望的结果。在实际使用中需要注意这些问题。

iter_bar

iter_bar是一个成员函数,这个函数主要是为了在迭代过程中创建进度条。

参数:

iter_item (Iterable): 需要迭代的对象。

value (int, 可选): 迭代开始时的值,默认为0。

key (str, 可选): 进度条的名字,默认为None,在此情况下,会自动生成名字为"Iter i"(i为当前进度条所在的堆栈位置)。

max (int, 可选): 迭代对象的最大长度,默认为None,在此情况下,会尝试获取iter_item的长度作为最大长度。

返回:

generator: 对输入的迭代对象进行封装,每次迭代完成后,都会更新进度条。

例子:

bar = process_status_bar()

for i in bar.iter_bar(range(100)):

print(i)

这将打印出从0到99的数字,同时在控制台显示进度条。

注意在循环结束后,进度条会自动调用stop_current来结束。如果在循环过程中出现异常需要提前结束,你需要手动调用stop_current来清理进度条。

异常:

如果iter_item不是一个有限长度的可迭代对象,这个函数将抛出一个异常。

start

开始一个新的进度追踪过程。每个进度追踪过程用一个字典进行存储,其中包括进度追踪的键(标记)、最大值、当前值、开始时间、平均耗时和剩余时间。这些信息会在进度条中显示。

Args:

key (str): 进度追踪的标记,如果未指定,则默认为"Iter {len(self._iter_stack)}"格式的字符串。

max (int): 进度追踪的最大值。

value (int, optional): 进度追踪的当前值,默认为0。

Returns:

None

注意事项:

  1. 如果同一个进度条对象中,连续调用了多次start() 方法,但未对应调用stop(),那么会形成一个进度追踪的栈结构。

  2. 在进度条显示时,会依次展示栈中所有进度追踪的信息,并以 ">>" 分隔。

  3. 调用stop()方法时,会出栈最顶层的进度追踪过程。

set_value

这是一个更新进度条中当前任务进度的方法。

参数:

v: 这是一个整数,代表当前任务完成的进度。

返回:

这个函数没有返回值。

用法:

\# 创建一个进度条对象

p = process_status_bar()

\# 开始一个名为'task1',总进度为100的任务

p.start('task1', 100)

\# 设置任务'task1'完成了30的进度

p.set_value(30)

注意:

  1. 如果v超过了任务的总进度,可能会导致进度条显示错误。

  2. set_value方法只会更新最近一次start开始的任务的进度,不会影响其他任务。

one_done

one_done(self): 这个函数的主要职责是更新进度条的进度。

函数名称:one_done

函数目的:该函数会将进程栈中的最后一个元素的值增加1,然后更新进度条。每当一个任务完成时,此函数被调用一次。

参数列表:该函数不接受任何参数。

返回类型:无返回值。

使用示例:


bar = process_status_bar()




**for i in range(10):**

\# 进行一些操作

bar.one_done()

注意:该函数不会检查进度是否超过了最大值,因此使用时需要确保任务的总数不会超过预先设定的最大值。

stop_current

stop_current方法用于终止当前进度条,并将其从进度条栈中移除。如果所有的进度条都已经被移除,那么在日志中留下一个标记。

该方法没有参数。

返回类型: 无。

使用示例:

processbar = process_status_bar()

processbar.start('Task1', 10)




**for i in range(10):**

\# 执行相关任务

processbar.one_done()  \# 每完成一个子任务,调用一次one_done()

processbar.stop_current()  \# 完成所有子任务后,调用stop_current()结束当前进度条

注意:stop_current应当在完成所有子任务后调用,否则可能引发错误。

process_print

process_print 是一个成员方法,其主要功能是为进度条附加一个字符串,用于描述当前进度条的状态,例如:"正在执行某操作..."。这个方法会将传入的字符串作为状态信息显示在进度条的末尾。同时,此方法会调用 _flush_ 方法对进度条进行刷新,使得新的状态信息能够立即显示出来。

参数:

str: 字符串类型,用于描述当前进度条的状态信息。

返回值:

无返回值。

使用示例:

p = process_status_bar(processbar_length=20)

p.start(key='Step1', max=100)




**for i in range(100):**

time.sleep(0.1)  \# 模拟耗时操作

p.one_done()




**if i == 50:**

p.process_print('已完成一半工作')

p.stop_current()

注意事项:

本方法不对输入字符串str进行任何安全性或合法性检查,请确保输入的str为合法的字符串,且不含有可能破坏进度条显示效果的特殊字符。

print_log

print_log(self, str)方法是process_status_bar类的一个成员方法,用于打印日志信息。

参数:

str: 需要打印的日志信息,类型为字符串。

返回:

无返回值。

使用示例:

bar = process_status_bar()

bar.print_log("开始处理...")

...

bar.print_log("处理完成...")

以上代码演示了如何使用print_log方法打印日志信息。

注意: 该方法内部调用了log函数打印日志,但是这里没有提供log函数的定义,并且该方法的注释部分也被注释掉了,所以在使用这个方法前,需要保证log函数已经被定义,否则会引发NameError异常。

AtomicCounter

这是一个线程安全的原子计数器类,使用python的内置线程锁(threading.Lock)实现。

类的主要目的是在多线程环境下提供一个安全的自增操作。它用于在多线程中跟踪某些资源的使用情况,例如计数器或ID生成器等。

类的使用方式如下:

counter = AtomicCounter() \# 创建一个原子计数器实例

counter.increment() \# 自增1

counter.increment(3) \# 自增3

print(counter.Value) \# 获取当前值

主要方法:

  • __init__:类的构造函数,初始化value为0,_lock为线程锁。

  • increment(add_value=1):自增函数,参数为自增的值,默认为1。使用线程锁保证在多线程环境下的安全性。函数返回自增后的值。

  • Value:类的property属性,返回当前的计数值。

注意:目前类是线程安全的,但是在多进程环境下未经测试,可能会有问题。

__init__

初始化AtomicCounter类。

这个类是一个线程安全的计数器,用于在多线程环境中安全地增加一个计数值。它使用了一个线程锁来确保在增加计数值时的线程安全。

属性:

value: 计数器的当前值。初始值为0。

_lock: 一个线程锁,用于在增加计数值时保持线程安全。

使用示例:

counter = AtomicCounter()

counter.increment()

print(counter.Value)  \# 输出: 1

increment

此函数是为了增加AtomicCounter类的value值。

参数:

add_value (int, optional): 要增加的值,默认值为1

返回:

Incremented value after adding the add_value to the current value.

此函数通过使用线程锁确保了在多线程环境下的安全性,可以防止数据竞争。

示例:

counter = AtomicCounter()

counter.increment(5)

print(counter.Value) # 输出: 5

counter.increment()

print(counter.Value) # 输出: 6

注意:此函数不是线程安全的,如果在没有使用线程锁的情况下在多线程环境中使用,可能会导致数据的不一致。为了避免这种情况,应当始终在调用此函数时使用线程锁。

Value

这是一个类的属性方法,用于获取AtomicCounter类实例的当前值。

Attributes:

Returns:

返回AtomicCounter类实例的当前value值,为整数类型。

Examples:

假设我们有一个AtomicCounter类的实例counter,可以通过以下方式获取其当前值:

counter = AtomicCounter()

counter.increment(5)

5

print(counter.Value)

5

Notes:

这个方法是线程安全的,可以在多线程环境下安全使用。

do_multitask

这是一个多线程任务执行函数,目的是将任务分配给多个线程并行处理,以提高任务执行效率。使用生产者-消费者模型,一个线程负责将任务放入队列,多个工作线程从队列中取出任务执行,然后将结果放入结果队列,最后从结果队列中取出所有结果。

参数:

iterations: 可迭代对象,表示需要处理的任务集合

task_fun: 函数,表示处理任务的函数,接受一个参数,即从iterations中取出的任务

thread_count: int,可选参数,默认为3,表示工作线程的数目

max_queue_buffer: int,可选参数,默认为0,表示任务队列和结果队列的最大容量,如果为0,则队列容量无限制

返回:

生成器,每次生成一个元组,元组的第一个元素是任务,第二个元素是该任务的处理结果

示例:

def task_fun(x):

return x * x

for item, result in do_multitask(range(10), task_fun, thread_count=5):

print(f'任务:{item},结果:{result}')

注意事项:

  • 输入的任务集合必须是可迭代的

  • 处理任务的函数必须接收一个参数

  • 线程数必须是正整数

  • 队列最大容量必须是非负整数

  • 本函数不保证任务的执行顺序和结果的生成顺序与输入的任务集合的顺序一致

evaluate

该python模块主要用于模型性能的评估,具备定义评分系统基本接口的抽象类,以规范评分系统的基本结构和调用方式。其包含了一系列评估器类,能计算分类评价指标(包括混淆矩阵)并返回结果,还能计算并返回每个类别的精确度、召回率和F1分数等关键模型性能评估指标。

类成员:

name info
ScoreAbstractClass 这是一个定义评分系统基本接口的抽象类,规范评分系统的基本结构和调用方式。
Evaluator 这是一个用于计算分类评价指标(包括混淆矩阵)并返回结果的评估器类。
AccuracyScores "AccuracyScores"是一个继承自ScoreAbstractClass的类,用于计算并返回每个类别的精确度。
PrecisionScores 这是一个名为PrecisionScores的类,用于计算并返回混淆矩阵中各类别的精确度得分。
RecallScores RecallScores类用于计算和返回召回率,它有两个主要方法:返回名称的Name和基于混淆矩阵计算召回率的get_score
F1Scores 这是一个继承自ScoreAbstractClass的类,用于计算和返回F1分数(模型性能评估指标)。

函数成员:

name info

ScoreAbstractClass

这是一个抽象类ScoreAbstractClass,它定义了一个评分系统的基本接口。它使用了抽象类元类abc.ABCMeta,强制要求所有子类必须实现定义的抽象方法。这个类主要用于规范和定义评分系统的基本结构,为实际的评分系统提供统一的调用方式。

它定义了两个抽象方法:

  1. Name: 这是一个property装饰的方法,要求所有子类必须提供一个名为Name的属性。这个属性返回一个str类型,用于描述这个评分系统的名称。

  2. get_score: 这是一个需要子类实现的方法,它接收一个名为confusion_matrix的参数,并返回一个列表类型的评分结果。

示例:

class MyScore(ScoreAbstractClass):

@property

def Name(self) -> str:

return "MyScoreName"

def get_score(self, confusion_matrix) -> []:

return [1, 2, 3] # 自定义的评分逻辑

my_score = MyScore()

print(my_score.Name) # 输出: MyScoreName

print(my_score.get_score(None)) # 输出: [1, 2, 3]

Name

这是一个抽象方法,需要在子类中实现。

返回:

str:返回评分标准的名称。

注意事项:

这是一个抽象属性,不能直接使用,需要在子类中实现。

错误或异常:

如果在子类中没有实现这个方法,那么在实例化子类的时候,会抛出TypeError异常。

示例:

**class AccuracyScore(ScoreAbstractClass):**

@property




**def Name(self) -> str:**

return 'accuracy'

在上面的示例中,我们创建了一个名为AccuracyScore的子类,并在子类中实现了Name方法,返回了'accuracy'字符串。

get_score

这是一个抽象方法,需要在子类中实现。目的是根据输入的混淆矩阵,计算并返回分数。

参数:

confusion_matrix (list): 混淆矩阵,是一个二维列表。

返回:

list: 根据混淆矩阵计算得到的分数,返回值类型为列表。

示例:

class ScoreConcreteClass(ScoreAbstractClass):

@property

def Name(self):

return 'ScoreConcreteClass'

def get_score(self, confusion_matrix):

# 计算并返回分数

pass

score_concrete_class = ScoreConcreteClass()

score = score_concrete_class.get_score([[1, 0], [0, 1]])

Evaluator

这是一个评估器类(Evaluator),用于处理分类问题的评估。

这个类的主要目的是计算出分类问题的混淆矩阵,并基于这个混淆矩阵,计算出一系列分类评价指标的结果。

这个类需要在初始化时传入一个评价指标方法列表,这个列表中的每个元素都应该是一个继承自ScoreAbstractClass的实例对象,每个对象都应实现了get_score方法。

这个类主要有三个方法,分别是:get_confusion_matrix,get_per_result和get_all_result。

  • get_confusion_matrix方法:计算出分类问题的混淆矩阵。

参数有两个,分别是真实标签列表和预测标签列表。

返回两个值,一个是标签到索引的映射字典,另一个是混淆矩阵。

  • get_per_result方法:计算出每个类别的评价指标结果。

参数有三个,分别是真实标签列表,预测标签列表,以及一个可选的return_confusion_matrix参数,默认为False,表示是否返回混淆矩阵。

返回一个包含每个评价指标结果的字典,如果return_confusion_matrix为True,还会返回混淆矩阵。

  • get_all_result方法:计算出所有类别的平均评价指标结果。

参数有三个,分别是真实标签列表,预测标签列表,以及一个可选的return_confusion_matrix参数,默认为False,表示是否返回混淆矩阵。

返回一个包含所有类别平均评价指标结果的字典,如果return_confusion_matrix为True,还会返回混淆矩阵。

示例:

**class MyScore(ScoreAbstractClass):**




**def get_score(self, confusion_matrix):**

\# 实现自己的评价指标计算方法

pass

my_score = MyScore()

evaluator = Evaluator([my_score])

true_labels = [0, 1, 0, 1, 0, 1]

predict_labels = [0, 1, 1, 0, 0, 1]

print(evaluator.get_all_result(true_labels, predict_labels))

__init__

初始化Evaluator类。

Evaluator类是一个评价器,其目的在于计算分类模型的性能。它可以计算多个评分方法并返回结果。对于每个评分方法,Evaluator类将计算混淆矩阵,并根据这个混淆矩阵得到每个类别的评分结果。

参数:

score_methods (list[ScoreAbstractClass]): 评分方法类的列表,每个类都应该从ScoreAbstractClass继承,并实现get_score方法。

例子:

from sklearn.metrics import accuracy_score, precision_score

evaluator = Evaluator([accuracy_score, precision_score])

true_labels = [0, 1, 1, 1, 0]

predict_labels = [0, 1, 0, 1, 1]

result = evaluator.get_all_result(true_labels, predict_labels)

print(result)

{'accuracy_score': 0.6, 'precision_score': 0.6666666666666666}

get_confusion_matrix

该方法用于生成混淆矩阵。混淆矩阵是一种常用的模型评估工具,特别是在处理多类别分类问题时。混淆矩阵显示了模型预测的类别和实际类别的对应情况。

参数:

true_labels (list): 真实的标签列表。

predict_labels (list): 模型预测的标签列表。

返回:

(dict, list): 返回一个元组,第一个元素是一个字典,包含所有唯一标签及其在混淆矩阵中的索引;第二个元素是一个二维列表,表示混淆矩阵,其行和列的顺序与第一个元素中标签的顺序一致。

例子:

evaluator = Evaluator([ScoreClass()])

true_labels = ['dog', 'cat', 'dog', 'fish']

predict_labels = ['dog', 'fish', 'dog', 'cat']

evaluator.get_confusion_matrix(true_labels, predict_labels)

({'fish': 0, 'dog': 1, 'cat': 2}, [[0, 0, 1], [0, 2, 0], [1, 0, 0]])

在这个示例中,'fish', 'dog', 'cat'分别在混淆矩阵中的索引为0,1,2。混淆矩阵表示'fish'被预测为'cat'一次,'dog'被正确预测两次,'cat'被预测为'fish'一次。

注意:

如果输入的真实标签和预测标签的数量不一致,将可能出现错误。

get_per_result

该函数主要用于获取每个类别的评分结果。首先,它会生成一个混淆矩阵,然后根据所有预定义的评分方法计算每个类别的评分,并将结果存储在字典中。

参数:

true_labels (list): 真实标签列表。

predict_labels (list): 预测标签列表。

return_confusion_matrix (bool): 是否返回混淆矩阵,默认为False。

返回:

如果return_confusion_matrix为True,那么返回一个元组,包括一个字典和一个混淆矩阵。字典的键是评分名称,值是另一个字典,里面包含每个类别的评分。混淆矩阵是一个二维列表,表示混淆矩阵的每个元素。

如果return_confusion_matrix为False,那么只返回一个字典,其结构与上述相同。

示例:

evaluator = Evaluator([ScoreMethod1(), ScoreMethod2()])

result_dict, confusion_matrix = evaluator.get_per_result(true_labels, predict_labels, return_confusion_matrix=True)

print(result_dict)

print(confusion_matrix)

get_all_result

此函数是Evaluator类的一个方法,用于获取模型在所有类别上的预测结果。

参数:

true_labels (list): 真实标签的列表。

predict_labels (list): 模型预测的标签列表。

return_confusion_matrix (bool, 可选): 是否返回混淆矩阵。默认为False。

返回:

返回一个包含每一个评分方法名称及其对应结果的字典;当return_confusion_matrix设为True时,同时返回混淆矩阵。

此函数首先调用get_per_result方法获取每个类别的评分结果,然后计算所有类别的平均分,并存储为一个新的字典。字典的键是评分方法的名称,值是该评分方法在所有类别上的平均分。如果某类别的评分不存在,将其值设为None。

示例:


e = Evaluator([score_method1, score_method2])

result = e.get_all_result(true_labels, predict_labels)

AccuracyScores

这是一个名为AccuracyScores的类,该类继承自ScoreAbstractClass抽象类。AccuracyScores类的主要目的是通过传入的混淆矩阵来计算和返回每个类别的精度。

每个类别的精度是通过混淆矩阵的对角线上的元素(正确分类的数量)除以该类别总的预测数量来得到的。

如果某类别没有预测(即混淆矩阵的某行和为0),则其精度被设置为None。

类方法介绍:

  • Name:这是一个属性方法,返回评分类的名称,即"Accuracy"。

  • get_score:这个方法接受一个混淆矩阵作为输入,返回一个列表,列表中的每个元素代表每个类别的精度。如果某类别没有预测(即混淆矩阵的某行和为0),则其精度被设置为None。

使用示例:

假设我们有一个二分类问题的混淆矩阵[[5, 2], [3, 7]],我们可以通过以下方式使用 AccuracyScores 类来计算每个类别的精度:

confusion_matrix = [[5, 2], [3, 7]]

accuracy_scores = AccuracyScores()

accuracies = accuracy_scores.get_score(confusion_matrix)

print(accuracies)

输出结果为:[0.7142857142857143, 0.7],表示类别1的精度为0.71,类别2的精度为0.7。

Name

这是一个类方法,用于获取当前类的名称。

返回:

str: 返回字符串 "Accuracy",作为当前类的名称。

get_score

此函数是 AccuracyScores 类的一个方法,用于根据混淆矩阵计算并返回每个类别的精确度。

参数:

confusion_matrix (list): 是一个二维列表,表示混淆矩阵。混淆矩阵的行对应真实类别,列对应预测类别。每一行的总和即为该类别的总预测次数。

返回:

list: 一个列表,其中包含每个类别的精确度。列表的索引i对应的是第i类的精确度。如果某个类别的总预测次数为0,则返回None。

示例:

confusion_matrix = [[10, 2, 3], [0, 8, 1], [0, 1, 9]]

accuracy_scores = AccuracyScores()

print(accuracy_scores.get_score(confusion_matrix))  \# 输出:[0.6666666666666666, 0.8888888888888888, 0.9]

注意:

如果混淆矩阵的某一行(即某一类)的总预测次数为0,表示没有对该类进行预测,此时该类的精确度无法计算,我们将其设置为None。如果你希望在这种情况下返回其他值,可以修改以下代码:

accuracies.append(None)  \# 或设为你希望在这种情况下返回的其他值

PrecisionScores

这是一个名为PrecisionScores的类,它继承了ScoreAbstractClass抽象类。这个类的目的是用来计算和返回混淆矩阵中的准确度得分。

Properties:

Name : str

返回字符串"Precision",表示这个类的名称。

Functions:

get_score(self, confusion_matrix) -> list:

这个函数接收一个混淆矩阵作为输入,计算并返回一个列表,其中包含了每个类的精确度得分。

参数:

confusion_matrix : list

这是一个二维列表,代表混淆矩阵。它的行表示实际的类别,列表示预测的类别。

返回:

precisions : list

这是一个列表,其中包含了每个类的精确度得分。精确度得分是由真正例数除以预测正例数得到的。如果预测的正例数为0,则该类别的精确度得分为None。

例如:

如果我们有一个二分类问题的混淆矩阵[[1, 2], [0, 2]],那么对于第一个类别,真正例数为1,预测正例数为1,所以精确度得分为1/1=1。对于第二个类别,真正例数为2,预测正例数为4,所以精确度得分为2/4=0.5。所以,这个函数会返回列表[1, 0.5]。

注意:这个类需要ScoreAbstractClass抽象类作为父类,所以确保在使用前已经导入了这个抽象类。

Name

此函数是PrecisionScores类的一个方法,用于返回字符串"Precision"。

方法的返回类型为字符串。

Args:

self: 代表类的实例。

Returns:

返回字符串 "Precision"。

Example:

precision_scores = PrecisionScores()

name = precision_scores.Name

print(name) # 输出 "Precision"

get_score

该函数用于计算精确度分数。精确度分数是评估预测模型分类结果质量的一个重要指标,其定义为:在预测为正类的样本中,真实为正类的比例。

参数:

confusion_matrix (list): 混淆矩阵,二维数组,每一行对应真实类,每一列对应预测类。对角线元素表示预测结果与真实情况相符的数量。

返回:

list: 返回一个列表,列表的每一个元素对应一个类别的精确度。如果某个类别没有正例预测,则返回None。

举例:

假设我们有一个混淆矩阵[[5, 2, 0], [3, 7, 1], [2, 4, 9]],代表有3个类别,分别为类别0,类别1,类别2。那么,这个函数会返回一个列表,里面的三个值分别对应这三个类别的精确度。

RecallScores

这是一个名为 RecallScores 的类,继承自 ScoreAbstractClass。这个类的主要目的是为了计算并返回“召回率”(Recall)。

“召回率”是在所有正样本中,预测为正样本的比率。它是衡量模型对正样本的预测能力,也是评估模型性能的重要指标之一。

这个类有两个主要的方法,一个是 Name,一个是 get_score:

  • Name 方法是一个属性方法,返回的是字符串 "Recall"。

  • get_score 方法是用来根据输入的 confusion_matrix(混淆矩阵)计算召回率的。它返回的是一个召回率的列表,列表中的每一个元素对应混淆矩阵中每一行的召回率。

使用方法:

假设我们有一个混淆矩阵 cm,我们可以这样来获取召回率:

recall_scores = RecallScores()

recalls = recall_scores.get_score(cm)

print(recalls)

注意:

如果在计算召回率时,某一行的实际正样本总数(即混淆矩阵一行的和)为0,那么这一行的召回率会被设置为None。

参数列表:

  • confusion_matrix:2维列表,表示混淆矩阵,真实值和预测值的组合情况。

返回类型:

  • get_score 返回一个列表,列表中的每一个元素对应混淆矩阵中每一行的召回率,如果某行的实际正样本总数为0,则对应的召回率为None。

Name

这是一个属性方法,用于获取当前类的名称。

返回:

str: 返回字符串 "Recall",表示这个类的名称。

get_score

get_scoreRecallScores 类中的一个方法,用于从混淆矩阵中计算每一类的召回率。

召回率是一个重要的评价指标,用于衡量我们的模型预测正例的能力。召回率的公式为TP/(TP+FN),其中TP代表真正例(实际为正例且被正确预测为正例的数量),FN代表假负例(实际为正例但被错误预测为负例的数量)。因此,get_score方法通过遍历混淆矩阵的每一行(每一类),并计算每一行的真正例数除以行内元素总和(真正例数 + 假负例数)来得到每一类的召回率。

如果某一类的真实样本总数(行内元素总和)为0,我们将该类的召回率记为None。

参数:

confusion_matrix (list of list of int): 混淆矩阵,每一行代表一个类别,每一行内的元素总和代表该类别的真实样本总数,行内对角线元素代表该类别的真正例数。

返回:

list: 返回一个列表,包含每一类的召回率。

示例:

confusion_matrix = [[2, 1], [1, 2]]

执行 get_score(confusion_matrix)

返回结果为 [2/3, 2/3]

注意:

如果混淆矩阵的输入不正确(例如,不是一个二维矩阵),或者混淆矩阵中包含负数,代码可能会出错。

F1Scores

这是一个计算F1分数的类,继承自ScoreAbstractClass。F1分数是精确度和召回率的调和平均数,是评价模型性能的一种常用指标。

类方法:

  • Name:返回评分方法的名字,即"F1"。

  • get_score:根据混淆矩阵计算F1分数。

属性:

  • Name:评分方法的名字。

方法:

  • get_score(self, confusion_matrix):计算F1分数。

- 参数:

  • confusion_matrix:混淆矩阵。

- 返回:

  • f1_scores:F1分数列表。

使用示例:

f1_score_calculator = F1Scores()

f1_scores = f1_score_calculator.get_score(confusion_matrix)

注意:在计算F1分数时,如果精确度和召回率中任何一个为None,或者两者之和为0,则F1分数为None。

Name

此方法是F1Scores类的一个属性方法。这个方法没有参数,它返回一个字符串,代表这个评分类的名称,即"F1"。这个名称可能被用于报告或者和其他评分方法进行比较。

返回:

str: 返回"F1",代表F1分数。

示例:

f1_calculator = F1Scores()

print(f1_calculator.Name) # 输出 "F1"

注意:

这个方法不接受任何参数,也不会改变对象的状态,只是提供一个固定的字符串。同时,作为一个属性方法,你不需要在调用时加括号。

get_score

这个get_score方法是在F1Scores类中定义的,该类继承自ScoreAbstractClass。该方法的主要目的是计算给定混淆矩阵的F1得分。

参数:

confusion_matrix (list): 需要计算得分的混淆矩阵。混淆矩阵是一个二维数组,每个元素表示预测类别和真实类别的匹配情况。

返回:

list: F1得分的列表。每个元素对应混淆矩阵中某个类别的F1得分。如果计算不出精确度或召回率,或者它们的和为零,对应的F1得分为None。

此方法首先使用PrecisionScoresRecallScores类计算混淆矩阵的精确度和召回率。然后,对于每个类别,使用以下公式计算F1得分:

F1 = 2 * ((precision * recall) / (precision + recall))

示例:

confusion_matrix = [[1, 2], [3, 4]]

f1_scores_calculator = F1Scores()

f1_scores = f1_scores_calculator.get_score(confusion_matrix)

print(f1_scores) # 输出: [0.6666666666666666, 0.5714285714285715]

请注意,此方法假设PrecisionScoresRecallScoresget_score方法返回的列表长度与混淆矩阵中的类别数量相同。如果这个假设不成立,那么此方法可能会引发错误。

此方法还可能返回包含None的列表,这时因为无法计算某个类别的F1得分。在处理返回的F1得分列表时,应当特别注意这一点。

ml

这个Python模块主要用于处理和管理图形数据结构和实体映射,提供了基于图形的模型构建,实体间映射关系的添加和查询。同时,它还包含了用于计算两个字符串间最少编辑距离的函数和查找字符串列表中最匹配字符串的函数,这些功能通常用于信息检索和自然语言处理。

类成员:

name info
Model_Base 这是一个抽象基类,为子类提供保存和加载模型的功能和定义模型名称的接口。
graph 这是一个基于Model_Base的graph类,用于创建和管理具有关联数据的图形数据结构。
mapping 'mapping'类是一个实体映射管理类,能添加和查询左右实体间的映射关系,不支持删除映射。

函数成员:

name info
lcs 这是一个计算两个字符串间Levenshtein距离(即最少编辑距离)的函数,常用于信息检索和自然语言处理。
match_2_string_list 这个函数找出第二个字符串列表中与第一个字符串列表中每个字符串最匹配的字符串的索引。

Model_Base

这个 Model_Base 类是其它模型的基类,采用抽象基类 (abc.ABCMeta) 作为元类。模型需要实现 model_name 属性。此外,它还提供了保存和加载模型的方法。

# 类的介绍和目的

这个类用于定义模型的基本结构和功能,它包含两个主要的功能:保存模型和加载模型。

\_\_init\_\_ 方法初始化一个空的字典 save_variables 用于存储模型的参数。

model_name 是一个抽象属性,子类需要根据自己的情况来实现这个属性。

save 方法是用来保存模型的参数到文件或者流中,如果传入的是字符串,它会被认为是文件路径;如果传入的是一个流对象,它会直接将模型参数保存到这个流中。

load 方法是用来从文件或者流中加载模型的参数,如果传入的是字符串,它会被认为是文件路径;如果传入的是一个流对象,它会直接从这个流中加载模型参数。

# 示例





**class MyModel(Model_Base):**

@property




**def model_name(self):**

return 'MyModel'

model = MyModel()

model.save('mymodel.pkl')

model.load('mymodel.pkl')

在这个例子中,我们定义了一个 MyModel 类,并实现了 model_name 属性。然后我们创建了一个 MyModel 的实例,并调用 saveload 方法来保存和加载模型。

注意: 这个类没有处理文件或者流读写错误的情况,当文件路径无效,或者文件权限不足,或者流对象无效时,可能会抛出异常。

__init__

这是Model_Base类的初始化函数。

在创建一个Model_Base类的实例时,会调用这个函数。这里并没有特别的参数需要传入。

Model_Base是一个基础模型的抽象类,它有两个基础功能:保存模型和加载模型。其中“模型”的保存和加载是通过pickle库实现的,保存和加载的对象是一个字典save_variables

在这个初始化函数中,我们创建了一个空字典save_variables,它用于以后保存模型的变量。

示例:

model = Model_Base()

注意:

Model_Base是一个抽象类,不能直接实例化。在实际使用中,通常会创建它的子类,并实现model_name这个抽象属性。

model_name

这是一个抽象方法,需要在子类中实现。这个方法定义了一个只读的属性model_name。由于用了@property装饰器,可以直接用object.model_name来访问,无需调用object.model_name()。由于该方法是抽象的,因此子类必须提供具体的实现。这个方法没有任何参数和返回值,但会返回在子类中实现的具体模型名。

例如:

**class MyModel(Model_Base):**

@property




**def model_name(self):**

return "my_model"

在这个例子中,如果有一个MyModel的对象mm.model_name会返回"my_model"

注意:在子类中重写这个方法时,要使用@property装饰器。

save

这个方法是用于保存模型的状态的。

参数:

path_or_stream: 一个str类型或者一个file-like object. 如果是str类型, 则它应该是一个要保存的文件的路径名. 如果是file-like object, 则将直接在这个流上进行写入操作.

返回:

None

这个方法不会返回任何值, 但是它会在指定的路径或流上保存模型的状态. 这个状态可以使用相同类的load方法进行加载.

示例:

model = Model_Base()

model.save('model.pkl')

在此例中, 我们创建了一个Model_Base实例, 并使用save方法保存了它的状态到一个叫做'model.pkl'的文件.

注意:

  1. 如果path_or_stream是一个文件路径, 则必须有足够的权限来写入该文件.

  2. 如果path_or_stream是一个file-like object, 则必须已经被打开, 并且可以进行写入操作.

load

此函数用于加载先前保存的模型变量。模型变量是以pickle格式存储的,可以从文件路径或者IO流中读取。

参数:

path_or_stream (str or file-like object): 如果是字符串,则它代表了要加载模型变量的文件路径。如果是file-like对象,则直接从该对象中读取模型变量。

返回:

无返回值。该方法将直接修改类实例的save_variables属性。

使用示例:

model = Model_Base()

model.load('/path/to/saved/model.pkl') # 从文件路径加载模型变量

with open('/path/to/saved/model.pkl', 'rb') as f:

model.load(f) # 从文件流加载模型变量

注意:

在使用该方法前,请确保你已经正确实现了model_name属性(抽象方法)。否则将会抛出NotImplementedError异常。

graph

这是一个graph类,基于Model_Base,用于创建和管理图形数据结构。图中的每个节点和线都可以有关联的数据。

属性:

model_name: 返回类名。

方法:

__init__: 初始化方法,创建四个字典,分别存储节点,节点线索引,线和线的端点。

__getitem__: 通过节点id获取节点数据。

__setitem__: 设置节点id和其对应的数据。

__contains__: 检查节点id是否存在于节点字典中。

__iter__: 迭代器,遍历并返回节点字典中的所有项。

add_node: 添加节点,需要节点id和可选的节点数据。

add_line: 添加线,需要两个节点id和可选的线数据。线id将自动创建。

get_node: 通过节点id获取节点数据。

get_relations: 通过节点id获取与其相连的所有线以及线的另一端节点的数据。可以通过start_to参数控制是获取从该节点出发的线还是指向该节点的线。

使用示例:

g = graph() # 创建一个图

g.add_node('node1') # 添加一个id为'node1'的节点

g.add_node('node2', {'color': 'red'}) # 添加一个id为'node2'的节点,并关联一个数据

g.add_line('node1', 'node2', {'weight': 10}) # 添加一条从'node1'到'node2'的线,并关联一个数据

print(g.get_node('node1')) # 获取并打印'node1'的数据

print(g.get_relations('node1')) # 获取并打印'node1'出发的所有线的数据

注意:

  • add_node时,如果节点id已存在,将抛出异常。

  • add_line时,如果任一节点id不存在,线将不会被添加,也不会抛出异常。

model_name

这是一个图模型类(graph),存储和操作简单图形的数据结构。

该类继承于Model_Base基类。

该类的数据是存储在字典中,包括存储节点数据(nodes)、节点线索引(node_lines_index)、线条数据(lines)、线条端点(lines_endpoint)等。

每个节点的id和数据都在字典nodes中存储,每个线条的id和数据都在字典lines中存储,每个线条的端点存储在字典lines_endpoint中,每个节点对应的线条索引存储在字典node_lines_index中。

提供了以下方法:

  • add_node:添加新的节点到图中。

  • add_line:添加新的线条到图中。

  • get_node:通过节点ID获取节点的数据。

  • get_relations:获取与给定节点ID相关的所有线条数据。

以下是一个使用例子:

g = graph()  \# 创建一个图

g.add_node('node1', "data1")  \# 添加一个节点

g.add_node('node2', "data2")  \# 添加另一个节点

g.add_line('node1', 'node2', "line_data")  \# 添加一个线条,连接node1和node2

print(g.get_node('node1'))  \# 输出 "data1"

print(g.get_relations('node1'))  \# 输出所有与node1有关的线条数据

__init__

这是一个名为 graph 的类,继承自 Model_Base,用于构建和操作图形结构。

graph 类的初始化方法 \_\_init\_\_ 中,定义了四个字典,分别用于存储节点数据、节点与线条关系、线条数据以及线条的起止节点。

  • self.save_variables['nodes'] 存储节点数据,键为节点的id,值为节点的数据。

  • self.save_variables['node_lines_index'] 存储节点与线条的关系,键为节点id,值为该节点连接的所有线条的id列表。

  • self.save_variables['lines'] 存储线条数据,键为线条id,值为线条的数据。

  • self.save_variables['lines_endpoint'] 存储线条的起止节点,键为线条id,值为一个元组,元素为线条的起始节点id和结束节点id。

示例:

g = graph()

g.add_node(1, "node1")  \# 添加一个节点,节点id为1,节点数据为"node1"

g.add_node(2, "node2")  \# 添加一个节点,节点id为2,节点数据为"node2"

g.add_line(1, 2, "line between node1 and node2")  \# 添加一条线,连接节点1和节点2,线条数据为"line between node1 and node2"

print(g.get_node(1))  \# 输出"node1"

print(g.get_relations(1))  \# 输出[(2, 'line between node1 and node2')]

注意:

当你添加一个已存在的节点ID时,程序将抛出异常,需要注意捕捉和处理。

__getitem__

这是一个特殊方法,它允许我们使用带有键的索引操作符 (例如:obj[key]) 来获取对象中的数据。在此类中,该方法被用来从图的节点中获取数据。

参数:

item: 键值,这里指图中节点的ID。

返回:

返回与给定键(图中节点的ID)关联的数据。

示例:

创建一个名为graph的图对象,然后添加一些节点:

graph_instance = graph()

graph_instance.add_node('node1', data='data1')

graph_instance.add_node('node2', data='data2')

然后,我们可以通过以下方式获取节点数据:

node_data = graph_instance['node1']

打印node_data会输出 'data1'。

__setitem__

该方法定义了如何为图中的节点(node)设置数据。

参数:

key : 节点的唯一标识符,通常为节点的id。

value : 需要存储在节点中的数据。

返回值:

示例:

# 创建一个graph对象

g = graph()

# 将数据"hello"存储在id为1的节点中

g[1] = "hello"

# 输出id为1的节点数据

print(g[1]) # 输出 "hello"

__contains__

该方法用于检查节点是否在图中。

参数:

item: 需要查询的节点。

返回:

如果节点存在于图中,返回True,否则返回False。

请注意:

在图中,键是节点的ID,值是此节点的数据。

这种方法对于检查节点是否存在在图中是有用的。

例子:

g = graph()

g.add_node('node1')

'node1' in g

True

'node2' in g

False

__iter__

这是一个迭代器方法,它允许我们在“graph”类对象上进行迭代操作。它会遍历保存在“nodes”字典中的每个节点,并返回它们的键值对。

在Python中,如果一个类定义了__iter__方法,那么它就可以被视为可迭代对象。当我们使用for循环来遍历这个对象时,for循环会自动调用这个__iter__方法,获取一个迭代器,并使用这个迭代器来遍历我们想要的数据。

参数列表:

返回类型:

generator,返回字典“nodes”中的键值对

代码示例:


g = graph()

g.add_node('node1', 'data1')

g.add_node('node2', 'data2')




**for node_key, node_value in g:**

print(node_key, node_value)

当我们运行以上代码时,会输出:


node1 data1

node2 data2

注意:无已知错误或bug。

add_node

该方法用于向图中添加节点。

参数:

id (str): 要添加的节点的唯一标识。

data (optional): 与节点相关的额外数据,默认为None。

返回:

None

异常:

Exception: 如果节点id已经存在于图中,将会抛出异常。

示例:

g = graph()

g.add_node('node1', data={'name': 'node1', 'value': 1})

assert 'node1' in g

注意:

请确保每个节点的id都是唯一的,否则添加节点时将会抛出异常。

add_line

此方法用于添加一条连接两个节点的线路,并可以将数据附加到该线路上。

参数:

node_id1:需要连接的第一个节点的标识符。

node_id2:需要连接的第二个节点的标识符。

data:可选参数,默认为None。这将作为附加数据添加到线路上。

返回:

此方法没有返回值。

示例:

graph.add_line('node_1', 'node_2', 'line_data')

注意事项:

如果两个节点已经通过线路连接,该方法将重新添加新的线路。每条线路的名称都是唯一的,即使它们连接的是相同的节点。

get_node

该函数的目的是通过键来获取图中的节点。

参数:

key: 节点的键。

返回:

返回与给定键对应的节点。

示例:

graph_instance = graph()

graph_instance.add_node('node_1', data='This is node 1')

node = graph_instance.get_node('node_1')

print(node) # 输出: This is node 1

注意:

如果给定的键不存在于图中,将会引发KeyError异常。

get_relations

此方法用于获取指定节点的关联信息。包含与之相连的其他节点及其相应的线数据。

参数:

key: 一个字符串,代表要查询关系的节点id。

start_to: 一个布尔值,默认为True。当为True时,返回从给定节点出发的关系;当为False时,返回指向给定节点的关系。

返回:

返回一个元组列表,列表中的每个元组包含两个元素,第一个是关联节点的id,第二个是连接两个节点的线的数据。

例子:

假设我们有这样的一个图——节点A通过线1与节点B相连,通过线2与节点C相连。那么:

graph.get_relations('A') 会返回 [('B', data1), ('C', data2)]

graph.get_relations('A', start_to=False) 会返回空列表,因为没有线指向节点A。

注意:

如果给定的key在图中不存在,该函数将返回一个空列表。

mapping

mapping类是一个从Model_Base继承的子类,主要用于实现左右两个实体之间的映射。这个类主要含有三个方法:addleftright

'add'方法用于添加一个新的左实体到右实体的映射关系。这个映射关系被存储在'content'列表中,并且每个映射的索引会被添加到对应的左或者右实体的字典中。

'left'方法接受一个左实体作为参数,返回所有与这个左实体关联的右实体。

'right'方法接受一个右实体作为参数,返回所有与这个右实体关联的左实体。

一个例子如下:

\# 实例化mapping类

m = mapping()

\# 添加映射关系

m.add('a', '1')

m.add('b', '2')

m.add('a', '3')

\# 查询关联的实体

print(m.left('a'))  \# 输出:['1', '3']

print(m.right('2'))  \# 输出:['b']

无已知错误或者bug。注意:这个类不支持删除已经添加的映射关系,如果需要这个功能,请在子类中实现。

model_name

mapping类是Model_Base的子类,主要用于建立左右两个对象之间的映射关系。每个对象都包含一个字典属性和一个内容列表。

  • model_name方法是一个属性方法,返回类名"mapping"。

类的使用示例:

map_obj = mapping()

map_obj.add('a', '1')  \# 建立'a'和'1'的映射关系

map_obj.add('a', '2')  \# 建立'a'和'2'的映射关系

map_obj.add('b', '1')  \# 建立'b'和'1'的映射关系

print(map_obj.left('a'))  \# 输出['1', '2']

print(map_obj.right('1'))  \# 输出['a', 'b']

__init__

这是一个名为“mapping”的类的初始化函数,该类继承了Model_Base基类。该类主要用于建立和查询左右两个元素之间的映射关系。

在初始化过程中,首先调用基类的初始化函数,然后创建了三个用于保存映射关系的字典:'left_dic','right_dic'和'content'。

  • 'left_dic'用于保存左侧元素到右侧元素的映射关系,其形式为{左侧元素:[对应的右侧元素的索引]}

  • 'right_dic'用于保存右侧元素到左侧元素的映射关系,其形式为{右侧元素:[对应的左侧元素的索引]}

  • 'content'用于保存映射关系的内容,其形式为[[左侧元素,右侧元素]]

此外,这个函数没有明显的错误或者bug。

这是一个示例代码来说明如何使用这个类:


map = mapping()

map.add('a', '1')

map.add('b', '1')

map.add('a', '2')

print(map.left('a'))  \# 输出:['1', '2']

print(map.right('1'))  \# 输出:['a', 'b']

add

add 方法用于在映射模型类中添加新的映射关系。

参数:

left (任意类型): 左侧的映射元素。用作映射的键。

right (任意类型): 右侧的映射元素。用作映射的值。

返回值:

无返回值。

使用方法:

add 方法用于向映射模型中添加新的映射关系。例如:

mapping_obj = mapping()

mapping_obj.add('apple', '苹果')

在这个例子中,left 参数是 'apple',right 参数是 '苹果'。执行 add 方法后,在映射模型中添加了一个新的从 'apple' 到 '苹果' 的映射关系。

注意事项:

leftright 参数在映射模型中不存在时,会在映射模型的 'left_dic' 或 'right_dic' 属性中添加一个新的键,并将键对应的值设置为一个空列表。然后,为该列表添加包含映射关系的索引。

leftright 参数在映射模型中已存在时,只会在其对应的列表中添加新的映射关系索引。

left

这个函数用于获取与指定“left”对象相关联的所有“right”对象。

这个函数接受一个参数:

  • left: 我们想要查询的左边的对象。

返回值:

  • 如果left存在在left_dic字典中,函数会返回一个列表,其中包含与left关联的所有right对象。

  • 如果left不存在在left_dic字典中,函数会返回None。

举个例子,

假设我们有以下的映射关系:{'a': ['x', 'y'], 'b': ['y', 'z']},那么left('a')将返回['x', 'y']。

right

此方法用于获取映射关系中对应于给定右侧元素的左侧元素。

参数:

right: 需要查询映射关系的右侧元素。

返回:

如果右侧元素存在于映射关系中,则返回一个列表,其中包含对应于给定右侧元素的所有左侧元素;

如果右侧元素不存在于映射关系中,则返回None。

例如:

假设我们已经添加了如下映射关系: (1, 'a'), (2, 'a'), (3, 'b')

如果我们调用right('a'),那么将返回[1, 2]

如果我们调用right('b'),那么将返回[3]

如果我们调用right('c'),那么将返回None

错误或者bug:

没有发现错误或者bug。

lcs

此函数是用于计算两个字符串之间的Levenshtein距离,即将一个字符串转换为另一个字符串所需要的最少单字符编辑(插入、删除或替换)的数量。这种度量方式在许多领域都有应用,包括计算机科学中的信息检索和自然语言处理。

参数:

s1, s2: 需要计算Levenshtein距离的两个字符串。

返回:

返回计算得出的Levenshtein距离,归一化,即除以两个字符串中的最大长度,这样得出的值会在0和1之间,值越小表示两个字符串越相似。

示例:

s1 = "kitten"

s2 = "sitting"

lcs_distance = lcs(s1, s2)

print(lcs_distance) # 输出: 0.5714285714285714

注意:

此函数使用的是动态规划方法,时间复杂度为O(len(s1)*len(s2)),空间复杂度也为O(len(s1)*len(s2)),在处理大规模数据时需要注意。

match_2_string_list

这个函数的目的是找出列表2中与列表1中每个字符串最匹配的字符串的索引。匹配程度通过使用最长公共子序列(lcs)算法来评估。

参数:

list1: 一个由字符串组成的列表。将从列表2中寻找每个字符串的最佳匹配项。

list2: 一个由字符串组成的列表,我们将从这个列表中寻找与列表1中的字符串最匹配的字符串。

top_n: 整数。默认为1。表示返回匹配程度最高的前n个字符串的索引。

返回:

返回一个二维列表。列表的长度与list1一致。每个子列表都包含top_n个从list2中选出的与list1中对应字符串匹配程度最高的字符串的索引。

例如,如果list1 = ['abc', 'def'], list2 = ['abc', 'def', 'ghi', 'jkl'],top_n = 2,那么返回的结果可能是[[0, 1], [1, 0]]。

这意味着与'abc'最匹配的两个字符串在list2中的索引是0和1,与'def'最匹配的两个字符串在list2中的索引是1和0。

注意:

这个函数使用lcs(最长公共子序列)算法来评估字符串之间的匹配程度,因此在输入长度较大的字符串列表时,可能会花费较长时间。

此外,如果top_n的值设置得较大,且列表2的长度远大于top_n,可能会浪费一些计算资源,因为我们只关心匹配程度最高的top_n个结果。

示例:

list1 = ['abc', 'def']

list2 = ['abc', 'def', 'ghi', 'jkl']

top_n = 2

print(match_2_string_list(list1, list2, top_n))

结果:[[0, 1], [1, 0]]

lmc

该Python模块主要用于处理和管理文本数据,以及与OpenAI的聊天模型进行交互。它包含一系列类和函数,支持文本向量化,实现模型调用和控制,管理链式模型状态,获取词嵌入,以及处理API请求等功能。特别是,提供了与GPT-3.5和GPT-4进行文本交互的聊天模型类,以及用于自定义模型参数的细调完成模型类。

类成员:

name info
LLM_Plus LLM_Plus是LLM的子类,它增加了自定义功能,并重写了一些方法,主要用于调用和控制模型,支持使用缓存和代理。
Emb_Plus Emb_Plus类主要用于处理文本数据,将其转化为向量,并提供代币使用量管理与缓存功能。
lmc_linked_model 'lmc_linked_model'是一个用于管理和处理链式模型的各种状态、日志记录以及异常处理的类。
Openai_embedding Openai_embedding类用于通过OpenAI的API令牌获取词嵌入,并可将文本列表转换为浮点数列表。
ChatGLM ChatGLM类是一个从LLM_Plus基类继承的API请求类,用于实现与外部模型的交互。
OpenAI_Complete_Model OpenAI_Complete_Model类是用于与OpenAI对话模型交互的工具,包含构造请求、调用模型、解析响应等功能。
ChatGPT4 ChatGPT4类用于创建和管理OpenAI的GPT-4聊天模型实例,并通过API令牌实现模型交互。
ChatGPT3 这是一个继承自OpenAI模型的聊天模型类ChatGPT3,用于与Gpt-3.5进行交互生成文本。
FineTuned_Completion_Model 这是一个继承自OpenAI模型的细调完成模型类,用于自定义模型参数并设置默认的温度范围。

函数成员:

name info
update_prompt_folder 这是一个更新指定文件夹下文档内容的函数,处理完成后显示更新状态,遇错则抛出异常。

LLM_Plus

LLM_Plus是LLM的子类,主要目的是在LLM的基础上增加了一些自定义的功能,并对一些方法进行了重写或者添加。

这个类主要有以下几个属性:

  • proxy: 代理地址,默认为空。

  • model_name: 模型的名称。

  • use_buffer: 是否使用缓存,默认为False。

  • buffer_version: 缓冲版本,默认为-1。

  • call_dictionary_config: 调用字典配置。

  • price: 价格,是一个元组,默认为(0.0, 0.0)。

此外,LLM_Plus类还提供了以下几个方法:

  • \_\_init\_\_(self, model_name, version=-1, call_dict={}, price=(0.0, 0.0), proxy=None, **kwargs: Any): 构造函数,初始化LLM_Plus类的实例。

  • add_token_use(self, use: (int, int)): 添加token使用。

  • call_model(self, prompt, *args, **kwargs) -> Any: 抽象方法,需要子类重写,用于调用模型。

  • _call(self, *args: Any, **kwargs: Any) -> str: 私有方法,用于调用模型,并处理缓冲的情况。

  • _llm_type(self) -> str: 属性方法,返回模型名称。

使用示例:


llm_plus = LLM_Plus('my_model', 1.0, {'key': 'value'}, (0.1, 0.2), 'http://my_proxy.com')

llm_plus.add_token_use((10, 20))

result = llm_plus._call('my_prompt')

注意事项:

  • 如果想要使用缓存,需要将version参数设置为大于0的值。

  • call_model方法需要在子类中重写,否则会抛出NotImplementedError异常。

  • add_token_use方法在添加token使用时,会对使用量进行记录,并且会计算并记录使用的总成本。

  • _call方法在调用模型时,会先检查是否使用缓存,如果使用缓存且缓存中有对应的结果,则直接返回。否则,会调用call_model方法,并将结果存入缓存中。

  • 如果调用_call方法时传入的参数中包含call_dictionary_config中的键,则call_dictionary_config中的值会被覆盖。

__init__

初始化LLM_Plus类的实例。

LLM_Plus类是LLM类的子类,主要在LLM类的基础上增加了价格、使用缓存和代理等属性,以及对相应属性的操作方法。

使用示例:

llm_plus = LLM_Plus(model_name="example_model", version=1.0, call_dict={'key': 'value'}, price=(1.0, 2.0), proxy="http://127.0.0.1:8000")

参数列表:

model_name : str

模型名称。

version : int, 默认值为-1

模型的版本号。如果版本号大于0,则启用缓存。

call_dict : dict, 默认值为空字典

调用模型时的参数字典。

price : Tuple[float, float], 默认值为(0.0, 0.0)

模型的价格,为一个包含两个浮点数的元组。

proxy : str, 默认值为None

代理服务器的地址。

**kwargs : Any

存放其他参数的字典。

返回类型:

无返回值。

注意事项:

  1. 如果model_name为空或者version小于0,将不启用缓存。

  2. 当调用call_model方法时,需要确保call_dict中的参数正确,否则可能导致调用失败。

错误与异常:

  1. 如果指定的代理服务器无法连接,将导致网络调用失败。

  2. 如果price不是一个包含两个浮点数的元组,将引发TypeError。

add_token_use

此方法用于在调用模型后添加token使用量。

参数:

use: (int, int), 一个元组,其中包含两个整型值。其中第一个整型值表示输入的token使用量,第二个整型值表示输出的token使用量。

返回:

无返回值。

此方法首先获取当前日期,并以此构建一个特定的键名。如果该键名在内存中不存在值,则为其设置初始值0。并且为输入的token使用量和输出的token使用量设置初始价格。

之后,此方法会更新输入和输出的token使用量,并计算当前的成本。最后,该方法将会记录这些信息。

注意:此方法不会返回任何值,其主要目的是记录模型使用的token量和相应的成本。

call_model

该函数是一个抽象方法,需要在子类中实现。它的主要作用是调用模型。

参数:

prompt: 主要用于模型的提示语。

*args: 可变参数,可传入多个任意类型的参数。

**kwargs: 可变参数,可传入多个关键字参数。

返回类型:

Any: 返回任意类型的数据。

使用示例:

在子类中重写此方法,例如:

class MyModel(LLM_Plus):

**def call_model(self, prompt, *args, kwargs):

# 在这里实现模型的调用

result = my_model.predict(prompt)

return result

这样,当我们创建MyModel的实例并调用call_model方法时,就会调用我们定义的模型并返回预测结果。

注意:

由于这是一个抽象方法,所以在使用LLM_Plus类时,必须创建其子类并实现这个方法,否则将会引发TypeError。

_call

_call是一个私有方法,主要用于处理模型调用的请求。这个方法首先会检查是否使用缓存,如果使用并且缓存版本号与当前版本号相同,那么就会直接返回缓存的结果。如果不使用缓存或缓存版本号不同,那么就会重新构建关键字参数kwargs,并调用模型进行计算,然后返回计算结果。

参数:

  • *args: Any: 可变参数,用于传递给模型调用的参数。

  • **kwargs: Any: 可变关键字参数,用于传递给模型调用的关键字参数。

返回:

  • str: 模型调用的结果。

示例:

_call('input_text', key1='value1', key2='value2')

注意:

此方法是私有方法,一般不应该在类的外部直接调用。同时,因为此方法依赖于call_model抽象方法,所以在使用此类时,必须先实现call_model方法。

_llm_type

这是一个属性装饰器函数,用于返回LLM_Plus对象的model_name属性的值。此函数没有参数,并且返回一个字符串类型的值。

属性装饰器的功能是将一个方法变为只读的类属性,使其可以像访问属性一样调用方法,这个方法没有输入参数。

返回:

返回model_name属性的值,该值是一个字符串类型。

Emb_Plus

Emb_Plus类是一个嵌入类,它继承了Embeddings类。这个类用于处理文本数据,将文本转化为向量。它的主要功能包括通过指定的模型名和版本来初始化嵌入对象,添加代币使用量,嵌入字符串列表,嵌入文档,以及嵌入查询。

类变量:

  • proxy: 可选的代理设置,默认为空。

  • model_name: 可选的模型名称,默认为空。

  • use_buffer: 是否使用缓冲,默认为False。

  • buffer_version: 缓冲版本,默认为-1.0。

  • call_dictionary_config: 调用字典配置, 默认为空字典。

  • price: 嵌入的价格,默认为(0.0, 0.0)。

方法:

  • \_\_init\_\_: 初始化方法,用于设定模型名称、版本、调用字典、价格和代理等。

  • add_token_use: 添加代币使用量,根据模型名和当前日期来生成密钥名称,并将使用量添加到相应的密钥。

  • embed_str_list: 抽象方法,需要在子类中实现。接受一个字符串列表,返回一个嵌入向量的列表。

  • embed_documents: 接受一个文本列表,返回一个嵌入向量的列表。如果启用了缓冲,会从缓冲中获取数据;否则,会调用embed_str_list方法进行嵌入。

  • embed_query: 接受一个查询文本,返回一个嵌入向量。

使用例子:

embedder = Emb_Plus("model_name")

embedder.add_token_use((1, 2))

vector = embedder.embed_query("query text")

__init__

这是一个名为 Emb_Plus 的类,继承自 Embeddings 类。目的是实现特定的词嵌入功能。该类对父类进行了扩展,提供了额外的功能,如代理、价格、缓冲区版本、模型名称等。

类初始化方法 __init__ 的功能是初始化 Emb_Plus 实例的各项参数。

参数:

model_name: str, 模型名称。用于标识用户所使用的模型。

version: int, 默认为-1,表示缓冲版本号。当版本号大于0时,启用缓冲区。

call_dict: dict, 默认为空字典,表示调用字典配置。可以自定义传入的参数。

price: tuple, 默认为 (0.0, 0.0),表示使用代币的价格,包括进入和出去的价格。

proxy: str, 默认为 None,表示代理。如果需要使用代理则可以指定。

**kwargs: Any, 表示可以接受任意数量的额外参数。

返回:

示例:

emb_plus = Emb_Plus(model_name='model1', version=1, call_dict={'key': 'value'}, price=(1.0, 0.5), proxy='http://127.0.0.1:8080')

这个示例创建了一个 Emb_Plus 的实例,指定了模型名称、版本号、调用字典配置、价格和代理。

add_token_use

这个方法是Emb_Plus类的一部分,用于追踪和更新token的使用情况。

具体来说,它取得一个包含输入与输出token数量的元组,然后根据当前日期和模型名称生成一个键名,用于存储token使用的统计数据。

如果该键名在数据库中不存在,则会为该键创建初始值为0的数据。并且还会创建两个以"_use_in"和"_use_out"为后缀的键,用于跟踪输入和输出token的使用。

然后,方法会更新"_use_in"和"_use_out"的值,增加输入和输出token的数量。

最后,计算当前的成本,即每1000个输入token的价格加上每1000个输出token的价格,然后将计算结果写入日志中。

参数:

use: 一个元组,包含两个整数,分别表示输入和输出token的数量。

返回类型:

embed_str_list

此抽象方法用于嵌入字符串列表。

参数:

texts (Liststr): 要嵌入的字符串列表。

**kwargs: 其他可选参数。

返回:

List[List[float]]: 返回嵌入后的二维浮点数列表。

示例:

假设我们有一个实现了embed_str_list方法的Emb_Plus子类实例em。

texts = ["Hello World", "Machine Learning"]

em.embed_str_list(texts)

注意:

这是一个抽象方法,需要在子类中实现。

embed_documents

该函数的目标是嵌入给定的文档列表,并返回嵌入的结果。

参数:

texts: 文档列表,每个文档是一个字符串。

**kwargs: 这是一个额外的参数列表,这将被用来传递给embed_str_list方法,这是该类定义的一个抽象方法。

返回:

该函数返回一个列表,其中包含每个文档的嵌入结果。每个结果都是一个浮点数列表。

在函数的实现中,首先检查是否启用了缓冲区。如果启用了缓冲区,它将首先尝试从缓冲区中获取嵌入结果。如果在缓冲区中找到匹配项,并且其版本与当前的buffer_version匹配,那么将直接返回缓冲区中的值。

如果没有启用缓冲区,或者在缓冲区中没有找到匹配的项或版本不匹配,那么它将调用embed_str_list方法,使用texts和kwargs作为参数来获取嵌入结果。

在得到嵌入结果后,如果启用了缓冲区,它将将嵌入结果及其版本添加到缓冲区中,并立即刷新缓冲区。

最后,返回嵌入结果。

embed_query

该方法是将输入的文本进行向量化处理。

参数:

text (str): 需要进行向量化处理的输入文本。

返回:

List[float]: 返回处理后的向量,它是一个浮点数列表。

使用示例:

embed = Emb_Plus(...)

vector = embed.embed_query("需要转化的文本")

print(vector)

注意:

Output的维度取决于模型的类型和配置。例如,如果使用的是BERT模型,那么每个词的向量维度可能为768或1024等。

方法的主要步骤:

  1. 输入文本被封装为一个列表,即[text],然后传递给self.embed_documents方法进行处理。

  2. embed_documents方法返回的结果是一个列表的列表(即二维列表),这是因为它设计为处理多个文本输入。

  3. 但是embed_query只处理一个文本输入,所以它仅返回embed_documents返回的二维列表中的第一个元素(即一个一维向量)。

此外,如果已经缓存了text的向量化结果,并且缓存的版本和当前版本一致,embed_documents方法会直接返回缓存的结果,提高处理效率。

lmc_linked_model

lmc_linked_model是一个用于处理链式模型的类。

这个类主要通过将一系列的处理函数添加到不同的函数列表中,来实现对模型的初始化、正常处理、异常处理以及完成后的处理。这样,可以方便地处理模型的各种状态,并通过日志记录处理过程中的各种信息。

参数:

llm (LLM_Plus): 需要处理的模型。

属性:

llm (LLM_Plus): 需要处理的模型。

retry_count (int):重试次数。

invoke_times (int): 调用次数。

output_parser (BaseOutputParser): 输出解析器。

fix_output_parser (OutputFixingParser): 修复输出的解析器。

prompt_template (PromptTemplate): 提示模板。

init_func_list (list): 初始化函数列表。

norm_func_list (list): 正常处理函数列表。

completed_func_list (list): 完成处理函数列表。

exception_func_list (list): 异常处理函数列表。

方法:

在本类中,主要包含了一些设置方法,如设置重试次数、设置输出修复、设置调用次数、设置提示模板、设置输出解析器等。

另外,还包含了一些私有方法,用于处理函数列表的调用、处理日志状态等。

使用示例:

llm = LLM_Plus()  \# 假设这是一个已经初始化的模型

linked_model = lmc_linked_model(llm)

linked_model.set_retry(3)  \# 设置重试次数为3

linked_model.set_times(2)  \# 设置调用次数为2

\# 设置输出解析器

output_parser = BaseOutputParser()  \# 假设这是一个已经初始化的输出解析器

linked_model.set_output_parser(output_parser)

\# 设置提示模板

linked_model.set_prompt_template("Hello, {name}")

\# 使用该模型

result, logs = linked_model(prompt="Hello, World")

在这个示例中,我们首先创建了一个lmc_linked_model的实例,并设置了重试次数、调用次数、输出解析器和提示模板。然后,我们调用模型并得到结果和日志。

注意: 使用本类需要有一定的python编程基础和对正则表达式、异常处理等有一定了解。在使用过程中,请确保传入的模型和解析器等都是正确初始化的。

__init__

lmc_linked_model是一个类,用于处理和维护LLM_Plus模型的各种参数和功能。它可以记录和管理模型的调用次数,错误处理等。

\_\_init\_\_方法是该类的初始化方法,主要负责初始化各种属性和参数。

参数:

llm: LLM_Plus对象,需要处理的模型。

属性:

retry_count: 重试次数,默认为1,表示模型运行出错时,重试的次数。

invoke_times: 调用次数,默认为1,表示模型要调用的次数。

output_parser: 输出解析器,用于解析模型的输出结果,默认为None。

fix_output_parser: 修复输出解析器,用于处理模型的异常输出,默认为None。

prompt_template: 提示模板,用于生成模型的输入,默认为None。

init_func_list: 初始化函数列表,用于处理模型的初始化操作,默认为空列表[]。

norm_func_list: 正常函数列表,用于存储处理模型正常运行的函数,默认为空列表[]。

completed_func_list: 完成函数列表,用于处理模型运行完成后的操作,默认为空列表[]。

exception_func_list: 异常函数列表,用于处理模型运行出错的情况,默认为空列表[]。

在使用此类时,首先创建一个LLM_Plus对象,然后将其作为参数传递给lmc_linked_model的初始化方法来创建一个lmc_linked_model对象。然后可以通过这个对象来管理和操作模型,如设置重试次数、调用次数、解析器等。

例如:

llm = LLM_Plus()

lmc_model = lmc_linked_model(llm)

lmc_model.set_retry(3).set_times(5)

set_retry

设置重试次数。

此方法用于设定模型在调用过程中出现异常时的重试次数。

Args:

count: 一个整数,表示重试的次数。

Returns:

返回当前的lmc_linked_model实例,以支持链式调用。

示例:

假设我们有一个lmc_linked_model实例model,我们想要设置它的重试次数为3次,可以这样操作:

model.set_retry(3)

_fix_output

这个函数是用来修复输出的。将输出作为输入,尝试使用修复解析器来修复并输出结果。

参数:

output(str): 输出结果,需要被修复的字符串。

返回:

result : 返回通过fix_parser处理过的结果,如果处理过程中出现了异常,则返回None。

示例:

给定一个输出 "abc",假设我们有一个修复解析器,它将所有的"a"替换为"b",那么"_fix_output"函数将返回"bbc"。

注意:

如果在解析过程中抛出了异常,该函数会捕获异常并返回None,所以调用者需要对None做出正确的处理。

set_output_fix

这个函数的功能是设置输出修正。当output_parser被设置后,使用OutputFixingParser.from_llm生成一个修正输出的parser,并将其赋值给fix_output_parser。同时将一个名为add_fix的函数添加到exception_func_list列表中。这个add_fix函数的作用是使用fix_output_parser解析上一次的输出,并将解析结果返回。如果解析过程中抛出异常,函数将返回None。最后,函数返回自身以供链式操作。

参数:

无参数。

返回:

返回当前实例,支持链式操作。

举例:

model = LMC_Linked_Model(...)

model.set_output_fix()

注意事项:

在调用此函数前,必须要先设置output_parser,否则会抛出异常。

set_times

此方法用于设置模型调用的次数。

Args:

count (int): 指定的模型调用次数。

Returns:

self: 返回类的实例。

示例:

linked_model = lmc_linked_model(llm)

linked_model.set_times(5) # 设置模型调用次数为5次

注意:

输入的次数应为正整数,否则可能会导致程序错误。

set_prompt_template

这个方法是为了设置提示模板。提示模板是根据模板字符串生成的,其中大括号包围的部分会被替换为对应的输出值。

参数:

temp_str: 字符串类型,用于定义提示模板的模板字符串。其中的大括号部分(如"{something}")将被替换为对应的输出值。

返回:

返回当前对象自身,以便于链式调用。

使用示例:

lmc_linked_model_obj.set_prompt_template("The output is {output}")

注意事项:

请确保模板字符串中的大括号部分对应的key在模型的输出中确实存在,否则在模板字符串的填充过程中会出错。

set_output_parser

设置用于解析输出结果的解析器。

该函数主要用于设置解析器,用于提取模型输出结果中所需要的信息。通过这个函数,用户可以设置自定义的解析器,处理从模型返回的结果。

参数:

out_parser (BaseOutputParser): 用户自定义的解析器,需要继承自BaseOutputParser基类。

返回:

self,以便于链式调用其他方法。

示例:

# 创建一个自定义解析器

class MyOutputParser(BaseOutputParser):

def parse(self, output):

...

# 在模型中设置这个解析器

my_model.set_output_parser(MyOutputParser())

注意:

如果已经设置过解析器,再次调用该方法会引发异常。在这种情况下,需要先清除已有的解析器,再设置新的解析器。

set_enum_output_parser

此函数用于设置枚举输出解析器。该函数接收一个枚举对象或列表作为参数,之后根据参数类型创建并设置对应的枚举输出解析器。

参数:

enum_or_list (Enum或list): 用于创建输出解析器的枚举对象或列表。如果参数是枚举对象,则直接使用该枚举对象创建解析器。如果参数是列表,则首先将列表转换为枚举对象,然后使用新的枚举对象创建解析器。

返回:

None

使用示例:

# 使用枚举对象设置解析器

lmc_model.set_enum_output_parser(MyEnum)

# 使用列表设置解析器

lmc_model.set_enum_output_parser(['option1', 'option2', 'option3'])

注意:

如果enum_or_list既不是枚举对象也不是列表,则此函数会抛出异常。

set_dictionary_output_parser

这个函数是为了设置字典类型的输出解析,参数是一个键的元组列表,每个元组包含两部分:键的名称和键的描述。

函数会根据参数生成一系列的响应模式,然后基于这些响应模式生成一个结构化的输出解析器。

对于一个包含复杂字典类型的输出,这个函数可以帮助解析输出并将其转化成结构化的形式以方便进一步处理。

参数:

list_of_keys_tuple: 一个元组列表,每个元组包含两部分:键的名称和键的描述。

返回:

此函数没有返回值,但是会将新生成的解析器设置为self.output_parser。

示例:

set_dictionary_output_parser([("name", "姓名"), ("age", "年龄")])

这个样例代码会生成一个解析器,用于解析包含"name"和"age"两个字段的字典,"name"字段的描述是"姓名","age"字段的描述是"年龄"。

set_pydantic_output_parser

这个方法是为了设置pydantic的输出解析器。pydantic是一个数据验证库,可以用于从复杂的数据类型中解析数据。

参数:

pydantic_object: pydantic的模型对象,其实例的输出将被用作解析器的目标。

例子:

该方法与PydanticOutputParser类一起使用,接受一个pydantic模型对象,然后使用该模型对象的实例来解析输出。

**class UserModel(pydantic.BaseModel):**

name: str

age: int

lmc_model = lmc_linked_model(some_llm_model)

lmc_model.set_pydantic_output_parser(UserModel)

在上述示例中,用户创建了一个pydantic模型UserModel,这个模型有两个字段:name和age。然后在lmc_linked_model实例lmc_model中,调用set_pydantic_output_parser方法,将UserModel作为参数传入。这将设置lmc_model使用UserModel来解析其输出。

返回值:

该方法返回调用它的对象本身,以支持链式调用。

注意:

这个方法不会检查pydantic_object是否真的是pydantic的模型对象,如果传入非pydantic模型对象,可能会在运行时出现错误。

错误处理:

如果在运行时发生错误,通常是由于传入非pydantic模型对象或者模型对象的结构与输出数据不匹配引起的,这种情况下,错误消息将指向具体的问题,用户需要根据错误消息进行调整。

log_state

这个函数是lmc_linked_model类的一个方法,用于记录当前llm对象的状态信息。

函数首先输出一条Summary:的日志,然后遍历llm对象的状态state(),并将每个状态信息以name: value的形式记录到日志中。

函数没有参数,也没有返回值。

注意,函数在记录状态信息时,把状态信息都转化为字符串形式,即使状态信息的实际类型不是字符串。

_invoke_list_func

_invoke_list_func是一个私有的实例方法,它用于调用一系列给定的函数,并记录每个函数的输出和日志。

参数:

ori_input: 原始输入,用作每个函数的第一个参数。

history_output_list: 历史输出列表,用于存储每个函数的输出。每个函数的第二个参数将是列表中的最后一个输出。

log_list: 日志列表,用于存储每个函数的日志。

func_list: 函数列表,这是要被调用的函数列表。

返回:

如果所有函数都成功执行,则返回None。否则,返回引发异常的第一个函数的异常实例。

注意:本函数不会捕获并处理函数引发的异常,而是直接将异常返回。

示例:

def add_one(ori_input, last_output, log):

log['operation'] = 'add one'

return last_output + 1

def multiply_two(ori_input, last_output, log):

log['operation'] = 'multiply two'

return last_output * 2

history_output_list = [1]

log_list = []

_invoke_list_func(1, history_output_list, log_list, [add_one, multiply_two]) -> None

assert history_output_list == [1, 2, 4]

assert log_list == [{'operation': 'add one'}, {'operation': 'multiply two'}]

__call__

这个类的主要目标是对模型进行调用并处理输出。我们可以设置不同的输出解析器、重试次数、调用次数以及提示模板。此外,我们还可以在发生异常时添加修复函数。在每次调用模型之后,我们都会记录日志。

其中\_\_call\_\_函数是一个特殊的魔术方法,它使类的实例可以像函数一样被调用。

例如:

lmc = lmc_linked_model(llm)

lmc(prompt='hello world')

函数介绍:

__call__函数接收一个prompt参数和任意数量的关键字参数。它首先调用初始化函数列表处理输入,然后根据设置的次数和重试次数调用模型,并使用正常函数列表和异常函数列表处理模型的输出。最后,它会调用完成函数列表处理最终的输出。

参数列表:

  • prompt:可选参数,初始提示信息。

  • **kwargs:任意数量的关键字参数。

返回类型介绍:

这个函数返回两个列表,第一个列表是所有模型的输出,第二个列表是所有调用的日志。

注意,如果在调用模型或处理输出时发生异常,这个函数会直接抛出异常。

如果你需要在异常时处理输出,你可以添加异常函数。例如:

**def fix_output(ori_prompt, last_output, log):**

\# 修复输出的代码

return fixed_output

lmc.exception_func_list.append(fix_output)

然后,你可以像之前那样调用lmc,如果在调用模型或处理输出时发生异常,它会自动调用fix_output函数来修复输出。

Openai_embedding

Openai_embedding是一个继承自Emb_Plus的类,用于获取OpenAI的词嵌入。这个类需要通过使用OpenAI提供的API令牌才能工作,并且可以通过代理服务器进行网络连接。这个类的主要功能是将文本列表嵌入到浮点数列表中。

类的使用方法如下:

  1. 实例化类时,需要提供OpenAI的API令牌

  2. 如果需要通过代理服务器进行网络连接,还需要提供代理信息

  3. 使用embed_str_list方法,可以将文本列表嵌入到浮点数列表中

例子:

openai_emb = Openai_embedding(apitoken="your_openai_api_token")

texts = ["Hello world", "I love python"]

embeddings = openai_emb.embed_str_list(texts)

参数:

  • texts:需要嵌入的文本列表,类型为列表,列表的元素是字符串

  • apitoken:OpenAI的API令牌,类型为字符串

  • proxy:代理服务器信息,类型为字符串,格式为"ip:port",默认为None

  • **kwargs:其他参数,可以是任何类型

返回:

  • embed_str_list方法,返回一个列表,列表的元素是浮点数列表,其中每个浮点数列表代表一个文本的嵌入

注意:

  • 在使用该类时,需要确保你有一个有效的OpenAI API令牌,并且该令牌有权限访问OpenAI的词嵌入功能

  • 在使用代理服务器时,需要确保代理服务器可以正常访问OpenAI的服务器

embed_str_list

这个方法的主要用途是将一系列的文本通过OpenAI的词嵌入模型进行编码。

Args:

texts (Liststr): 这是一个字符串列表,每个元素是一个独立的文本,需要进行词嵌入编码的文本。

**kwargs: 可变长度的关键字参数,用于接收未知数量的关键字参数。

Returns:

List[List[float]]: 返回的是一个嵌套列表,外层列表的每个元素对应输入的每个文本,内层列表则包含了该文本通过词嵌入模型编码后得到的浮点数值。

示例:

假设我们有一个文本列表texts = ["hello", "world"],我们可以这样使用这个方法:

emb = Openai_embedding(apitoken="your_api_token")

embeddings = emb.embed_str_list(texts)

embeddings将会是一个包含两个子列表的列表,每个子列表都是一个浮点数的列表,这代表了"hello"和"world"这两个文本的词嵌入表示。

__init__

这是一个初始化 Openai_embedding 类的构造函数。

该函数主要用于初始化 Openai_embedding 类的实例对象。其中,Openai_embedding 类继承自 Emb_Plus 类,主要用于处理和管理 OpenAI 的嵌入向量。

参数:

apitoken (str): 用于访问 OpenAI API 的密钥。

**kwargs: 用于 Emb_Plus 类的其他关键字参数。

示例:

embedding = Openai_embedding(apitoken="your_openai_api_token")

text = "This is a test sentence."

vector = embedding.embed_str_list([text])

update_prompt_folder

这个函数的主要功能是更新指定根文件夹下的文档。它首先读取所有内容,并将其存储在名为all_content的字典中。然后,该函数通过遍历语言和内容来更新内容,找到最新的版本,并根据需要创建更新任务。

在更新完成后,该函数使用进度条显示更新的状态,并执行更新任务。如果在执行过程中发生错误,程序将抛出异常。

参数:

root_folder: str,需要更新的文件夹路径

llm: 一个可选参数,默认为None。如果没有提供,那么函数将使用get_init_llm()来初始化。

返回:

无返回值

错误:

如果翻译失败,函数将抛出异常。

使用示例:

假设我们有一个名为"prompts"的文件夹,我们希望更新其中的内容,我们可以按照以下方式使用此函数:

update_prompt_folder("prompts")

ChatGLM

这是一个名为ChatGLM的类,该类继承自基类LLM_Plus。该类主要用于实现和外部模型交互的API请求功能。

主要包含三个方法:\_\_init\_\__postcall_model

\_\_init\_\_方法用于初始化ChatGLM类的实例。需要传入一个url字符串参数(用于API调用的url)和**kwargs参数(用于传入模型的其他参数)。

_post方法用于向指定的url发送post请求。传入的参数包括一个url字符串和一个包含请求内容的字典。

call_model方法用于调用模型。传入的参数包括一个prompt字符串(作为模型的输入)和其他可选的参数。该方法会对模型的返回结果进行处理,如果请求成功返回模型的预测结果,否则返回错误信息。

类的使用示例:

chat_glm = ChatGLM("your_url")

prompt = "你好,世界"

result = chat_glm.call_model(prompt)

\# 返回模型的预测结果或错误信息

print(result)

注意:响应时间取决于模型的处理速度和网络状况,请在使用时确保网络通畅,以便及时获取请求结果。

__init__

这是ChatGLM类的初始化函数,用于初始化该类的实例。

本方法首先通过调用父类的初始化方法,初始化LLM_Plus类,并设置模型的名称为"GLM6B";

然后,设置类变量gurl的值为参数_url,该变量将被用于后续的网络请求。

Args:

_url (str): 用于网络请求的URL地址

**kwargs (Any): 可接收任何关键字参数,这些参数会被传递给父类LLM_Plus的初始化方法

例子:


chat_glm = ChatGLM(_url="http://example.com", token="my_token")

在此例子中,我们创建了一个ChatGLM的实例,_url参数设置为"http://example.com",

并且通过kwargs传递了一个名为token的参数值为"my_token"到父类LLM_Plus的初始化方法中。

_post

这是一个私有的_post方法,用于向服务器发送POST请求。

参数:

url: str, 服务器的URL地址。

query: Dict, 要发送的数据,以字典形式存在。

返回:

Any, 返回服务器的响应。

使用方法:

_post方法通常在类的内部使用,作为向服务器发送请求的工具函数。该函数使用了requests库的session对象进行网络请求,

在请求过程中,设置了请求头为"Content_Type": "application/json",并对请求进行了60秒的超时设置。

在请求成功后,该函数返回服务器的响应。

例如:

假设我们有一个名为'query'的字典,包含我们要发送的数据。我们可以这样调用_post方法:

response = self._post(url="http://example.com", query=query)

注意:

由于这是一个私有方法,所以通常只在类的内部使用。在类的外部调用可能会引发错误。

call_model

这个方法是ChatGLM类的一部分,用于调用模型并获取预测结果。

参数:

prompt (str): 输入的提示,模型将根据该提示生成预测结果。

*args: 变长参数,根据需要使用。

**kwargs: 变长关键字参数,可以传递任意数量的关键字参数。

返回:

predictions (Any): 如果请求成功(HTTP状态码为200),则返回模型的预测结果;否则返回错误提示信息"请求模型Error"。

使用示例:

glm = ChatGLM(_url='http://localhost:8000')

prompt = "你好"

print(glm.call_model(prompt))

OpenAI_Complete_Model

OpenAI_Complete_Model类是一个继承自LLM_Plus的类,用于实现与OpenAI对话模型的交互。在初始化时,它会根据提供的api token和模型名称初始化OpenAI客户端。这个类主要包含五个方法:_construct_query_invoke_model_parse_invoke_resultcall_modeladd_token_use

类初始化方法\_\_init\_\_:

- 参数:

  • apitoken(str): OpenAI的API认证令牌。

  • model_name(str): OpenAI模型的名称。

  • price(float): 调用模型的价格。

  • **kwargs(dict): 其他任意的关键字参数。

  • 返回: None

方法_construct_query:

  • 功能: 构造一个查询请求,用于进一步向模型发送。

- 参数:

  • prompt(str): 用户给模型的提示或问题。

  • 返回: 构造好的查询请求列表。

方法_invoke_model:

  • 功能: 使用OpenAI客户端调用聊天模型,并返回响应。

- 参数:

  • prompt(str): 用户给模型的提示或问题。

  • 返回: OpenAI聊天模型的响应。

方法_parse_invoke_result:

  • 功能: 解析模型响应,获取并返回模型的回答,并记录消耗的token数。

- 参数:

  • response(dict): OpenAI聊天模型的响应。

  • 返回: 模型的回答。

方法call_model:

  • 功能: 调用上述三个方法,完成从构造请求到获取模型回答的整个过程。

- 参数:

  • prompt(str): 用户给模型的提示或问题。

  • *args(tuple): 其他任意位置参数。

  • **kwargs(dict): 其他任意的关键字参数。

  • 返回: 模型的回答。

使用例子:

model = OpenAI_Complete_Model(token, 'text-davinci-002', 0.06)

prompt = 'Translate the following English text to French: {}'

result = model.call_model(prompt.format('Hello, World!'))

print(result)

__init__

这是一个初始化OpenAI_Complete_Model类的方法。该类继承自LLM_Plus类,用于与OpenAI API进行交互,获取模型预测的结果。

初始化方法需要用户提供API的token,模型名称,以及模型的价格。

如果用户希望使用代理,可以通过关键字参数proxy来设置。

参数:

apitoken: OpenAI平台的API token,类型为字符串,用于API调用的身份验证。

model_name: OpenAI平台的模型名称,类型为字符串,指定调用哪个模型。

price: 模型的价格,类型为数字,用于计算使用模型的费用。

**kwargs: 任意额外的关键字参数,可能包括代理设置,传给父类LLM_Plus的初始化方法。

返回:

无返回值。

示例:

model = OpenAI_Complete_Model('API_TOKEN', 'gpt-3', 0.06, proxy='http://localhost:8080')

result = model.call_model('Hello, World!')

注意事项:

在使用代理时,需要保证代理的可用性和安全性,否则可能会影响API的调用和结果。

_construct_query

这个方法是用于构建查询的。在OpenAI Complete模型中,查询是以一个列表的形式存在的,列表中的元素是一个字典,键为'role'和'content'。'role'是一个字符串,表示发送消息的角色,这里是'user','content'是一个字符串,表示用户输入的提示。

参数:

prompt: str类型,表示用户输入的提示。

返回:

返回一个列表,列表中的元素是一个字典,键为'role'和'content'。

示例:

**def _construct_query(self, "你好"):**

\# 返回: [{"role": "user", "content": "你好"}]

_invoke_model

该函数用于调用模型并获取响应。

参数:

prompt: str类型,传入的用户提示信息。

返回:

返回从OpenAI接口获取的响应结果,通常是模型生成的文本结果。

在此函数中,我们使用了OpenAI的chat.completions.create接口来调用我们的模型。我们将用户的提示信息(prompt)传入模型,并将模型的响应结果返回。返回的结果将在后续的_parse_invoke_result函数中进行解析。

_parse_invoke_result

此函数的目的是解析模型调用的响应,并从响应中抽取所需的信息。

该函数首先从响应中获取答案内容。接着,它获取输入(prompt)和补全所用的令牌(token)数量。最后,它会添加令牌的使用情况并返回答案。

参数:

response: OpenAI的模型调用响应。它是一个包含模型生成的文本、令牌的数量等信息的对象。

返回:

返回从响应中获取的答案内容,它是一个字符串。

注意:

这个函数没有错误处理机制,如果响应的结构与预期不符,可能会引发异常。例如,如果响应中没有"choices"键,将无法获取到答案内容。

call_model

该函数是模型类OpenAI_Complete_Model的一个方法,用于调用模型并返回模型的输出结果。

参数:

prompt (str): 用户的输入提示,模型将基于此提示生成相应的回答或完成相应的任务。

*args: 可变参数,根据具体需要传入。

**kwargs: 关键字参数,根据具体需要传入。

返回:

Any: 返回模型生成的回答或完成任务的结果。

用法示例:

model = OpenAI_Complete_Model(apitoken="your_api_token", model_name="gpt-3", price=0.05)

result = model.call_model(prompt="Translate the following English text to French: '{}'", *args, **kwargs)

注意:

在使用该函数时,需要确保已经正确设置了OpenAI的API密钥,并且已经选择了正确的模型。

ChatGPT4

ChatGPT4是一个继承自OpenAI_Complete_Model的类,用于创建并管理OpenAI的GPT-4聊天模型的实例。

这个类的主要目的是使用OpenAI的API,利用提供的API令牌,实现与GPT-4聊天模型的交互。

示例:

\# 使用API令牌初始化ChatGPT4实例

chatgpt = ChatGPT4(apitoken='your_openai_api_token')

\# 使用ChatGPT4实例进行一些操作,例如生成文本

generated_text = chatgpt.generate_text(input_text='Hello, world!')

参数:

  • apitoken: OpenAI的API令牌,是一个字符串,用于进行身份验证和API访问。

  • kwargs: 其他可选参数,可以传递给OpenAI_Complete_Model的初始化方法。

注意:

  • 请确保你的OpenAI API令牌是有效的,否则将无法使用GPT-4模型。

  • 这个类没有明确的返回类型,它的主要作用是创建和管理GPT-4模型的实例。

__init__

初始化ChatGPT4类的实例。

这个类是OpenAI_Complete_Model的子类,用于创建和管理GPT-4模型的实例。通过这个类,我们可以方便地调用和使用OpenAI的GPT-4模型进行各种任务。这个类在初始化时需要传入OpenAI的API令牌,这样才能正确地使用模型。

参数:

apitoken (str): OpenAI的API令牌,用于验证用户身份和调用模型。

**kwargs: 任意关键字参数,这些参数将直接传递给OpenAI_Complete_Model的构造函数。

例子:

model = ChatGPT4('YOUR_OPENAI_TOKEN')

output = model.generate_prompt('Hello, world')

注意:

请确保你的OpenAI API令牌是正确的,错误的令牌可能会导致无法调用模型。

当前版本的类并不支持修改GPT-4模型的配置,模型的temperature和max tokens是固定的。

ChatGPT3

这是一个继承自OpenAI_Complete_Model的聊天模型类ChatGPT3。主要用于实现和Gpt-3.5的交互,包括生成文本等。

参数:

apitoken: API访问密钥。用于验证和建立与OpenAI模型的连接。

**kwargs: 可以接受任意关键字参数。这些参数将传递给父类。

使用示例:

apitoken = "你的API密钥"

model = ChatGPT3(apitoken)

generated_text = model.generate("你想说的话")

注意:

  • 必须要有API访问密钥才能使用这个模型。

  • **kwargs 的参数将会传递给父类,具体取决于父类如何处理这些参数。

__init__

初始化ChatGPT3类。

此类是OpenAI_Complete_Model的子类,用于创建ChatGPT3对象。

ChatGPT3类实例化后,将创建一个与GPT-3.5-turbo模型的连接。

参数:

apitoken(str): OpenAI API的令牌。

kwargs(dict, optional): 可选参数,用于控制模型的具体行为。可能包含例如temperature、max_tokens等参数。

返回:

None

例子:

chatgpt = ChatGPT3(apitoken="your_api_token")

response = chatgpt.generate(prompt="Hello, world!")

注意:

此类需要有效的OpenAI API令牌才能使用。

FineTuned_Completion_Model

这是一个细调完成模型类,它继承自OpenAI_Complete_Model类。

细调完成模型类主要用于自定义OpenAI的模型参数。它的构造函数需要两个参数:模型ID和API令牌。在初始化时,它将模型的ID和API令牌传递给超类,同时设置模型的温度范围为0.03到0.06。

使用示例:


model = FineTuned_Completion_Model('text-davinci-002', 'my-api-token')

参数:

  • model_id: 一个字符串,表示OpenAI模型的ID

  • apitoken: 一个字符串,表示API的令牌

  • **kwargs: 任意数量的关键字参数

注意:尽管这个类已经设置了模型的温度范围,但是你仍然可以通过传入关键字参数来自定义设置。

注意:这个类没有明显的错误或bug,但是在使用时需要注意API的令牌安全。

请确保你的API令牌是正确且安全的,否则可能会导致无法访问模型的错误。

__init__

这是FineTuned_Completion_Model类的构造函数, 这个类是OpenAI_Complete_Model的子类, 用于实现微调模型的功能。

参数:

model_id: 用于微调的模型的ID

apitoken: 连接OpenAI API的令牌

**kwargs: 任意数量的关键字参数, 这些参数将传递给父类的构造函数。

返回:

无返回值

使用示例:

model = FineTuned_Completion_Model(model_id="text-davinci-001", apitoken="my-token", temperature=0.5)

注意:

我们在这里假设OpenAI_Complete_Model类的构造函数接受模型ID、API令牌和一个浮点数元组作为参数,如果实际情况并非如此,请根据实际情况进行修改。

tasks

这个Python模块主要用于代码翻译、审查和注释生成。它可以初始化并返回指定类型的语言模型,支持原文本的语言翻译。通过递归文本拆分和语言模型链接,可以对指定路径下的文件进行代码审查并生成审查报告。而且,它还包括一个能够自动为Python源码文件生成注释的类和命令行工具,以提高代码的可读性和可维护性。

类成员:

name info
comment_creator 'comment_creator'是一个自动为Python源码文件生成注释的类,以提高代码的可读性和可维护性。

函数成员:

name info
get_init_llm 这个函数用于根据配置初始化并返回"gpt4"或"glm"类型的llm模型实例。
translate 这是一个使用lmc_linked_model模型,将原文本翻译成指定语言的函数。
codereview 这是一个通过递归文本拆分和语言模型链接,对指定路径下的文件进行代码审查并生成审查报告的函数。
comment_creator_cmd 这个函数用于创建一个命令行工具,该工具可以为指定的文件添加注释。

get_init_llm

这个函数的主要功能是初始化llm模型。它首先从Service_Shelve类实例中获取配置,提取出llm模型的类型,然后根据llm模型的类型分别初始化对应的模型。目前支持的模型类型有"gpt4"和"glm"。

参数:

返回:

返回初始化后的llm模型实例。

错误处理:

如果配置中的llm模型类型不在支持的类型列表中("gpt4", "glm"),则会抛出异常。

示例:

llm = get_init_llm()

注意:

在使用这个函数之前,需要确保已经正确配置了llm模型的类型。

translate

此函数是一个翻译函数,通过调用lmc_linked_model模型进行语言翻译。

参数:

lang (str): 目标语言名称,以字符串形式。

ori_text (str): 需要被翻译的原文本,以字符串形式。

tllm (model, 可选): 用于翻译的模型。如果没有提供,函数将会使用get_init_llm()函数来获取初始模型。默认值为None。

返回:

str: 翻译后的文本。如果无法进行翻译,将返回None。

错误或bug:

如果提供的原始文本为空,或者目标语言无法识别,函数可能无法正常工作。此外,如果lmc_linked_model模型无法正常工作,也可能导致函数错误。

使用示例:

translate("zh", "Hello, world!") # 返回 "你好,世界!"

translate("fr", "Hello, world!") # 返回 "Bonjour, monde !"

codereview

这是一个通过路径、过滤器和附加标记来进行代码审查的函数。

函数首先检查过滤器和附加标记,如果没有设置或为空,则将其设置为默认值。然后,它会获取指定路径下所有的文件,并通过使用递归文本拆分器,将每个文件的内容拆分为大小为1000的块。

对于每个文件,如果文件内容为空,则跳过该文件并继续处理下一个文件。否则,函数会将文件路径和文件长度打印到日志中。

然后,该函数会逐个处理拆分后的文本块,对其进行语言模型链接,生成结果和日志。如果结果为空,则在日志中打印错误信息。否则,根据索引值将结果写入到文件中。

注意,对于每个文件,第一个放入结果的文件名将附加给定的附加标记,而其他的文件名将添加索引号和附加标记。

参数:

path: str. 文件路径。处理该路径下所有的文件。

filter: str,默认为'*'. 文件名过滤器,用于选择需要处理的文件。默认处理所有文件。

addition_mark: str, 默认为'.report'. 用于标记处理完成的文件的附加标记。

无返回值,但会生成表示处理结果的文件,文件名格式为“原始文件名+附加标记”。

错误或异常:

如果文件路径不存在或无法访问,该函数可能会抛出异常。

如果在处理文件时遇到错误,如读取文件失败或者写入文件失败,该函数可能会抛出异常。

示例:

codereview('/path/to/files', '*.txt', '.report')

comment_creator

这个类名为comment_creator,主要用于为Python源码文件添加自动生成的注释。

目的:

为了提高代码的可读性和可维护性,该类可以自动分析Python源码文件,并根据代码结构和内容生成对应的注释。

使用例子:

llm = LLM_Plus(...)

path = "./src"

cc = comment_creator(llm, path)

cc.fill_comment()

类方法:

  • \_\_init\_\_(self, llm: LLM_Plus, path): 初始化,接收一个LLM_Plus对象和Python源码文件路径,准备用于后续代码分析生成注释。

  • _parse_py_file(self, lines): 用于解析Python文件,通过ast模块获取代码的抽象语法树并进行分析。

  • add_l1_coment(self, l1_list, l1_mark, l2_list, l2_mark, m_obj, m1, m2): 用于生成对应的注释,需要输入当前处理的行等信息。

  • fill_comment(self): 对初始化时指定的路径下的所有Python文件进行遍历,自动添加注释。

注意:该类依赖LLM_Plus对象生成注释,LLM_Plus需预先训练好。并且注释生成可能不完全准确,需要根据实际代码内容进行修改调整。

待修复的问题:暂无已知bug。

__init__

这个类的名称是comment_creator,其主要目的是对Python源代码文件进行解析,并自动添加注释。

类的初始化函数需要两个参数:

  • llm (LLM_Plus):一个LLM_Plus类型的对象,用于调用其相关功能对源代码进行解析。

  • path (str):需要解析的Python源文件的路径。

初始化函数主要进行以下几个操作:

  • 获取并保存输入参数

  • 枚举path路径下所有的文件,并存储到all_file_task属性中

  • 创建一个进度条对象p_bar,用于后续的任务进度显示

  • 获取注释生成模型的提示字符串(模型的语言设置为中文),并创建一个lmc_linked_model对象l1_parser,设置其模型提示模板以及重试次数。

注意:

  • 在使用本类时,请确保提供的LLM_Plus对象和文件路径是有效的,不然可能会导致错误。

  • 本类尚未处理所有可能的错误情况和异常输入,因此在使用时请保证输入的参数正确性。

_parse_py_file

_parse_py_filecomment_creator类中的一个私有方法,这个方法的主要目的是解析Python文件,生成一个标记列表,标记列表中的每个元素表示原始代码行的类型。

参数:

lines (Liststr): 原始代码行的列表。

返回:

Liststr: 表示代码行类型的标记列表。可能的标记包括"class_start", "class", "func_start", "func", "comment"等等。如果某一行是类定义或函数定义的开始,该行的标记会在行号后加上"_start"。如果某一行是类定义或函数定义内部的一部分,该行的标记会是"class"或"func"。如果某一行是注释行,该行的标记会是"comment"。

使用方法:

该函数是内部使用的,通常不会直接调用。它被fill_comment方法调用,用于解析Python文件并生成标记列表,这些标记随后用于决定如何添加新的注释。

注意:

该函数使用Python的AST(Abstract Syntax Tree)模块进行源代码解析。因此,如果原始代码存在语法错误,该函数可能无法正确运行。

示例代码:


lines = read_file_lines("test.py")

line_marks = self._parse_py_file(lines)

可能的错误:

如果输入的原始代码存在语法错误,AST的解析过程可能会失败,导致函数无法正确运行。

_add_memeber_summary

这是一个为类成员添加概要信息的方法。

参数:

lines: 类型为list,包含了要解析的源代码行。

返回:

字符串类型, 返回从源代码行解析得到的类成员概要信息。

此函数的主要作用是分析源代码中的类成员,并为它们添加概要信息。它首先检查缓存中是否已存在此类成员的概要信息,如果存在,则直接从缓存中获取。如果不存在,则使用预定义的解析器从源代码行中提取信息,并将提取的信息添加到缓存中,最后从缓存中获取概要信息并返回。

注意:

使用此函数时,应确保传入的源代码行是有效的Python代码,否则可能会导致解析错误。

_add_module_summary

这是一个私有方法,其主要功能是为模块生成摘要。

参数:

lines: list, 需要生成摘要的模块的行列表。

返回:

str, 生成的模块摘要。

该函数主要通过以下步骤实现其功能:

  1. 首先,对输入的模块行列表进行清洗,去除两端空白字符,并将其合并为一个字符串,作为键值。

  2. 然后,检查该键值是否已经存在于缓存中。如果存在,直接从缓存中获取摘要;如果不存在,就使用summary_2_parser进行解析,获取摘要,并将其缓存起来。

  3. 最后,返回生成的摘要。

注意:该函数没有处理错误和异常的代码,因此调用者需要自行处理可能出现的错误和异常。

示例代码:

lines = ['def func1():', '    pass', 'def func2():', '    pass']

summary = _add_module_summary(lines)

print(summary)

add_l1_coment

这个函数的主要目的是在给定的python代码块中添加一级注释。

参数:

l1_list: list, 一级代码块,通常为类定义

l1_mark: list, 与l1_list长度相同,标记l1_list中的每一行代码

l2_list: list, 二级代码块,通常为函数定义

l2_mark: list, 与l2_list长度相同,标记l2_list中的每一行代码

m_obj: list, 主要的python代码块

m1: str, 主模块名

m2: str, 子模块名

返回值:

str, 生成的一级注释

这个函数首先判断l1_list和l2_list中哪一个是主要的代码块,然后根据主要的代码块生成一级注释的关键字key,

并检查是否已经为这个key生成过注释。如果已经生成过,就直接从缓存中取出并返回;如果没有生成过,就先移除主要代码块中的注释,

然后调用l1_parser生成一级注释,并添加到缓存中,最后返回生成的一级注释。

注意:

这个函数可能会抛出"list error."异常,当l1_list和l2_list都为空时会抛出这个异常,这是因为至少需要一个主要的代码块来生成注释。

fill_comment

fill_comment方法用于给Python代码中的类和函数添加一级注释。

这个方法首先创建了一个空的字典用于存储注释信息,然后遍历了所有的文件。对于每个Python文件,

这个方法会读取文件的所有行,然后使用_parse_py_file方法解析这些行并标记它们。

对于每个被标记为"class_start"或"func_start"的行,方法会创建一个缓冲区并尝试用add_l1_coment方法为其添加注释。

然后,这个方法会将新的注释添加到新行列表中,并在完成文件遍历后,将这些新行写回到原文件中。

参数:

返回:

注意:

  1. 这个方法不支持对嵌套的类和函数添加注释。

  2. 这个方法会直接修改原始文件,可能会导致原始代码丢失。在使用这个方法前,建议先备份原始文件。

extract_name

该函数的主要目的是从给定的代码字符串中提取类名或函数名。

参数:

code (str): 需要用来提取名字的代码字符串。

返回:

str: 如果找到匹配的类名或函数名则返回该名字,否则返回None。

使用示例:

name = extract_name('class MyClass:')

print(name) # 输出: MyClass

注意:

  • 只有当代码字符串以'class'或'def'开头时,才会尝试提取名字。

  • 如果代码字符串不符合预期格式,可能无法正确提取名字。

异常:

  • 无特别异常处理,如果代码字符串格式错误,可能无法返回预期结果。

该函数使用正则表达式匹配,可能对性能有一定影响,如果处理大量代码字符串,建议优化或者寻找其他方法。

_add_style_comment_content

此函数的主要目的是将指定的代码行添加到Markdown文档中以创建样式化的注释内容。

参数:

  • lines (Liststr): 代表源代码的字符串列表,通常每个字符串代表源代码的一行。

  • mdoc (Object): Markdown文档对象,用于写入和格式化源代码行。

该函数会遍历输入的源代码行,对每一行进行处理。如果行以三引号结束(即表示注释的结束),则移除三引号并执行以下操作:

  • 如果行为空,则忽略该行;

  • 如果行以冒号结束,则在Markdown文档中添加一行标记为粗体的文本;

  • 其他情况,直接在Markdown文档中添加该行。

此函数不返回任何值,但会直接修改传入的Markdown文档对象。

注意:此函数不会处理源代码行中的错误或bug,如果源代码行无法正确解析为Markdown格式,可能会引发错误。

export_document

该方法用于导出项目的文档,包括每个模块的介绍,类和函数成员的信息。

参数:

doc_path (str): 文档的导出路径。

返回:

None

使用方法:

创建一个 comment_creator 实例,然后调用该方法即可将文档导出到指定的路径,例如:

cc = comment_creator(llm, path)

cc.export_document("./doc.md")

注意:

本方法无法处理的异常情况包括:

  1. doc_path 的指定路径不合法或无写入权限。

  2. 类或函数的注释不规范,导致无法正确解析。

comment_creator_cmd

此函数用于创建用于添加注释的命令行工具。

参数:

path (str): 需要添加注释的文件的路径。

返回:

此函数无返回值。

函数执行步骤:

  1. 首先,函数会调用get_init_llm()函数获取一个初始的词法、语法和语义模型(llm)。

  2. 然后,使用此llm和文件路径作为参数创建一个CommentCreator实例。

  3. 最后,调用fill_comment()方法来对指定文件进行注释。

注意:

  1. 需要确保指定的文件路径存在,否则可能会引发异常。

  2. 本函数不会检查指定的文件是否已经存在注释,如果已经存在注释,则可能会重复添加。

  3. 本函数不会保存添加注释后的文件,如果需要保存,请在调用此函数后调用相应的保存方法。

示例:

comment_creator_cmd('/path/to/your/file')

prompts

该Python模块主要用于处理提示文件的读取、解析和写入操作。其中,get_prompt函数能从提示文件中获取基于特定语言的信息,read_prompt_file函数用于读取并解析文件内容,返回包含版本号、描述、参数和模板字符串的字典,而write_prompt_file函数则用于创建带有版本信息、描述、参数和模板的自动生成文件。

类成员:

name info

函数成员:

name info
get_prompt get_prompt函数用于获取并返回指定语言的提示文件内容或其详细信息。
read_prompt_file 这个函数读取并解析指定路径下文件的内容,返回一个包含版本号、介绍、参数和模板字符串的字典。
write_prompt_file 这是一个用于创建带有版本信息、描述、参数和模板的自动生成文件的函数。

get_prompt

get_prompt函数是用于获取提示文件的内容的函数。

参数:

key (str): 提示文件的键值。

lang (str, 可选): 提示文件的语言版本,默认为 "english"。

return_details (bool, 可选): 是否返回提示文件的详细内容,如果为 False,只返回文件中的模板字符串,默认为 False。

folder (str, 可选): 提示文件所在的文件夹路径,默认为 None,表示提示文件在安装包中。

返回:

str 或 dict: 如果return_details为 True,返回提示文件的详细内容,类型为字典;否则只返回文件中的模板字符串,类型为字符串。

raise:

IOError: 如果无法找到指定的文件,或者读取文件出错。

注意:

此函数将尝试从指定的文件夹或安装的包中获取提示文件。所以如果你安装了此包,请确保提示文件存在于正确的位置,否则可能会引发 IOError。

示例:

get_prompt('welcome', return_details=True)

# 返回 {'templatestr': 'Welcome to our system!', 'details': 'This is the welcome message showed to the user when they log in.'}

read_prompt_file

这个函数的主要作用是读取给定路径下的文件,解析其内容并以字典的形式返回。文件的内容主要包括:版本号、介绍、参数和模板字符串。

参数:

path (str): 文件的路径。

返回:

dict: 一个字典,包含版本号(version)、介绍(description)、参数(params)和模板字符串(templatestr)。

文件的格式应该如下:

version 1.0

这是介绍

参数1: 参数值1

参数2: 参数值2

start

这是模板字符串

这个函数首先会读取文件的所有行,然后依次处理每一行。在处理过程中,首先会获取版本号,然后是介绍,接着是参数。参数的处理会持续到遇到以'start'开头的行为止。最后,从'start'开始后的所有行会被处理为模板字符串。

示例:

output = read_prompt_file('/path/to/file')

print(output)

# 输出:

# {

# 'version': '1.0',

# 'description': '这是介绍',

# 'params': {

# '参数1': '参数值1',

# '参数2': '参数值2'

# },

# 'templatestr': '这是模板字符串'

# }

注意:

如果文件的格式不符合预期,这个函数可能会产生不可预知的行为。例如,如果没有'start'这一行,那么所有的行都会被当作参数来处理。如果参数没有用':'来分隔,那么处理参数的部分可能会出错。

write_prompt_file

这个函数是用于生成一个带有版本信息、描述、参数信息和模板的文件。通过这个函数,我们可以方便的创建用于自动生成文件的模板。

参数:

path (str): 要写入文件的路径。

version (str): 版本信息。

des (str): 文件描述信息。

params (dict): 参数列表,字典形式,包含参数的名称和对应的值。

str_template (str): 用于生成文件的模板。

返回类型:

无返回值。

使用示例:

path = './prompt.txt'

version = '1.0'

des = 'This is a prompt file.'

params = {'author': 'admin', 'date': '2021-01-01'}

str_template = 'Hello, world!'

write_prompt_file(path, version, des, params, str_template)

buffer

该Python模块主要用于提供基于键值对的数据存储与管理,具备数据缓存、Hash键管理、数据库连接管理等功能。其重要功能包括初始化和关闭Shelve数据库和普通数据库,缓存管理(包括缓存内容的保存、读取和删除),以及提供了一个装饰器用于优化相同输入参数的函数运行效率。通过这个模块,用户可以方便的进行数据的存取操作,并且支持线程安全。

类成员:

name info

函数成员:

name info
init_shelve 这是一个初始化Shelve数据库的函数,提供了加载、保存、删除键值对,以及检查键是否存在的功能。
close_shelve 这是一个关闭全局shelve对象并释放资源的函数,防止数据不一致和资源占用问题。
init_db init_db函数用于初始化数据库,并提供对数据库表buffer的增删查改操作接口。
close_db 这是一个关闭数据库连接的函数,用于释放数据库连接资源,如果关闭过程出现错误,会抛出异常。
flush flush函数用于将缓冲区内容存储到文件,并清空缓冲区,确保同一时间只有一个线程执行。
set_flush_freq 这是一个修改全局变量flush_freq值的函数,传入参数即为新的值。
get_hash_key 这个函数用于生成基于函数名和参数的唯一哈希键。
buffer_item 这是一个线程安全的函数,其功能是将键值对存入缓冲区,并在需要时将缓冲区内容写入磁盘。
get_buffer_item 这是一个线程安全的函数,用于通过键值从内存或文件缓存中获取项目,如果无法获取则抛出错误。
has_item_key 这是一个检查给定键(经过hash处理)是否存在于BUFFER_ITEMS或buffer文件中的函数。
remove_item 这是一个用于从缓存中移除指定项的函数,如果该项在缓存文件中存在,也会被删除。
buffer 这是一个装饰器函数,用于缓存并提高相同输入参数函数的运行效率,通过修改版本号可清除旧缓存。
get_path_for_key 这个函数用于根据给定的键构建并返回在缓冲文件夹中的路径。

init_shelve

这是一个初始化Shelve数据库的函数,Shelve数据库是一种简单的键值对存储。

函数首先检查存储Shelve文件的目录是否存在,如果不存在,则创建该目录。

然后,打开Shelve文件,并返回一个对象。

函数定义了四个内部函数:_load_buffer_file,_save_buffer_file,_delete_buffer_file和_has_buffer_file。

这些函数分别用于加载、保存、删除Shelve数据库中的键值对,以及检查一个键是否在数据库中。

此外,这四个内部函数也被赋值给了bb对象的四个属性,分别为:has_buffer_file,load_buffer_file,delete_buffer_file,save_buffer_file。

这样做的目的是为了在函数外部也能调用这四个内部函数。

函数最后返回了打开的Shelve对象。

函数没有参数。

返回值是一个Shelve对象。

注意:

  1. 如果删除的键不存在,_delete_buffer_file函数将会抛出KeyError异常。

  2. 如果保存的键已经存在,_save_buffer_file函数将会覆盖原有的值。

close_shelve

关闭shelve对象的函数。

这个函数会检查全局的SHELVE_OBJ对象是否存在,如果存在,就会关闭这个对象,并将其设置为None。这个函数主要用于确保shelve对象在不需要的时候被正确关闭,避免占用资源或者导致数据不一致的问题。

函数没有参数,也没有返回值。

示例:

close_shelve()

注意:这个函数依赖于全局变量SHELVE_OBJ,所以在使用之前需要确保SHELVE_OBJ已经被正确初始化。另外,如果已经关闭了SHELVE_OBJ,再次调用这个函数会导致错误。

待解决的问题:目前这个函数没有处理可能的异常,比如在关闭SHELVE_OBJ的时候可能会出现的IO错误。

init_db

这个函数init_db的主要作用是初始化数据库,并提供对数据库表buffer的增删查改的接口。

函数首先检查存放数据库的文件夹是否存在,如果不存在则创建。

然后建立连接池,这是一个字典,其键值为线程的ID,值为线程创建的SQLite数据库连接。

函数内部定义了四个操作数据库的函数:

  • _load_buffer_file(key): 从buffer表中加载具有特定键(key)的数据,返回的数据为pickle反序列化后的数据。如果键不存在,返回None。

  • _save_buffer_file(lists): 将数据保存到buffer表中,数据是一个列表,列表的每个元素是一个键值对,键为字符串,值为任意pickle序列化后的对象。

  • _delete_buffer_file(key): 从buffer表中删除具有特定键(key)的数据。

  • _has_buffer_file(key): 判断buffer表中是否存在具有特定键(key)的数据。

函数最后将这四个函数绑定到全局变量bb的相应方法上。

:param 无

:return tuple: 返回一个元组,元组的第一个元素为数据库连接,第二个元素为一个指向数据库的游标。

示例:


conn, cursor = init_db()

注意:函数没有进行异常处理,如果数据库操作失败,可能会抛出异常。

close_db

这个函数是用于关闭数据库连接的。在全局变量CONN_OBJ存在的情况下,它会调用CONN_OBJ的close方法来关闭连接,并将CONN_OBJ设置为None。

函数没有接收任何参数,也没有返回任何内容。主要的用途是在执行数据库操作后,释放数据库连接资源。

该函数假设CONN_OBJ具有close方法,如果CONN_OBJ不具有此方法,将会抛出异常。同时,该函数没有处理可能的数据库关闭异常,如果在关闭数据库连接时发生错误,该错误会被抛出。

示例代码:

# 初始化数据库连接

CONN_OBJ = create_db_connection()

# 执行数据库操作

...

# 在完成数据库操作后,关闭数据库连接

close_db()

flush

flush函数是用于把缓冲区的内容存储到文件中,然后清空缓冲区。

这个函数没有参数和返回值,但在执行过程中,如果没有导入任何模块,会抛出异常。

函数执行的步骤如下:

  1. 检查是否导入了任何模块,如果没有,抛出异常。

  2. 获取buffer_lock锁,保证同一时间只有一个线程可以执行这段代码。

  3. 从BUFFER_OPER_QUEUE队列中获取要保存的项目,队列中存储的是要保存的项目的key。

  4. 将要保存的项目以列表的形式传给save_buffer_file函数,该函数负责将项目存储到文件中。

  5. 最后清空BUFFER_OPER_QUEUE队列。

注意:

  • has_buffer_filebuffer_lockBUFFER_OPER_QUEUEBUFFER_ITEMS都应该是事先定义好的全局变量。

  • save_buffer_file应该是一个可以将项目存储到文件中的函数。

set_flush_freq

这是一个设置全局变量flush_freq的函数。它接收一个参数count,并将flush_freq设为count。

参数:

count: 任何可以赋值给flush_freq的值。

返回:

无返回值。

注意: 这个函数会改变全局变量flush_freq的值。确保在调用此函数时了解这一点,以避免可能的问题。

get_hash_key

这个函数用于生成一个基于函数名,参数(位置参数和关键词参数)的哈希键。

参数:

func_name (str): 函数名

*args (tuple): 可变位置参数,可以接受任意数量的位置参数

**kwargs (dict): 可变关键词参数,可以接受任意数量的关键词参数

返回:

hash_key (str): 返回一个字符串,该字符串以函数名作为前缀,并加上基于参数的哈希值。

例如:

假设我们有一个函数:def my_func(a, b, c=2, d=3):

我们可以这样使用 get_hash_key 函数:

hash_key = get_hash_key('my_func', 1, 2, c=3, d=4)

这将返回一个字符串,比如 "my_func_7fe02"

注意:这个函数依赖另外两个函数 hash_str 和 hash_obj_strbase 来生成哈希值,

如果这两个函数不存在或者报错,那么这个函数也会报错。

错误/异常:

如果输入的 func_name 不是字符串,或者 *args 或 **kwargs 中的元素不能被 hash_obj_strbase 正确处理,那么这个函数可能会抛出异常。

buffer_item

这个函数的作用是将一个键值对存放到缓冲区中。

参数:

key: str - 需要存放的键。

value: 需要存放的值,可以为任何类型。

这个函数没有返回值。

在函数内部,首先会对给定的键执行一个哈希操作来得到一个新的键(nkey)。然后将给定的值存放到以nkey为键的缓冲区中。并且,将nkey添加到操作队列中。

当操作队列的长度达到设定的刷新频率时,将调用flush()函数,将缓冲区的内容写入到磁盘中。

注意:该函数是线程安全的,使用了锁来确保在修改缓冲区和操作队列时不会出现数据竞争。

在调用这个函数时,需要确保给定的键是字符串类型,否则哈希操作可能会失败。

示例:

buffer_item('my_key', 'my_value')

上述代码将把键为'my_key',值为'my_value'的键值对存入缓冲区中。

get_buffer_item

此函数的作用是通过给定的键值,从缓存中获取对应的项目。如果在内存缓存中没有找到对应项目,它将尝试从缓存文件中加载。如果在文件中也没有找到,将抛出一个值错误。

参数:

key: str类型,输入的键值,用来在缓存中查找对应的项目。

返回:

返回对应键值的缓存项目。

异常:

如果没有导入任何模块,会抛出一个异常。

如果没有找到对应键值的缓存项目,会抛出一个值错误。

注意:

此函数是线程安全的,可以在多线程环境下同时访问。

示例:

item = get_buffer_item('my_key')

print(item)

has_item_key

这是一个检查是否存在指定键(通过hash处理)的函数. 它首先确认是否导入了模块,然后获取参数key的hash值。 最后,它会检查这个hash值是否在BUFFER_ITEMS中,或者是否在buffer file中.

参数:

key (str): 要检查的键.

返回:

Boolean: 如果键存在于BUFFER_ITEMS或buffer file中,则返回True,否则返回False.

异常:

如果没有导入模块,将引发异常.

remove_item

移除指定的缓存项。

此函数用于从BUFFER_ITEMS中移除指定的缓存项。如果缓存项在缓冲文件中也存在,也将一并删除。

Args:

key (str): 要移除的缓存项的键。

Raises:

Exception: 如果没有导入模块,将抛出异常。

注意:

此函数需要配合其他函数使用,如get_hash_key, has_buffer_file, delete_buffer_file等。

例如:

remove_item('test_key')

注意,如果没有提前导入必要的模块或准备好必要的配置,此函数可能会引发错误。

buffer

这是一个装饰器函数,用于缓存函数的运行结果。当函数的输入参数相同时,不会重复运行函数,而是直接返回缓存的结果,提高了程序的运行效率。

参数:

version (float): 缓存版本,默认为1.0。当函数逻辑发生改变,需要清除旧的缓存时,可以通过修改这个版本号来实现。

返回:

decorator (function): 返回一个装饰器,用于装饰其他函数。

使用示例:

@buffer(version=2.0)




**def add(x, y):**

return x + y

在此示例中,add函数被buffer装饰器装饰,当多次调用add(1, 2)时,实际上函数只运行了一次,其它次数直接返回了缓存的结果。

注意:

  1. 该装饰器使用了全局的BUFFER_ITEMS字典进行缓存,如果使用的地方较多,可能会占用较多的内存。

  2. 缓存的键是通过函数名和参数生成的哈希值,如果函数名或参数的字符串表示发生改变,可能会产生冲突。

get_path_for_key

此函数用于获取给定键的路径。

参数:

key (str): 用于构建最终路径的键。

返回:

str: 返回由缓冲文件夹路径和给定键组成的路径。如果缓冲文件夹路径在配置中没有找到,默认将使用"buffer"作为缓冲文件夹路径。

例子:

get_path_for_key('sample_key')

'/path/to/buffer_folder/sample_key'

注:

此函数依赖于get_config_instance()方法来获取配置实例,并从中读取"buffer_folder"配置。如果该配置不存在,它会默认使用"buffer"。因此,确保在调用此函数之前,已经正确配置了get_config_instance()方法。

markdowns

这个Python模块主要用于创建和管理Markdown文档,提供了一系列便利的工具类和方法。包括处理文档元素、生成不同等级的标题、处理未转换的字符串、转化代码为Markdown格式代码块、创建和管理Markdown格式的表格和图片等。此外,它还可以生成和管理基于图表的Markdown项目,包括流程图和甘特图,方便用户在Markdown文件中插入多种类型的图表。

类成员:

name info
flowchart_color_enum 这是一个名为flowchart_color_enum的枚举类,主要用于定义和管理流程图中使用的不同颜色的十六进制RGB值。
flowchart_shape_enum 这是一个枚举类,用于提供方便表示流程图中不同形状的字符串值。
MFlowchart MFlowchart类是用于生成和管理基于图表的Markdown项目,提供添加节点、设置节点属性和添加线条等功能。
FGantt FGantt类用于在Markdown文件中创建和管理甘特图,包括设定标题、日期格式以及添加项目和时间信息。
markdown_item_base 这是一个基础类markdown_item_base,主要用于处理Markdown文档元素,作为接口供子类继承,并实现具体方法。
MTitle "MTitle类用于生成、存储和操作不同等级的markdown格式标题。"
MTOC MTOC类是用于处理markdown文本中的行锚点,并将其替换为特定字符串的工具。
MSplit_line MSplit_line类用于处理Markdown语法中的分割线,其方法flush_row_anchor返回表示分割线的列表。
MFooter MFooter类用于处理Markdown文档的页脚信息,提供一个方法flush_row_anchor将输入的锚字符串格式化为特定的字符串列表。
MUnconvert_str 'MUnconvert_str'是一个处理和管理未转换字符串内容的类,提供初始化及返回原始字符串列表的功能。
MCode MCode类是用于将指定的代码内容和语言转化为markdown格式代码块的工具。
MTable MTable是一个方便用户创建和管理Markdown格式表格的类,包含初始化、添加行和刷新行锚三个主要方法。
MImage MImage类是用于生成包含Markdown格式图片的字符串的工具类。
Mstring Mstring类是用于存储和处理markdown格式字符串,提供添加和格式化字符串的方法。
markdowndoc markdowndoc是一个处理Markdown文档的类,能将字符串或Markdown项目转换为Markdown格式,然后写入到指定文件中。

函数成员:

name info

flowchart_color_enum

这个类是一个枚举类,名为flowchart_color_enum,它继承自str和Enum类。它主要是定义了一些代表不同颜色的字符串常量,这些字符串常量都是对应颜色的十六进制RGB值。这些颜色经常会被用在流程图中,因此这个枚举类的名称被命名为flowchart_color_enum。

使用这个枚举类时,你可以直接使用它的枚举值来表示颜色,而不需要记住颜色的十六进制RGB值。这会使你的代码更易读,更易维护。

类属性包含:

  • Red: 红色,其十六进制RGB值为"#FF0000"。

  • Yellow: 黄色,其十六进制RGB值为"#FFFF00"。

  • Blue: 蓝色,其十六进制RGB值为"#00BFFF"。

  • Orange: 橙色,其十六进制RGB值为"#FFA500"。

  • LightGreen: 浅绿色,其十六进制RGB值为"#90EE90"。

  • MediumPurple: 中紫色,其十六进制RGB值为"#9370DB"。

  • Auqamarin: 浅碧色,其十六进制RGB值为"#7FFFAA"。

  • DeepSkyBlue: 深天蓝色,其十六进制RGB值为"#00BFFF"。

  • NavajoWhite: 纳瓦霍白色,其十六进制RGB值为"#FFDEAD"。

以下是一个使用示例:

\# 初始化一个颜色变量

color = flowchart_color_enum.Blue

\# 打印颜色的RGB值

print(color.value)

这个枚举类没有任何已知的错误或bug。

flowchart_shape_enum

这是一个名为flowchart_shape_enum的枚举类,继承自strEnum

此类的主要目的是提供一个方便的方式来表示流程图中的形状。

例如,可以通过flowchart_shape_enum.Roundedges来表示具有圆角的形状,通过flowchart_shape_enum.Stadium来表示体育场形状等等。此枚举类中的每一个值都是一个字符串,这些字符串的形式是特定的,可以被用来在流程图生成器中创建对应的形状。

使用例子如下:

shape = flowchart_shape_enum.Circle

print(shape.value)  \# 输出: ((%%))

此类没有已知的错误或bug。

MFlowchart

MFlowchart 类继承自 markdown_item_base 类,主要用于生成和管理基于图表的Markdown项目。它提供了一系列的方法,供用户添加节点、设置节点颜色和形状,以及添加线条等。

此类使用的示例如下:

flowchart = MFlowchart(oriented_left2right=True)

flowchart.add_node(name='Node1', id='1', anchor_title='title1', icon='icon1')

flowchart.set_node_color(id='1', color=flowchart_color_enum.RED)

flowchart.set_node_shape(id='1', shape=flowchart_shape_enum.CIRCLE)

flowchart.add_line(id1='1', id2='2', message='message', dot_line=False)

主要方法介绍:

  • \_\_init\_\_(self, oriented_left2right=True): 类的构造方法,初始化一个新的 MFlowchart 实例。可选参数 oriented_left2right 决定图表是从左向右 (True) 还是从上到下 (False) 排列。

  • flush_row_anchor(self, anchor_str) -> [str]: 该方法生成流程图的字符串表示,主要用于Markdown的渲染。返回一个字符串列表,每个字符串是流程图的一行。

  • _convert_name(self, answer): 对节点或线条的名称进行处理,以保证其在Markdown中的正确显示。

  • add_node(self, name, id, anchor_title=None, icon=None): 添加一个新的节点。参数包括节点名称 (name),节点ID (id),锚点标题 (anchor_title) 和图标 (icon)。

  • set_node_color(self, id, color: flowchart_color_enum): 设置指定节点的颜色。颜色值需要从 flowchart_color_enum 枚举中选取。

  • set_node_shape(self, id, shape: flowchart_shape_enum): 设置指定节点的形状。形状值需要从 flowchart_shape_enum 枚举中选取。

  • add_line(self, id1, id2, message=None, dot_line=False): 添加一条从 id1 节点到 id2 节点的线。可选参数 message 是线条的文字说明,dot_line 决定线条是否为虚线。

flush_row_anchor

flush_row_anchorMFlowchart 类的一个成员函数,主要负责生成流程图的每一行内容。

参数:

anchor_str : str

该参数在函数内部并未使用,可能是历史代码遗留或者预留的接口,当前版本中没有实际意义。

返回:

Liststr

返回一个字符串列表,每个元素代表了流程图的一行。

该函数首先定义了流程图的基本格式,然后对 nodeslines 这两个属性进行遍历,根据这两个属性的内容生成流程图的节点和连线。

对于节点,它会检查节点的形状(node_shape)和图标(node_icon)是否被定义,如果被定义则使用定义的形状和图标,否则使用默认值。

对于连线,它会检查连线是否有标签(line[2])和是否为虚线(line[3]),然后生成相应的连线。

最后,它会为每个节点添加点击跳转链接(如果有的话)和节点颜色。

举个例子,如果我们有如下的流程图对象:

mf = MFlowchart()

mf.add_node("Start", "start")

mf.add_node("End", "end")

mf.add_line("start", "end", "Go")

那么 flush_row_anchor 函数会生成以下的列表:

[

'```mermaid\\n',

'graph LR\\n',

'start[Start] \\n',

'end[End] \\n',

'start -->|Go| end \\n',

'```\\n'

]

这个列表可以直接用于生成Markdown文件。

注意:本函数没有对输入参数做任何的错误检查和处理,因此在使用时需要保证输入的有效性。

__init__

初始化MFlowchart类的实例。

MFlowchart是一个用于创建和编辑Markdown格式的流程图的类。它可以定义节点、线条、节点颜色、节点形状等。

初始化方法中,我们定义了一些保存流程图信息的列表和字典,以及流程图的默认方向。

参数:

oriented_left2right(bool, 可选): 流程图的方向,默认为从左到右。如果为False,则流程图的方向为从上到下。

使用示例:

m_flowchart = MFlowchart(oriented_left2right=False)

m_flowchart.add_node("开始", "node1")

m_flowchart.add_node("结束", "node2")

m_flowchart.add_line("node1", "node2")

lines = m_flowchart.flush_row_anchor()

for line in lines:

print(line)

_convert_name

_convert_name函数是一个内部函数,主要用于对输入的名称进行处理和转换。

参数:

answer (str or None): 需要处理的名字,如果为None则直接返回None。如果名字以"/"开始,会在前面添加一个空格。同时,会将名字中的所有uncode字符替换为一个空格。

返回:

str or None: 返回处理后的名字,如果输入为None则直接返回None。

注意: 该函数并未处理可能存在的错误,比如answer不是字符串或None的情况,并且该函数也没有对uncode进行定义,可能需要在外部定义uncode变量并传入。

使用示例:

_convert_name("/my name") 会返回 " my name"

_convert_name("my name") 会返回 "my name"

_convert_name(None) 会返回 None

add_node

add_node方法用于向流程图中添加节点。

参数:

  • name (str): 节点的名字。

  • id (int/str): 节点的ID,每个节点的ID需要唯一。

  • anchor_title (str,可选): 锚点标题,用于在Markdown中创建导航,如果给出,则此节点将在流程图中为此链接创建一个导航。默认值是None。

  • icon (str,可选): 为节点添加图标,必须是可接受的图标名称。默认值是None。

返回:

  • 无返回值。

使用方法:

flowchart = MFlowchart()

flowchart.add_node(name="开始", id=1)

flowchart.add_node(name="结束", id=2, icon="fa:check")

在上面的例子中,我们首先创建了一个MFlowchart对象。然后,我们使用add_node方法添加了两个节点,一个名为“开始”的节点和一个名为“结束”带有图标的节点。

注意:所有节点的id必须是唯一的,否则将导致错误。

set_node_color

该方法用于设置流程图节点的颜色。

参数:

id: 流程图节点的标识符,用于区分不同的节点。

color: 流程图节点的颜色,应为flowchart_color_enum枚举类的一个实例,

该枚举类定义了一系列可用的颜色。

此方法无返回值。

例:

flowchart = MFlowchart()

flowchart.add_node('Node1', '1')

flowchart.set_node_color('1', flowchart_color_enum.Blue)

在上述例子中,我们首先创建了一个MFlowchart类的实例,然后添加了一个名为'Node1'的节点,并设定其id为'1',

最后我们调用了set_node_color方法将此节点的颜色设定为蓝色。

注意: 如果传入的id不在实例的节点id映射中,将不会进行任何操作。

set_node_shape

设置流程图中节点的形状

这个函数用于设置流程图中指定id的节点的形状。

参数:

id: 要设置的节点的id。这个id应该是一个已经添加到流程图中的节点的id。

shape: 要设置的形状。这个应该是一个flowchart_shape_enum枚举的实例,表示要设置的形状。

返回值:

这个函数没有返回值。

使用示例:

\# 创建一个MFlowchart实例

flowchart = MFlowchart()

\# 添加一个节点

flowchart.add_node("node1", "id1")

\# 设置该节点的形状

flowchart.set_node_shape("id1", flowchart_shape_enum.ELLIPSE)

注意:

如果传入的id不存在,这个函数没有任何效果。所以在调用这个函数之前,要确保id已经存在于流程图中。

add_line

在流程图中添加一条线。该函数将两个节点通过一条线进行连接,线上可以添加消息,并且可以指定线的类型(实线或虚线)。

参数:

id1: str, 起始节点的id。

id2: str, 结束节点的id。

message: str, 线上的消息,可选,默认为None。

dot_line: bool, 是否为虚线,可选,默认为False。如果为True,则添加的线为虚线,否则为实线。

返回类型:

无返回值。

示例:

以下代码将节点'id1'和节点'id2'通过一条实线连接,线上的消息为'message':

flowchart = MFlowchart()

flowchart.add_node('name1', 'id1')

flowchart.add_node('name2', 'id2')

flowchart.add_line('id1', 'id2', 'message')

注意:

在调用此函数前,确保参数中的起始节点id和结束节点id已经被添加到流程图中。如果这两个id对应的节点不存在,将抛出异常。

FGantt

FGantt类是一个继承自markdown_item_base的类,主要用于在Markdown文件中创建和处理甘特图。甘特图是一种用于描述项目计划执行时间的条形图。在这个类中,用户可以设定甘特图的标题,日期格式,并为甘特图添加项目和相关的时间信息。

以下是使用FGantt类的一个简单示例:

\# 创建一个甘特图实例,标题为"My Project",日期格式为"YYYY-MM-DD"

gantt_chart = FGantt("My Project", "YYYY-MM-DD")

\# 添加一个名为"Task1"的项目

gantt_chart.add_item("Task1")

\# 为"Task1"项目添加一个名为"start"的时间,日期为"2022-01-01"

gantt_chart.add_item_data("Task1", "start", "2022-01-01")

\# 输出甘特图的Markdown格式

print(''.join(gantt_chart.flush_row_anchor()))

属性:

  • self.Items:存储甘特图中所有项目的字典。

  • self.Title:甘特图的标题。

  • self.date_format:甘特图中使用的日期格式。

方法:

  • \_\_init\_\_(self, gantt_title, date_format='YYYY-MM-DD'):初始化方法,设置甘特图的标题和日期格式。

  • flush_row_anchor(self, anchor_str) -> [str]:返回一个包含整个甘特图Markdown格式的字符串列表。

  • add_item(self, name):在甘特图中添加一个新的项目。

  • add_item_data(self, key, date_name, date):为指定的项目添加时间信息。

错误和Bug:

暂时没有发现错误和Bug。

__init__

FGantt类的初始化方法。

该方法用于创建FGantt类的新实例,创建一个甘特图制作工具。甘特图是一种常用的项目管理工具,用于描述项目的各个阶段如何随时间推移进行。该类帮助用户创建和管理甘特图的数据,并提供将其输出到Markdown文件的方法。

参数:

gantt_title: str, 甘特图的标题。

date_format: str, 可选参数,默认为'YYYY-MM-DD',表示日期格式。用于解析和格式化甘特图中的日期数据。

返回:

无返回值。

示例:

gantt = FGantt('My Project')

gantt.add_item('Task 1')

gantt.add_item_data('Task 1', 'Start', '2020-01-01')

gantt.add_item_data('Task 1', 'End', '2020-02-01')

out_str = gantt.flush_row_anchor()

print('\\n'.join(out_str))

以上代码将创建一个名为"My Project"的甘特图项目,添加一个名为"Task 1"的任务,并为该任务设定开始和结束日期。然后,该代码将生成一个Markdown格式的甘特图,并打印出来。

注意:

目前还没有发现错误或bug。

flush_row_anchor

此函数用于生成带有甘特图信息的markdown文本。

参数:

anchor_str (str): 锚点字符串,但在函数内部并未使用此参数。

返回:

list: 返回一个包含markdown格式的甘特图信息的字符串列表。

此函数会生成一个markdown格式的甘特图。首先添加mermaid和gantt的声明,然后添加日期格式和标题。接着,遍历Items字典中所有的项目,为每一个项目创建一个section,然后在section中添加所有的时间段。最后添加结束的声明。

add_item

这是一个添加项目的方法,用于在Gantt图中插入一个新的项目。

参数:

name (str): 一个字符串,用于表示项目的名称。该名称将被用作项目在Gantt图中的标识。

返回:

None

示例:

# 创建一个新的Gantt图对象

g = FGantt('项目进度')

# 增加一个名为 '任务1' 的项目

g.add_item('任务1')

此方法不包含任何异常处理,如果传入的name参数不是字符串类型,程序可能会崩溃。

add_item_data

这是FGantt类的一个方法,用于向指定的项目添加具体的日期数据。

参数:

key (str): 在Items字典中的键,也是项目的名称。

date_name (str): 日期的名称,比如可以是项目的开始日期或结束日期等。

date (tuple): 一个包含具体日期信息的元组,第一个元素是日期的开始时间,第二个元素是日期的结束时间。

返回:

使用例子:

gantt = FGantt('Project Schedule')

gantt.add_item('Task1')

gantt.add_item_data('Task1', 'start_date', ('2021-01-01', '2021-01-31'))

注意:

  1. 请保证key和date_name的唯一性,否则可能会覆盖已有的数据。

  2. date的格式应符合提供给FGantt类的date_format参数。

markdown_item_base

此类为 markdown_item_base,这是一个基础类,主要用于处理markdown文档的各种元素。

类的方法:

  • anchor_pointer_outer : 返回一个空字符串,可能作为锚点调用的外部接口,返回类型为字符串。

  • flush_begin : 无返回值的方法,具体功能未知,可能在子类中做具体实现,或者作为某种复位或初始化方法。

  • flush_row_anchor : 接收一个字符串作为参数,返回一个空列表,可能在子类中做具体实现,用于处理markdown文档的行锚点。

注意:这个类的全部方法都没有具体实现,可能作为接口或者抽象类供其他markdown相关的子类继承并实现具体方法。

示例:

**class markdown_item_subclass(markdown_item_base):**




**def anchor_pointer_outer(self) -> str:**

return "\#anchor"




**def flush_begin(self):**

print("begin flush")




**def flush_row_anchor(self, anchor_str) -> [str]:**

return [anchor_str]

markdown_item = markdown_item_subclass()

markdown_item.flush_begin()  \# 输出 "begin flush"

print(markdown_item.flush_row_anchor("my_anchor"))  \# 输出 ["my_anchor"]

此类尚未发现错误或者bug。

anchor_pointer_outer

此方法是markdown_item_base类的一部分。

anchor_pointer_outer方法用于生成一个空的字符串,此方法没有参数。在具体实现中可能会被子类重写,用于生成特定的字符串。

参数:

返回:

str:返回一个空字符串

注意:本方法在当前类中未实现具体功能,返回的都是空字符串。

示例:

markdown_item = markdown_item_base()

print(markdown_item.anchor_pointer_outer())  \# 输出: ""

flush_begin

flush_begin是一个在markdown_item_base类中定义的方法,此方法没有具体的实现内容,可能是一个需要子类进行重写的抽象方法。

本方法没有输入参数,也没有返回值。

例如,如果我们在子类中重写此方法,可能如下所示:

**class markdown_item_sub(markdown_item_base):**




**def flush_begin(self):**

print("开始执行操作")

在此示例中,flush_begin方法被重写,当调用此方法时,会打印出"开始执行操作"。

flush_row_anchor

此函数是markdown_item_base类的一个方法, 它的主要作用是处理传入的锚字符串,并以字符串数组形式返回。

参数:

anchor_str(str): 传入的锚字符串。

返回:

liststr: 返回处理后的字符串数组,如果没有需要处理的内容,则返回空数组。

MTitle

这是一个名为MTitle的类,它继承自markdown_item_base类。该类主要用于创建、存储和操作markdown格式的标题。

标题在markdown中是非常重要的元素,它们帮助组织和突出内容的结构。这个类主要是用来生成不同等级的markdown标题的。

类的初始化函数接收两个参数:

  • title_str: 这是标题的文本内容,是一个字符串。

  • title_level: 这是标题的等级,是一个整数。Markdown允许1-6级的标题。标题级别越高,标题文本越小。

类中定义了一个名为flush_row_anchor的方法,这个方法接收一个名为anchor_str的参数,返回一个字符串列表。

这个方法的主要作用是用于生成标题的markdown格式的字符串。它使用级别数量的井号\#来表示标题的级别,然后添加标题文本。

例如创建一个一级标题:

title_item = MTitle("Hello, world!", 1)

print(title_item.flush_row_anchor("anchor"))

输出:['# Hello, world! \n']

请注意,由于这是一个基于类的设计,所以在使用这个类时,需要先实例化后再调用其方法。

__init__

初始化 MTitle 类的实例。

这是一个表示 Markdown 标题的类,通过设置标题内容和标题等级来创建标题。标题等级决定了标题的重要程度。

参数:

title_str (str): 标题的文本内容。

title_level (int): 标题的等级。等级越高,标题在页面中的显示越大。

示例:

MTitle("这是一个标题", 1)

这将会创建一个等级为 1 的标题,显示为 "这是一个标题"。

flush_row_anchor

该函数主要用于刷新标题行锚点的功能。

参数:

anchor_str: 锚点字符串,虽然在函数体中未使用,但可能在其他地方有调用。

返回类型:

返回一个列表,其中包含一个字符串,该字符串是Markdown格式的标题。标题的层级由self.title_level决定,标题内容由self.title_str决定。

使用示例:

mtitle = MTitle('测试标题', 1)

mtitle.flush_row_anchor('anchor')

输出: ['# 测试标题 \\n']

注意:

_convert_char函数将self.title_str中的特殊字符转换为Markdown可以识别的格式。此函数体中并未定义,因此应在同一作用域下被定义和引用。

MTOC

这个类名为MTOC,继承自markdown_item_base类。主要功能是处理markdown文本中的行锚点。

具体方法如下:

flush_row_anchor:

该方法用于清理Markdown中的行锚点,将其替换为特定的字符串。

参数:

anchor_str (str): 需要被清理的行锚点字符串。

返回:

liststr: 返回一个包含处理后的字符串的列表,返回的字符串是"[toc] \n"。

用法示例:

mtoc = MTOC()

result = mtoc.flush_row_anchor('\#example')

print(result)  \# 输出:['[toc] \\n']

注意:目前未发现此类的错误或bug。

flush_row_anchor

这是一个类方法,其主要功能是生成markdown的目录链接。

参数:

self: 对象本身的引用

anchor_str: 一个字符串,用于指示markdown文档中的锚点。

返回:

返回一个列表,其中只包含一个字符串。这个字符串是markdown目录的链接,格式为"[toc] \n"。

使用示例:

mtoc = MTOC()

anchor_str = 'example'

print(mtoc.flush_row_anchor(anchor_str)) # 输出:['[toc] \n']

注意:

无论传入什么样的anchor_str,此函数总是返回['[toc] \n']。也就是说,anchor_str参数在当前实现中并未使用。

这可能是一个设计错误或者是尚未完成的功能。如果你希望根据不同的anchor_str生成不同的目录链接,请完善此函数的实现。

MSplit_line

这是一个名为 MSplit_line 的类,它继承自 markdown_item_base。该类的主要用途在于处理并控制Markdown语法中的分割线。

这个类仅包含一个名为 flush_row_anchor 的方法。

方法 flush_row_anchor 的作用是返回一个包含Markdown分割线语法的列表。

例如:

msplit_line = MSplit_line()

print(msplit_line.flush_row_anchor("anchor")) # 输出['*** \n']

方法参数:

anchor_str: 一个字符串类型的参数,但在这个方法中并未被实际使用。

返回类型:

返回一个包含字符串的列表,字符串为 "*** \n",代表Markdown中的分割线。

注意与bug:

本方法的参数并未实际参与到业务逻辑中,考虑到方法的实现,可能需要重新审查和优化这个函数的设计。

flush_row_anchor

flush_row_anchorMSplit_line类的一个方法,该类继承自markdown_item_base

该方法的目标是生成一个markdown的分隔线。

参数:

anchor_str (str): 该参数在当前的方法中并未被使用,但可能在其他继承自同一基类的类中有所使用。

返回:

示例:


ms = MSplit_line()

print(ms.flush_row_anchor(""))

\# 输出:['*** \\n']

注意:

当前方法未使用到输入参数anchor_str,可能存在参数冗余的问题。

MFooter

这是一个名为MFooter的类,它继承自markdown_item_base基类。这个类主要用于处理Markdown文档中的页脚信息。

这个类有一个名为flush_row_anchor的方法,它会接受一个锚字符串(anchor string)作为输入,并返回一个包含特定格式的字符串列表。这个字符串列表中的第一项是"*** \n",第二项是当前的时间戳,格式为'%Y-%m-%d %H:%M:%S'。

flush_row_anchor方法的参数和返回类型如下:

- 参数:

  • anchor_str:一个字符串,表示输入的锚字符串。

- 返回类型:

  • 这个方法返回一个字符串列表,列表中的第一项是"*** \n",第二项是当前的时间戳。

使用这个类的示例:

m_footer = MFooter()

m_footer.flush_row_anchor('example')

在上述示例中,我们首先创建了一个MFooter类的实例,并调用flush_row_anchor方法处理指定的锚字符串。

注意:这个类和方法都没有明显的错误或bug,但是在使用过程中,要确保传入flush_row_anchoranchor_str参数是字符串类型,否则可能会引发类型错误。

flush_row_anchor

flush_row_anchorMFooter类的一个方法,其主要目的是生成markdown的footer部分,该部分包含一个链接和当前的时间。

参数:

anchor_str (str): 需要生成的markdown链接的字符串。

返回:

liststr: 返回一个包含两个字符串元素的列表。第一个元素是markdown的分隔线("*** \n"),第二个元素是当前的时间字符串(采用'%Y-%m-%d %H:%M:%S'格式),每个元素后都跟着一个换行符("\n")。

使用示例:

footer = MFooter()

footer.flush_row_anchor("https://www.example.com")

注意:

当前版本中,该方法并未使用到参数anchor_str,这可能是一个待修复的bug。

MUnconvert_str

MUnconvert_str 是一个从 markdown_item_base 类继承而来的类,主要用于处理和管理未转换的字符串内容。

这个类主要有两个方法:

  • \_\_init\_\_(self, content):初始化方法,接收一个 content 参数,表示未转换的字符串内容。

  • flush_row_anchor(self, anchor_str) -> [str]:该方法接收一个 anchor_str 参数,并且返回一个由 content 组成的列表。

示例:

unconvert_str = MUnconvert_str("Hello World")

result = unconvert_str.flush_row_anchor("some anchor")

print(result)  \# 输出 ["Hello World"]

注意:本类不负责对字符串内容进行任何转换处理,只是单纯的返回原始字符串列表。

__init__

初始化 MUnconvert_str 类的实例。

MUnconvert_str 类是一个处理 Markdown 项目的基类,它主要是用来处理未转换的字符串。

参数:

content (str): 未被转换的字符串内容。

属性:

content (str): 存储传入的字符串内容。

示例:

munconvert_str = MUnconvert_str("原始字符串")

print(munconvert_str.content) # 输出: 原始字符串

flush_row_anchor

这个类MUnconvert_strmarkdown_item_base的一个子类,主要用于处理不需要转换的字符串。

类中的flush_row_anchor函数,用于将内容返回为字符串列表。

函数参数:

anchor_str (str): 作为参数传入,但在当前函数中并未使用。

返回类型:

返回一个包含单个元素的列表,元素为类初始化时传入的content内容。

示例:

unconvert_str = MUnconvert_str('test')

print(unconvert_str.flush_row_anchor('any_str'))

['test']

注意:虽然函数定义中包含anchor_str参数,但在函数实现中并未使用,可能是由于历史版本遗留或者未来版本预留的参数。

MCode

MCode类继承自markdown_item_base基类,用于处理markdown中的代码块部分。

该类的主要目标是将给定的代码内容和代码语言转化为markdown格式的代码块。

类的初始化方法需要两个参数,code_content和code_language,其中code_content参数用于指定代码内容,

code_language参数用于指定代码语言,默认为"python"。

flush_row_anchor方法用于将code_content和code_language转化为markdown格式的代码块,返回一个字符串列表。

例如:

mc = MCode("print('Hello, world!')", "python")

mc.flush_row_anchor("anchor")




**输出**

['\\n', '```python \\n', "print('Hello, world!')", '\\n', '```\\n']

注意,使用该类时,确保传入的code_content已经被正确转义,否则可能会导致markdown格式错误。

参数:

code_content (str): 代码内容

code_language (str, optional): 代码语言,默认为"python"

返回:

flush_row_anchor方法返回一个包含markdown格式代码块的字符串列表。

异常:

无法正确处理非字符串类型的code_content或code_language参数时,可能会引发类型错误。

__init__

初始化 MCode 类的实例。

MCode 类是用于处理 markdown 中的代码块,提供将代码内容和代码语言封装成一个代码块的功能。

Args:

code_content (str): 需要封装的代码块的内容。

code_language (str, optional): 代码块的编程语言,默认为 "python"。

示例:

mcode = MCode("print('Hello World!')", "python")

这将创建一个包含 "print('Hello World!')" 的 Python 代码块。

注意:

目前不支持对代码块内容的语法检查,如果传入的代码内容有语法错误,可能会影响生成的 markdown 文件的正常显示。

flush_row_anchor

这个类MCode继承自markdown_item_base,用于处理markdown中的代码部分。用户可以通过提供代码内容和代码语言来创建一个MCode对象。默认的代码语言为Python。

例如:

mcode = MCode("print('Hello, World!')", "python")

此外,此类还提供了一种方法flush_row_anchor,用于生成代码块的markdown表示形式。

## 方法:flush_row_anchor

这个方法是用来生成代码块的markdown表示形式。

### 参数:

  • anchor_str: 字符串,指定代码行的锚点。 但是在此函数中并未使用此参数,可能是留待未来版本使用,或者是一个错误。

### 返回值:

  • 返回一个字符串列表,包括markdown代码块的开始标记、代码内容、结束标记。

### 例子:

假设我们有一个MCode对象,代码内容为“print('Hello, World!')”,代码语言为Python:

mcode = MCode("print('Hello, World!')", "python")

调用flush_row_anchor方法:

mcode.flush_row_anchor("any_string")

将返回:

['\\n', '```python \\n', "print('Hello, World!')", '\\n', '```\\n']

这是一个markdown代码块的表示形式,可以直接插入到markdown文件中。

### 注意:

当前版本中,anchor_str参数在方法中并未实际使用,这可能是一个错误或者是留给未来版本的特性。

MTable

MTable 是一个 markdown 表格生成类,继承自 markdown_item_base。这个类的目的是为了方便地创建和管理 markdown 格式的表格。它有三个主要的方法:初始化 (\_\_init\_\_),添加行 (add_row) 和刷新行锚 (flush_row_anchor)。

初始化方法 (__init__) 接受一个列名列表 (cols) 作为输入,创建一个新的 MTable 实例。在这个实例中,列名被存储在 self.cols 中,而表格的数据被保存在 self.datas 中。

添加行方法 (add_row) 接受一行数据 (row) 作为输入,并将其添加到 self.datas 中。

刷新行锚方法 (flush_row_anchor) 接受一个锚字符串 (anchor_str) 作为输入,并返回一个包含更新后的 markdown 表格的字符串列表。这个方法首先生成标题行和分隔符行,然后遍历 self.datas 中的每一行,将每一行的数据转化为字符串并用 <br> 替换所有的换行符,最后添加到返回的字符串列表中。

例如:

table = MTable(['姓名', '年龄'])

table.add_row(['小明', '18'])

table.add_row(['小红', '19'])

lines = table.flush_row_anchor('demo')

print(''.join(lines))

将生成以下的 markdown 表格:


| 姓名 | 年龄 |

| ---- | ---- |

| 小明 | 18 |

| 小红 | 19 |

__init__

初始化MTable类。

MTable类是一个用于处理和存储markdown表格数据的类。它使用一组列名初始化,并存储添加的行数据。最后,可以通过flush_row_anchor方法将表格数据转换为markdown格式。

参数:

cols: Liststr,一组字符串,用来初始化表格的列名。

属性:

cols: 用于存放表格的列名。

datas: 用于存储添加的行数据。

使用示例:

示例 1:

mtable = MTable(['Name', 'Age'])

mtable.add_row(['Tom', '30'])

mtable.add_row(['Jerry', '35'])

print(''.join(mtable.flush_row_anchor('')))

输出:

| Name | Age |

| ---- | ---- |

| Tom | 30 |

| Jerry| 35 |

add_row

在当前Markdown表格实例中添加一行数据。

参数:

row: 一个具有与表格列数相同的元素数量的列表。列表元素将被转换为字符串并添加到Markdown表格中。

返回:

示例:

table = MTable(['Name', 'Age'])

table.add_row(['John', 25])

table.add_row(['Sara', 30])

此时, Markdown表格中将添加两行数据。

flush_row_anchor

这是一个类方法,用于将已添加的行数据转换为Markdown表格的格式。

参数:

anchor_str: 锚点字符串,该参数实际上在函数内部并未使用。

返回:

使用方式:

mtable = MTable(['col1','col2'])

mtable.add_row(['row1_col1','row1_col2'])

mtable.add_row(['row2_col1','row2_col2'])

lines = mtable.flush_row_anchor('anchor')

for line in lines:

print(line)

MImage

MImage类是一个用于处理Markdown中的图片的类,继承自markdown_item_base。

构造函数接收一个路径作为初始化参数,用来指示图片的位置。

类中实现了flush_row_anchor方法,该方法用于生成一个包含图片的Markdown格式的字符串。

注:这个类暂时没有检测到任何错误或者BUG。

示例:


m_image = MImage('path_to_your_image')

markdown_str = m_image.flush_row_anchor('anchor_str')

print(markdown_str)

函数介绍:

  • flush_row_anchor(self, anchor_str) -> str

  • anchor_str: 锚点字符串。

  • 返回值: 返回一个列表,列表包含一个Markdown格式的字符串,字符串中包含了图片的HTML标签。

__init__

初始化 MImage 类。

MImage 是一个用于处理 Markdown 中图片的类,它继承自 markdown_item_base 基类。对象实例化时,需要传递一个图片路径参数。

参数:

path (str): 图片的路径。

例子:

mi = MImage('/path/to/image.jpg')

mi.flush_row_anchor('anchor_str')

flush_row_anchor

这是一个MImage类的方法,用于生成markdown的图片链接。

Args:

anchor_str: 一个锚字符串,用于标识图片。但是在当前的方法实现中,这个参数并未被使用。

Returns:

返回一个字符串列表,列表中只有一个元素,这个元素是一个markdown的图片链接。图片链接的形式是:

a{当前时间戳}

其中,a和当前时间戳组成了图片的标识符,图片路径是类在初始化时传入的路径。

注意:

此方法会忽略anchor_str参数。

使用此方法需要确保传入的图片路径是正确的,否则生成的markdown链接可能无法正常工作。

Mstring

# Mstring类是一个基于markdown的字符串处理类,它继承自markdown_item_base基础类。

# 这个类主要用来存储和处理markdown格式的字符串,提供一些常用的方法进行字符串的添加和格式化。

该类的主要方法包括:

  • \_\_init\_\_:初始化方法,用于创建一个Mstring实例,该实例内容为空。

  • set:添加字符串的方法,可以把指定的字符串content添加到实例的内容中。

  • set_bold:添加加粗字符串的方法,可以把指定的字符串content以加粗的格式添加到实例的内容中。

  • flush_row_anchor:返回实例内容的方法,返回的是一个包含所有字符串的列表。

一个使用Mstring类的例子:

ms = Mstring()

ms.set('Hello, world!')

ms.set_bold('I am bold text.')

print(ms.flush_row_anchor())

在上面的例子中,我们创建了一个Mstring实例ms,并分别添加了普通字符串和加粗字符串。最后我们打印出这个实例的所有内容。

注意:

在使用setset_bold方法时,输入的字符串会被_convert_char方法处理,这个方法的具体实现并没有在这里给出,可能会对字符串的内容进行一些转换或者过滤。

__init__

初始化Mstring类的实例。

Mstring类是一个处理Markdown语法的类,用户可以使用此类添加和处理字符串内容,并将其转换为Markdown格式。此类包括添加普通内容(set方法)和加粗内容(set_bold方法)等功能。初始化Mstring类的实例时,将创建一个空列表用于存储字符串内容。

示例:

mstring = Mstring()

mstring.set('hello')

mstring.set_bold('world')

mstring.contents

['hello', 'world']

注意:

类的方法中使用了_convert_char函数来处理特殊Markdown字符,但此函数在给定的代码段中未定义。如果未在其他地方定义,这可能会导致运行错误。

set

这是一个set函数,用于设置Mstring类的内容。具体做法是将输入的内容转化为特定格式,然后添加到类的内容列表中。

参数列表:

content: 需要添加的内容,类型为str。

返回类型:

无。

注意:

本函数使用了内部函数_convert_char(),用于将输入的内容转化为特定格式。如果输入的内容不能被_convert_char()接受,可能会抛出异常。

set_bold

将传入的字符串转化为markdown语言的加粗格式并添加到内容列表。

参数:

content (str): 需要设置为加粗的字符串。

返回:

无返回值

示例:

mstring = Mstring()

mstring.set_bold("Hello World")

assert mstring.contents == ["Hello World"]

flush_row_anchor

此方法的设计目的是获取Mstring实例的所有内容并返回。内容保存在实例的contents变量中,这是一个列表,所以此方法直接返回这个列表。

参数:

anchor_str (str): 这个参数在方法中没有被使用,可能是为了与其他方法保持一致的参数列表而设计的。

返回类型:

Liststr: 返回Mstring实例的所有内容,都保存在一个字符串列表中。

注意:

此方法可能存在错误或者bug,因为在方法中并没有使用到参数anchor_str,这可能会导致一些不可预见的问题。

markdowndoc

这是一个处理markdown文档的类markdowndoc。类的主要功能是接收字符串或者markdown项目(markdown_item_base类型),然后将它们转换成markdown格式的文本,并写入到指定的文件中。

类的主要属性包括:

  • items:存储markdown项目的列表

  • stream_lines:存储待写入文件的字符串列表

  • stream_lines_type:存储待写入字符串的类型列表

  • path:指定的写入文件的路径

类的主要方法包括:

  • flush:依次处理items中的markdown项目和stream_lines中的字符串,将它们转换成markdown格式,并写入到指定的文件中。

  • write:接收字符串或markdown项目,并将它们存储到相应的列表中。

使用示例:

doc = markdowndoc('example.md') \#创建一个实例,指定写入文件名为'example.md'

doc.write('hello world!')  \# 添加字符串到doc中

item = markdown_item_base('example')  \# 创建一个markdown项目

doc.write(item)  \# 添加markdown项目到doc中

doc.flush()  \# 写入到文件

以上示例将创建一个名为'example.md'的markdown文件,文件内容为转换后的字符串'hello world!'和markdown项目。

注意事项:

  • write方法只接收字符串和markdown_item_base类型的参数,如果输入其他类型的参数,将不会被处理。

  • flush方法需要在所有内容添加完毕后调用,以保证所有内容能够正确转换并写入文件。

__init__

初始化markdowndoc类的实例。

参数:

path (str): 文件的路径,用于保存markdown文档的内容。

属性:

items (list): 用于存储markdown项的列表,每个项是markdown_item_base类的实例。

stream_lines (list): 用于存储markdown文档的实际行内容的列表。

stream_lines_type (list): 用于存储与stream_lines对应的行类型的列表。-1表示该行是普通文本,其他值表示该行是markdown项,值为该项在items列表中的索引。

path (str): 文件路径。

用法示例:

md = markdowndoc("example.md")

md.write("Hello, world!")

md.flush() # 将"Hello, world!"写入example.md文件

注意:

这个类并没有处理文件路径无效或者无法写入的情况,使用者需要确保提供的路径是有效的,且具有写入权限。

flush

这是一个用于处理Markdown文档对象(markdowndoc类)的flush方法。此方法的主要目的是将当前文档对象中的所有项目(items)和流行(stream_lines)转换成最终的markdown文档输出行(doc_output_lines)。最后,这些输出行将被写入到目标文件中。

在此过程中,对于每个项目,会先调用flush_begin方法进行初始化操作。然后,对于每个流行,根据其类型进行不同的处理。如果流行类型为-1,表示这是一个普通字符串,直接进行字符的转换并添加到输出行中。如果流行类型不是-1,表示这是一个markdown项目,会使用该项目的flush_row_anchor方法进行处理并将处理结果添加到输出行中。

此方法无需显式的参数输入,同时也不会返回任何结果。

注意,此方法在使用前,请确保提供正确的markdown项目和流行,否则可能会导致生成的markdown文档格式错误。

参数:

返回:

错误和异常:

可能会因为提供的markdown项目或流行格式不正确导致生成的markdown文档格式错误。

示例:

doc = markdowndoc(path="path_to_your_file")

doc.write("Your Markdown content")

doc.flush()

以上示例会将字符串"Your Markdown content"写入到名为path_to_your_file的文件中。

write

这是一个write函数,用于将内容写入Markdown文件。

根据传入的内容类型,该函数会将内容添加到文档流中,并更新流类型列表。如果内容是字符串,那么它将被直接添加到文档流中,同时在流类型列表中添加-1。如果内容是markdown_item_base类的实例,那么它将被添加到items列表中,同时在文档流和流类型列表中添加相应的指针和类型。

参数:

content:写入的内容,可以是字符串或者markdown_item_base类的实例。

返回:

返回传入的内容。

例如,你可以这样调用write函数:

doc = markdowndoc('path_to_your_file')

doc.write('This is a markdown document.')

item = markdown_item_base('title', 'This is a title.')

doc.write(item)

这样,字符串和markdown_item_base实例都会被写入到markdown文件中。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

tketool-0.8.0-py3-none-any.whl (342.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page