变压器数据集

问题描述 投票:0回答:1

我正在使用转换器来创建模型来生成代码。我很困惑如何获取数据的输入和输出来训练解码器仅变压器我的数据库是:-

[
    {
      "code": "def bubble_sort(arr):\n    n = len(arr)\n    for i in range(n):\n        for j in range(0, n-i-1):\n            if arr[j] > arr[j+1]:\n                arr[j], arr[j+1] = arr[j+1], arr[j]\n    return arr",
      "function_name": "bubble_sort",
      "docstring": "Bubble Sort repeatedly steps through the list, compares adjacent elements, and swaps them if they are in the wrong order. This pass through the list is repeated until the list is sorted.",
      "language": "python",
      "tags": ["sorting", "algorithm", "bubble sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def insertion_sort(arr):\n    for i in range(1, len(arr)):\n        key = arr[i]\n        j = i - 1\n        while j >= 0 and key < arr[j]:\n            arr[j + 1] = arr[j]\n            j -= 1\n        arr[j + 1] = key\n    return arr",
      "function_name": "insertion_sort",
      "docstring": "Insertion Sort builds the sorted array  one element at a time by repeatedly picking the next element and inserting it into its correct position in the already sorted part of the array.",
      "language": "python",
      "tags": ["sorting", "algorithm", "insertion sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def selection_sort(arr):\n    for i in range(len(arr)):\n        min_idx = i\n        for j in range(i+1, len(arr)):\n            if arr[j] < arr[min_idx]:\n                min_idx = j\n        arr[i], arr[min_idx] = arr[min_idx], arr[i]\n    return arr",
      "function_name": "selection_sort",
      "docstring": "Selection Sort repeatedly selects the smallest element from the unsorted part of the list and swaps it with the first unsorted element, effectively growing the sorted portion of the list.",
      "language": "python",
      "tags": ["sorting", "algorithm", "selection sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def merge_sort(arr):\n    if len(arr) > 1:\n        mid = len(arr) // 2\n        L = arr[:mid]\n        R = arr[mid:]\n\n        merge_sort(L)\n        merge_sort(R)\n\n        i = j = k = 0\n\n        while i < len(L) and j < len(R):\n            if L[i] < R[j]:\n                arr[k] = L[i]\n                i += 1\n            else:\n                arr[k] = R[j]\n                j += 1\n            k += 1\n\n        while i < len(L):\n            arr[k] = L[i]\n            i += 1\n            k += 1\n\n        while j < len(R):\n            arr[k] = R[j]\n            j += 1\n            k += 1\n    return arr",
      "function_name": "merge_sort",
      "docstring": "Merge Sort is a divide-and-conquer algorithm that recursively splits the list into halves, sorts each half, and then merges the sorted halves back together.",
      "language": "python",
      "tags": ["sorting", "algorithm", "merge sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def quick_sort(arr):\n    if len(arr) <= 1:\n        return arr\n    else:\n        pivot = arr[len(arr) // 2]\n        left = [x for x in arr if x < pivot]\n        middle = [x for x in arr if x == pivot]\n        right = [x for x in arr if x > pivot]\n        return quick_sort(left) + middle + quick_sort(right)",
      "function_name": "quick_sort",
      "docstring": "Quick Sort is a divide-and-conquer algorithm. It selects a 'pivot' element from the list, partitions the other elements into those less than the pivot and those greater, and then recursively sorts the sub-arrays.",
      "language": "python",
      "tags": ["sorting", "algorithm", "quick sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def heapify(arr, n, i):\n    largest = i\n    left = 2 * i + 1\n    right = 2 * i + 2\n\n    if left < n and arr[left] > arr[largest]:\n        largest = left\n\n    if right < n and arr[right] > arr[largest]:\n        largest = right\n\n    if largest != i:\n        arr[i], arr[largest] = arr[largest], arr[i]\n        heapify(arr, n, largest)\n\n\ndef heap_sort(arr):\n    n = len(arr)\n    for i in range(n // 2 - 1, -1, -1):\n        heapify(arr, n, i)\n\n    for i in range(n-1, 0, -1):\n        arr[i], arr[0] = arr[0], arr[i]\n        heapify(arr, i, 0)\n    return arr",
      "function_name": "heap_sort",
      "docstring": "Heap Sort builds a max heap from the list, then repeatedly extracts the largest element (the root of the heap) and rebuilds the heap, thereby sorting the list.",
      "language": "python",
      "tags": ["sorting", "algorithm", "heap sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def shell_sort(arr):\n    n = len(arr)\n    gap = n // 2\n    while gap > 0:\n        for i in range(gap, n):\n            temp = arr[i]\n            j = i\n            while j >= gap and arr[j - gap] > temp:\n                arr[j] = arr[j - gap]\n                j -= gap\n            arr[j] = temp\n        gap //= 2\n    return arr",
      "function_name": "shell_sort",
      "docstring": "Shell Sort is an extension of Insertion Sort that allows the exchange of far apart elements. It improves on Insertion Sort by comparing elements distant apart, gradually reducing the gap between elements to be compared.",
      "language": "python",
      "tags": ["sorting", "algorithm", "shell sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def counting_sort_for_radix(arr, exp):\n    n = len(arr)\n    output = [0] * n\n    count = [0] * 10\n\n    for i in range(n):\n        index = arr[i] // exp\n        count[index % 10] += 1\n\n    for i in range(1, 10):\n        count[i] += count[i - 1]\n\n    i = n - 1\n    while i >= 0:\n        index = arr[i] // exp\n        output[count[index % 10] - 1] = arr[i]\n        count[index % 10] -= 1\n        i -= 1\n\n    for i in range(len(arr)):\n        arr[i] = output[i]\n\n\ndef radix_sort(arr):\n    max_val = max(arr)\n    exp = 1\n    while max_val // exp > 0:\n        counting_sort_for_radix(arr, exp)\n        exp *= 10\n    return arr",
      "function_name": "radix_sort",
      "docstring": "Radix Sort processes the list digit by digit, starting from the least significant digit to the most significant digit, grouping numbers by each digit's value. It uses Counting Sort as a subroutine to sort based on individual digits.",
      "language": "python",
      "tags": ["sorting", "algorithm", "radix sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def counting_sort(arr):\n    max_val = max(arr)\n    count = [0] * (max_val + 1)\n    output = [0] * len(arr)\n\n    for num in arr:\n        count[num] += 1\n\n    for i in range(1, len(count)):\n        count[i] += count[i - 1]\n\n    for num in reversed(arr):\n        output[count[num] - 1] = num\n        count[num] -= 1\n\n    return output",
      "function_name": "counting_sort",
      "docstring": "Counting Sort counts the occurrences of each unique element in the list, and then uses this count information to place each element in its correct position in the output array.",
      "language": "python",
      "tags": ["sorting", "algorithm", "counting sort"],
      "dataset": "sorting_algorithms"
    },
    {
      "code": "def bucket_sort(arr):\n    bucket = []\n    n = len(arr)\n    for i in range(n):\n        bucket.append([])\n\n    for j in arr:\n        index = int(n * j)\n        bucket[index].append(j)\n\n    for i in range(n):\n        bucket[i] = sorted(bucket[i])\n\n    k = 0\n    for i in range(n):\n        for j in range(len(bucket[i])):\n            arr[k] = bucket[i][j]\n            k += 1\n    return arr",
      "function_name": "bucket_sort",
      "docstring": "Bucket Sort divides the elements into several buckets. Each bucket is then sorted individually, either using a different sorting algorithm or by recursively applying Bucket Sort.",
      "language": "python",
      "tags": ["sorting", "algorithm", "bucket sort"],
      "dataset": "sorting_algorithms"
    }
  ]

所以这是简单的排序算法数据库,有人可以帮忙将其更改为输入和输出形式

应该输入和输出什么来从数据库传递到模型

nlp ml
1个回答
0
投票

假设您的编码器块接受 3d 张量(n,seq_len,features)

  • n
    对应于观察结果
  • seq_len
    通常对应于句子/文本序列的长度
  • features
    通常代表嵌入大小

这里的每个 dict() 都是一个单独的观察(n 轴);

我建议在解码器块中使用

code
,因为您想生成文本(代码)作为输出。

然后事情就变得有点棘手了。您有不同的类别:

function_name, docstring, language, tags, dataset
,每个类别都对应于不同类型的数据(例如分类、纯文本等)。弄清楚如何有效地利用这些数据差异可能需要一些创造力。更简单的方法似乎是 (1) 将它们全部视为单个字符串(所有类别的单个嵌入),或 (2) 对 5 个类别中的每一个使用 5 个不同的嵌入,但保持相同的维度(例如 512 为每个子字单元/字)。这将为您提供 1 个嵌入
(long_seq_len, embed_size (e.g. 512))
(在情况 (1) 中)或 5 个嵌入
(seq_len_i, embed_size)
,其中 seq_len_i 对于 5 个嵌入中的每一个都不同。然后,您可以沿着
seq_len_i
轴将它们连接起来以获得
(sum_of_seq_len, embed_size)
案例(2).

这样,对于编码器块,在情况 (1) 中您将获得单个观测值 (1, long_seq_len, embed_size),或者在情况 (2) 中获得 (1, sum_of_seq_len, embed_size)。

对于解码器,您将使用

code
行,以类似的方式对待它:创建与之前相同大小的嵌入(再次,为了简单起见,例如 512),并获取数据块(1,seq_len,512) .

© www.soinside.com 2019 - 2024. All rights reserved.