CN115249010A - Metric learning method, device, equipment and medium based on pseudo label - Google Patents
Metric learning method, device, equipment and medium based on pseudo label Download PDFInfo
- Publication number
- CN115249010A CN115249010A CN202210966909.7A CN202210966909A CN115249010A CN 115249010 A CN115249010 A CN 115249010A CN 202210966909 A CN202210966909 A CN 202210966909A CN 115249010 A CN115249010 A CN 115249010A
- Authority
- CN
- China
- Prior art keywords
- network
- determining
- student network
- student
- vectors
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 53
- 239000013598 vector Substances 0.000 claims abstract description 125
- 230000006870 function Effects 0.000 claims abstract description 38
- 238000004364 calculation method Methods 0.000 claims description 39
- 238000004590 computer program Methods 0.000 claims description 16
- 238000012545 processing Methods 0.000 claims description 16
- 238000012549 training Methods 0.000 description 11
- 238000004891 communication Methods 0.000 description 8
- 230000008569 process Effects 0.000 description 8
- 238000010586 diagram Methods 0.000 description 5
- 230000003287 optical effect Effects 0.000 description 3
- 230000003993 interaction Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- FLDSMVTWEZKONL-AWEZNQCLSA-N 5,5-dimethyl-N-[(3S)-5-methyl-4-oxo-2,3-dihydro-1,5-benzoxazepin-3-yl]-1,4,7,8-tetrahydrooxepino[4,5-c]pyrazole-3-carboxamide Chemical compound CC1(CC2=C(NN=C2C(=O)N[C@@H]2C(N(C3=C(OC2)C=CC=C3)C)=O)CCO1)C FLDSMVTWEZKONL-AWEZNQCLSA-N 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 239000003795 chemical substances by application Substances 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 239000011521 glass Substances 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000003064 k means clustering Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000012805 post-processing Methods 0.000 description 1
- 238000005295 random walk Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/10—Text processing
- G06F40/194—Calculation of difference between files
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/30—Semantic analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q50/00—Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
- G06Q50/10—Services
- G06Q50/20—Education
- G06Q50/205—Education administration or guidance
-
- G—PHYSICS
- G09—EDUCATION; CRYPTOGRAPHY; DISPLAY; ADVERTISING; SEALS
- G09B—EDUCATIONAL OR DEMONSTRATION APPLIANCES; APPLIANCES FOR TEACHING, OR COMMUNICATING WITH, THE BLIND, DEAF OR MUTE; MODELS; PLANETARIA; GLOBES; MAPS; DIAGRAMS
- G09B19/00—Teaching not covered by other main groups of this subclass
- G09B19/0053—Computers, e.g. programming
-
- G—PHYSICS
- G09—EDUCATION; CRYPTOGRAPHY; DISPLAY; ADVERTISING; SEALS
- G09B—EDUCATIONAL OR DEMONSTRATION APPLIANCES; APPLIANCES FOR TEACHING, OR COMMUNICATING WITH, THE BLIND, DEAF OR MUTE; MODELS; PLANETARIA; GLOBES; MAPS; DIAGRAMS
- G09B5/00—Electrically-operated educational appliances
- G09B5/08—Electrically-operated educational appliances providing for individual presentation of information to a plurality of student stations
Landscapes
- Engineering & Computer Science (AREA)
- Business, Economics & Management (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Educational Administration (AREA)
- Educational Technology (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Strategic Management (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Tourism & Hospitality (AREA)
- Computer Hardware Design (AREA)
- Entrepreneurship & Innovation (AREA)
- Economics (AREA)
- Human Resources & Organizations (AREA)
- Marketing (AREA)
- Primary Health Care (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The invention discloses a metric learning method, a metric learning device, metric learning equipment and a metric learning medium based on a pseudo label. The method comprises the following steps: respectively inputting the target data sets into a teacher network and a student network in an iteration period; the teacher network and the student network are used for identifying class information of the target data set; determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network; determining a loss function of the student network according to the pseudo label and the Euclidean distance; and adjusting the network parameters of the student network according to the loss function. By the technical scheme, the identification efficiency of the target data set information can be improved under the condition that high-cost real labels and large-scale data sets are not used.
Description
Technical Field
The invention relates to the technical field of computers, in particular to a metric learning method, a metric learning device, metric learning equipment and a metric learning medium based on a pseudo label.
Background
Understanding the similarities between data has become the core of many machine learning tasks such as data retrieval, face verification, little learning, and characterization learning. In the prior art, the identification of the similarity between data is generally realized by using a metric learning method.
However, current metric learning training relies heavily on learning using large-scale datasets. And the used data set needs high-cost manual labeling, so that the category diversity of the training data is limited, and the generalization capability of the learning model is influenced.
In the current unsupervised metric learning mode, class information of training data is synthesized by allocating an agent class or finding a pseudo class to each training data, and the training data is classified by means of K-means clustering, hierarchical clustering or random walk. Although class information of training data can be identified without using real labels, the classification approach used causes considerable computational complexity.
Therefore, how to improve the recognition efficiency of training data information without using high-cost real labels and large-scale data sets is a problem to be solved urgently at present.
Disclosure of Invention
The invention provides a metric learning method, a metric learning device, metric learning equipment and a metric learning medium based on a pseudo label, which can solve the problems of low recognition efficiency and high cost in the recognition process of training data information.
According to an aspect of the present invention, there is provided a metric learning method based on a pseudo tag, including:
respectively inputting the target data sets into a teacher network and a student network in an iteration period; the teacher network and the student network are used for identifying class information of the target data set;
determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network;
determining a loss function of the student network according to the pseudo label and the Euclidean distance;
and adjusting the network parameters of the student network according to the loss function.
According to another aspect of the present invention, there is provided a pseudo tag-based metric learning apparatus, including:
the data input module is used for respectively inputting the target data sets into the teacher network and the student network in an iteration cycle; the teacher network and the student network are used for identifying class information of the target data set;
the first data processing module is used for determining pseudo labels containing contextualized semantic similarity according to each first embedded vector generated by the teacher network and determining Euclidean distances among each second embedded vector according to each second embedded vector generated by the student network;
the second data processing module is used for determining a loss function of the student network according to the pseudo label and the Euclidean distance;
and the student network adjusting module is used for adjusting the network parameters of the student network according to the loss function.
According to another aspect of the present invention, there is provided an electronic apparatus including:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein,
the memory stores a computer program executable by the at least one processor to enable the at least one processor to perform the pseudo tag based metric learning method of any of the embodiments of the present invention.
According to another aspect of the present invention, there is provided a computer-readable storage medium storing computer instructions for causing a processor to implement the pseudo tag-based metric learning method according to any one of the embodiments of the present invention when the computer instructions are executed.
According to the technical scheme of the embodiment of the invention, the target data sets are respectively input to the teacher network and the student network in one iteration cycle; the teacher network and the student network are used for identifying class information of the target data set; determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network; determining a loss function of the student network according to the pseudo label and the Euclidean distance; the network parameters of the student network are adjusted according to the loss function, the problems of low identification efficiency and high cost in the identification process of the target data set information are solved, and the identification efficiency of the target data set information can be improved under the condition that a high-cost real label and a large-scale data set are not used.
It should be understood that the statements in this section are not intended to identify key or critical features of the embodiments of the present invention, nor are they intended to limit the scope of the invention. Other features of the present invention will become apparent from the following description.
Drawings
In order to more clearly illustrate the technical solutions in the embodiments of the present invention, the drawings needed to be used in the description of the embodiments will be briefly introduced below, and it is obvious that the drawings in the following description are only some embodiments of the present invention, and it is obvious for those skilled in the art to obtain other drawings based on these drawings without creative efforts.
Fig. 1 is a flowchart of a metric learning method based on pseudo labels according to an embodiment of the present invention;
fig. 2a is a flowchart of a pseudo tag-based metric learning method according to a second embodiment of the present invention;
fig. 2b is a schematic flowchart of a pseudo tag-based metric learning method according to a second embodiment of the present invention;
fig. 3 is a schematic structural diagram of a pseudo tag-based metric learning apparatus according to a third embodiment of the present invention;
fig. 4 is a schematic structural diagram of an electronic device implementing the pseudo tag-based metric learning method according to the embodiment of the present invention.
Detailed Description
In order to make the technical solutions of the present invention better understood, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are only a part of the embodiments of the present invention, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
It should be noted that the terms "first," "second," "object," and the like in the description and claims of the present invention and in the drawings described above are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used is interchangeable under appropriate circumstances such that the embodiments of the invention described herein are capable of operation in sequences other than those illustrated or described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed, but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
Example one
Fig. 1 is a flowchart of a metric learning method based on a pseudo tag according to an embodiment of the present invention, which is applicable to a case where class information of a target data set is quickly identified without using a high-cost real tag and a large-scale data set, and the method may be performed by a metric learning apparatus based on a pseudo tag, which may be implemented in a form of hardware and/or software, and the metric learning apparatus based on a pseudo tag may be configured in an electronic device, and may be configured in a computer device as an example. As shown in fig. 1, the method includes:
s110, respectively inputting the target data sets to a teacher network and a student network in an iteration period; the teacher network and the student network are used for carrying out class information identification on the target data set.
The target data set may refer to a small batch of data sets contained in the sample data set; the sample data set may be a text data set containing a plurality of text segments, and the target data set may be a text segment in the sample data set. For example, a data set obtained by nearest neighbor search may be used as the sample data set; specifically, q samples are randomly sampled in a sample data set to serve as query samples, then q-1 nearest neighbors of each query sample are searched, and the query samples and the corresponding nearest neighbors are combined to obtain a target data set with the size qk.
The class information may refer to information representing an association relationship between data in the target data set, and may be, for example, information representing similarity between data in the target data set.
S120, determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network.
Wherein the first embedded vector may refer to an embedded vector generated by an embedding layer in the teacher network. Contextualized semantic similarity can be used to represent the context and degree of similarity between two data. The pseudo label can be a label which is not manually marked, is preliminarily generated through a teacher network and is used for reflecting the incidence relation between two data; learning of the student network can be achieved by taking the pseudo labels as supervision equivalence relations among data.
The second embedded vector may refer to an embedded vector generated by an embedding layer in the student network. The Euclidean distance can be used for reflecting the distance relation between embedded layer vectors generated according to different data in the same student network embedded layer.
S130, determining a loss function of the student network according to the pseudo label and the Euclidean distance.
Wherein the loss function may refer to a function for evaluating the performance of the student network.
Specifically, lower dimensional spatial data information, such as divergence information, of each embedded layer in the student network can be calculated according to the pseudo tags and the euclidean distance. And determining a loss function of the student network according to the divergence information and the Euclidean distance.
And S140, adjusting network parameters of the student network according to the loss function.
The network parameter of the student network may refer to a parameter for optimizing the student network by adjusting, and may be, for example, a gradient parameter.
Specifically, the network parameters of the student network can be generated through the loss function of the student network, and then the student network is updated according to the newly generated network parameters, so that the updated student network can be obtained.
In an optional implementation manner, the embodiment of the present invention may further include: determining teachers from network parameters of a student networkA network parameter value of the network; and adjusting the network parameters of the teacher network according to the network parameter values so as to update the teacher network in the iteration period. The network parameters of the classroom network can refer to parameters for optimizing the classroom network through adjustment; for example, the network parameter values of the teacher network may be determined according to the weighting between the network parameters of the original teacher network and the updated network parameters of the student network. Specifically, the following formula can be used: theta t =mθ t +(1-m)θ s Determining values of network parameters of the teacher network, wherein θ t Network parameters that can instruct the teacher network; theta s May refer to updated network parameters of the student network; m may refer to a coefficient that controls the update rate of each network parameter, typically m e 0,1]. Therefore, the teacher network is updated by the student network in each iteration, so that the pseudo label calculated by the embedded layer in the teacher network is gradually enhanced, and an effective basis is provided for improving the identification efficiency of the target data set type information.
In another optional implementation, the embodiment of the present invention may further include: if the iteration cycle times meet the preset iteration cycle numerical value, taking the current student network as a target label identification network; and performing class information identification on the target data set according to the target label identification network to generate class information corresponding to the target data set. The preset iteration cycle value can refer to a preset value for evaluating the iteration times; illustratively, the preset fixed numerical value may be used, or the variable numerical value adjusted according to the fitting condition of the student network may be used, and if the training result is under-fitting, the current preset iteration cycle numerical value may be increased; if the training result is overfitting, the current preset iteration cycle value can be reduced. The embodiments of the present invention are not limited in this regard. The target tag identification network can refer to a student network finally obtained when the number of iteration cycles meets a preset iteration cycle value. In the embodiment of the invention, the data with the highest contextualized semantic similarity between the target data set and other data can be used as the class information of the target data set.
Specifically, when the number of iteration cycles meets a preset iteration cycle value, the current student network is used as a target label identification network; therefore, class information identification can be carried out on the target data set according to the target label identification network, and finally, class information corresponding to the target data set can be generated quickly and accurately.
According to the technical scheme of the embodiment of the invention, the target data sets are respectively input to the teacher network and the student network in one iteration cycle; the teacher network and the student network are used for identifying class information of the target data set; determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network; determining a loss function of the student network according to the pseudo label and the Euclidean distance; the network parameters of the student network are adjusted according to the loss function, the problems of low identification efficiency and high cost in the identification process of the target data set information are solved, and the identification efficiency of the target data set information can be improved under the condition that a high-cost real label and a large-scale data set are not used.
Example two
Fig. 2a is a flowchart of a metric learning method based on a pseudo tag according to a second embodiment of the present invention, which is detailed based on the second embodiment, in this embodiment, specifically, an operation of determining a pseudo tag including contextualized semantic similarity according to each first embedded vector generated by a teacher network is detailed, and specifically may include: according to the formula:calculating the pairing similarity between the first embedded vectors; wherein,a first embedded vector representing a generation of the teacher network based on the ith data in the target dataset;presentation teacher network based onGenerating a first embedded vector by the jth data in the target data set; σ represents the gaussian kernel bandwidth; calculating the context similarity between the first embedded vectors according to the overlapping degree of the first embedded vectors; and calculating a pseudo label containing contextualized semantic similarity according to the pairing similarity and the context similarity. As shown in fig. 2a, the method comprises:
s210, respectively inputting the target data sets to a teacher network and a student network in an iteration period; the teacher network and the student network are used for carrying out class information identification on the target data set.
The student network comprises an auxiliary layer and a calculation layer, wherein the auxiliary layer and the calculation layer are two parallel embedded layers f s And g s . Auxiliary layer g of student network s Embedded layer g dedicated to constantly updating teacher network t Thus having a valence of g t The same output dimension. Furthermore, the teacher network and the student network share the backbone encoder
Specifically, if the target data setThe vector of the embedded layer generated by inputting the data into the teacher network can beIf the target data setThe embedded layer vector generated by the auxiliary layer can be input into the student networkThe embedded layer vector generated by the computation layer may be
S220, according to a formula:and calculating the pairing similarity between the first embedded vectors.
Wherein,a first embedded vector representing a generation of the teacher network based on the ith data in the target dataset;representing a first embedded vector generated by the teacher network based on the jth data in the target data set; σ denotes the gaussian kernel bandwidth. The pair similarity can be used to embody a semantic similarity relationship between two data.
It should be noted that, in the embodiment of the present invention, the ith data may be data in the currently input target data set; the jth data can be data in the target data set input last time; each embedding layer, after generating the embedding layer vector, may store the embedding layer vector for use in subsequent calculations.
S230, calculating the context similarity between the first embedded vectors according to the overlapping degree of the first embedded vectors.
Wherein the context similarity can be used to represent the degree of overlap between the two target data sets. Illustratively, the data may be represented by the formula:x j ∈R k (i) And calculating the context similarity between the first embedded vectors. Wherein R is k (i) And R k (j) May refer to data x in the target dataset i And x j Of (2), exemplary, R k (i)=(x j |(x j ∈N k (i)∧x i ∈N k (j))},N k (i) May be x i Nearest neighbor set of, N k (j) May be x j Is selected.
S240, calculating a pseudo label containing contextualized semantic similarity according to the pairing similarity and the context similarity.
Illustratively, it can be based on a formulaPseudo-labels are computed that contain contextualized semantic similarities.
And S250, determining a first Euclidean distance between vectors of each auxiliary layer according to the vectors of each auxiliary layer generated by the auxiliary layer in the student network.
Wherein, in the embodiment of the present invention, the vector of the auxiliary layer can be represented by symbolAnd (4) performing representation.
In an optional embodiment, determining a first euclidean distance between auxiliary layer vectors according to the auxiliary layer vectors generated by the auxiliary layer in the student network includes: according to the formula:determining a first Euclidean distance between the auxiliary layer vectors; wherein,representing an auxiliary layer vector generated by the student network based on ith data in the target data set;representing an auxiliary layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set.
S260, determining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network.
Wherein, in the embodiment of the invention, the vector of the computation layer can be signedAnd (4) performing representation.
In an alternative embodimentDetermining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network, wherein the method comprises the following steps: according to the formula:determining a second Euclidean distance between the vectors of each calculation layer; wherein,a calculation layer vector representing the generation of the ith data in the target data set by the student network;representing a calculation layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set.
S270, determining a first loose contrast loss corresponding to the auxiliary layer according to the pseudo label and the first Euclidean distance; and determining a second loose contrast loss corresponding to the calculation layer according to the pseudo label and the second Euclidean distance.
Wherein the first loose contrast loss may refer to a distance metric of an auxiliary layer in the student network. Illustratively, the following may be expressed in terms of the formula: determining a first relaxed contrast loss;a flag may be indicated that mitigates overfitting, which is 0 if the training results are not overfitting.
Wherein the second relaxed contrast loss may refer to a distance metric of a computational layer in the student network. Illustratively, the following may be expressed in terms of the formula: a second relaxed contrast loss is determined.
S280, determining the target divergence of the student network according to the first Euclidean distance and the second Euclidean distance.
Wherein, the target divergence may refer to data information of a lower dimensional space in the student network. Illustratively, the calculation may be performed by the Kullback-Leibler method. Specifically, the following formula can be used:a target divergence of the student network is determined. Where Ψ (·) may refer to a Softmax operator.
And S290, determining a loss function of the student network according to the first loose contrast loss, the second loose contrast loss and the target divergence.
Specifically, the following formula can be used: a loss function for the student network is determined.
And S2100, adjusting network parameters of the student network according to the loss function.
And S2110, determining network parameter values of the teacher network according to the network parameters of the student network.
And S2120, adjusting network parameters of the teacher network according to the network parameter values to update the teacher network in the iteration cycle.
And S2130, if the iteration cycle times meet a preset iteration cycle numerical value, using the current student network as a target label identification network.
S2140, class information identification is carried out on the target data set according to the target label identification network, and class information corresponding to the target data set is generated.
According to the technical scheme of the embodiment of the invention, a target data set is respectively input to a teacher network and a student network in one iteration period; further, calculating the matching similarity among the first embedded vectors, calculating the context similarity among the first embedded vectors according to the overlapping degree among the first embedded vectors, and further calculating a pseudo label containing contextualized semantic similarity according to the matching similarity and the context similarity; further, determining a first Euclidean distance between vectors of each auxiliary layer according to vectors of each auxiliary layer generated by the auxiliary layer in the student network, and determining a second Euclidean distance between vectors of each calculation layer according to vectors of each calculation layer generated by the calculation layer in the student network; determining a first loose contrast loss corresponding to the auxiliary layer according to the pseudo label and the first Euclidean distance; determining a second loose contrast loss corresponding to the calculation layer according to the pseudo label and the second Euclidean distance; determining the target divergence of the student network according to the first Euclidean distance and the second Euclidean distance; determining a loss function of the student network according to the first loose contrast loss, the second loose contrast loss and the target divergence; finally, network parameters of the student network are adjusted according to the loss function, network parameter values of the teacher network are determined according to the network parameters of the student network, so that the network parameters of the teacher network are adjusted according to the network parameter values, and the teacher network in the iteration period is updated; and when the iteration cycle times meet the preset iteration cycle value, the current student network is used as a target tag identification network, so that the class information of the target data set can be identified according to the target tag identification network, the class information corresponding to the target data set is generated, the problems of low identification efficiency and high cost in the identification process of the class information of the target data set are solved, and the identification efficiency of the class information of the target data set can be improved under the condition of not using high-cost real tags and large-scale data sets.
Fig. 2b is a schematic flowchart of a metric learning method based on a pseudo tag according to a second embodiment of the present invention. Specifically, in an iteration period, the teacher network generates first embedded vectors according to a target data set, and then calculates the pairing similarity between the first embedded vectors, and calculates the context similarity between the first embedded vectors according to the overlapping degree between the first embedded vectors; further, calculating a pseudo label containing contextualized semantic similarity according to the pairing similarity and the context similarity; meanwhile, the student network generates an auxiliary layer vector and a calculation layer vector according to the target data set; determining a first Euclidean distance between vectors of each auxiliary layer according to the vectors of each auxiliary layer generated by the auxiliary layer in the student network; determining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network; then, determining a first loose contrast loss corresponding to the auxiliary layer according to the pseudo label and the first Euclidean distance; determining a second loose contrast loss corresponding to the calculation layer according to the pseudo label and the second Euclidean distance; determining the target divergence of the student network according to the first Euclidean distance and the second Euclidean distance; further, determining a loss function of the student network according to the first loose contrast loss, the second loose contrast loss and the target divergence, so as to adjust network parameters of the student network according to the loss function; and finally, determining the network parameter value of the teacher network according to the network parameter of the student network to update the teacher network in the iteration cycle until the number of iteration cycles meets the preset iteration cycle value, and acquiring the current student network.
EXAMPLE III
Fig. 3 is a schematic structural diagram of a metric learning apparatus based on a pseudo tag according to a third embodiment of the present invention. As shown in fig. 3, the apparatus includes: a data input module 310, a first data processing module 320, a second data processing module 330 and a student network adjusting module 340;
the data input module 310 is configured to input the target data sets to the teacher network and the student network respectively in an iteration cycle; the teacher network and the student network are used for identifying class information of the target data set;
the first data processing module 320 is used for determining pseudo labels containing contextualized semantic similarity according to each first embedded vector generated by the teacher network and determining Euclidean distances among each second embedded vector according to each second embedded vector generated by the student network;
the second data processing module 330 is configured to determine a loss function of the student network according to the pseudo tag and the euclidean distance;
and the student network adjusting module 340 is configured to adjust a network parameter of the student network according to the loss function.
According to the technical scheme of the embodiment of the invention, the target data sets are respectively input to the teacher network and the student network in one iteration cycle; the teacher network and the student network are used for carrying out class information identification on the target data set; determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network; determining a loss function of the student network according to the pseudo label and the Euclidean distance; the network parameters of the student network are adjusted according to the loss function, the problems of low identification efficiency and high cost in the identification process of the target data set information are solved, and the identification efficiency of the target data set information can be improved under the condition that a high-cost real label and a large-scale data set are not used.
Optionally, the first data processing module 320 may include a pseudo tag generating unit, configured to:calculating the pairing similarity between the first embedded vectors; wherein,a first embedded vector representing a generation of the teacher network based on the ith data in the target dataset;representing a first embedded vector generated by the teacher network based on the jth data in the target data set; σ represents the gaussian kernel bandwidth; calculating the context similarity between the first embedded vectors according to the overlapping degree of the first embedded vectors; and calculating a pseudo label containing contextualized semantic similarity according to the pairing similarity and the context similarity.
Optionally, the student network includes an auxiliary layer and a computing layer;
the first data processing module 320 may include an euclidean distance determining unit, configured to determine a first euclidean distance between vectors of each auxiliary layer according to vectors of each auxiliary layer generated by the auxiliary layer in the student network; and determining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network.
Optionally, the euclidean distance determining unit may include a first euclidean distance determining subunit and a second euclidean distance determining subunit;
the euclidean distance first determining subunit may specifically be configured to: according to the formula:determining a first Euclidean distance between the auxiliary layer vectors; wherein,representing an auxiliary layer vector generated by the student network based on ith data in the target data set;representing an auxiliary layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set;
the euclidean distance second determining subunit may be specifically configured to: according to the formula:determining a second Euclidean distance between the vectors of each calculation layer; wherein,a calculation layer vector representing the ith data generated by the student network based on the target data set;representing a calculation layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set.
Optionally, the second data processing module 330 may be specifically configured to: determining a first loose contrast loss corresponding to the auxiliary layer according to the pseudo label and the first Euclidean distance; determining a second loose contrast loss corresponding to the calculation layer according to the pseudo label and the second Euclidean distance; determining the target divergence of the student network according to the first Euclidean distance and the second Euclidean distance; determining a loss function of the student network according to the first relaxed contrast loss, the second relaxed contrast loss and the target divergence.
Optionally, the pseudo tag-based metric learning apparatus may further include a teacher network update module, configured to determine a network parameter value of a teacher network according to a network parameter of a student network; and adjusting the network parameters of the teacher network according to the network parameter values so as to update the teacher network in the iteration period.
Optionally, the pseudo tag-based metric learning apparatus may further include a post-processing module, configured to use the current student network as a target tag identification network if the number of iteration cycles satisfies a preset iteration cycle value; and performing class information identification on the target data set according to the target label identification network to generate class information corresponding to the target data set.
The metric learning device based on the pseudo label provided by the embodiment of the invention can execute the metric learning method based on the pseudo label provided by any embodiment of the invention, and has corresponding functional modules and beneficial effects of the execution method.
Example four
FIG. 4 illustrates a block diagram of an electronic device 410 that may be used to implement an embodiment of the invention. Electronic devices are intended to represent various forms of digital computers, such as laptops, desktops, workstations, personal digital assistants, servers, blade servers, mainframes, and other appropriate computers. The electronic device may also represent various forms of mobile devices, such as personal digital assistants, cellular phones, smart phones, wearable devices (e.g., helmets, glasses, watches, etc.), and other similar computing devices. The components shown herein, their connections and relationships, and their functions, are meant to be exemplary only, and are not meant to limit implementations of the inventions described and/or claimed herein.
As shown in fig. 4, the electronic device 410 includes at least one processor 420, and a memory communicatively connected to the at least one processor 420, such as a Read Only Memory (ROM) 430, a Random Access Memory (RAM) 440, and the like, wherein the memory stores computer programs executable by the at least one processor, and the processor 420 may perform various suitable actions and processes according to the computer programs stored in the Read Only Memory (ROM) 430 or the computer programs loaded from the storage unit 490 into the Random Access Memory (RAM) 440. In the RAM440, various programs and data required for the operation of the electronic device 410 may also be stored. The processor 420, the ROM430, and the RAM440 are connected to each other by a bus 450. An input/output (I/O) interface 460 is also connected to bus 450.
Various components in the electronic device 410 are connected to the I/O interface 460, including: an input unit 470 such as a keyboard, a mouse, etc.; an output unit 480 such as various types of displays, speakers, and the like; a storage unit 490, such as a magnetic disk, optical disk, or the like; and a communication unit 4100 such as a network card, a modem, a wireless communication transceiver, and the like. The communication unit 4100 allows the electronic device 410 to exchange information/data with other devices through a computer network such as the internet and/or various telecommunication networks.
The method comprises the following steps:
respectively inputting the target data sets into a teacher network and a student network in an iteration period; the teacher network and the student network are used for carrying out class information identification on the target data set;
determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network;
determining a loss function of the student network according to the pseudo label and the Euclidean distance;
and adjusting the network parameters of the student network according to the loss function.
In some embodiments, the pseudo-tag based metric learning method may be implemented as a computer program tangibly embodied in a computer-readable storage medium, such as storage unit 490. In some embodiments, part or all of a computer program may be loaded onto and/or installed onto electronic device 410 via ROM430 and/or communications unit 4100. When the computer program is loaded into RAM440 and executed by processor 420, one or more steps of pseudo tag-based metric learning method X described above may be performed. Alternatively, in other embodiments, processor 420 may be configured to perform the pseudo tag-based metric learning method by any other suitable means (e.g., by way of firmware).
Various implementations of the systems and techniques described here above may be implemented in digital electronic circuitry, integrated circuitry, field Programmable Gate Arrays (FPGAs), application Specific Integrated Circuits (ASICs), application Specific Standard Products (ASSPs), system on a chip (SOCs), load programmable logic devices (CPLDs), computer hardware, firmware, software, and/or combinations thereof. These various embodiments may include: implemented in one or more computer programs that are executable and/or interpretable on a programmable system including at least one programmable processor, which may be special or general purpose, receiving data and instructions from, and transmitting data and instructions to, a storage system, at least one input device, and at least one output device.
A computer program for implementing the methods of the present invention may be written in any combination of one or more programming languages. These computer programs may be provided to a processor of a general purpose computer, special purpose computer, or other programmable data processing apparatus, such that the computer programs, when executed by the processor, cause the functions/acts specified in the flowchart and/or block diagram block or blocks to be performed. A computer program can execute entirely on a machine, partly on the machine, as a stand-alone software package, partly on the machine and partly on a remote machine or entirely on the remote machine or server.
In the context of the present invention, a computer-readable storage medium may be a tangible medium that can contain, or store a computer program for use by or in connection with an instruction execution system, apparatus, or device. A computer readable storage medium may include, but is not limited to, an electronic, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any suitable combination of the foregoing. Alternatively, the computer readable storage medium may be a machine readable signal medium. More specific examples of a machine-readable storage medium would include an electrical connection based on one or more wires, a portable computer diskette, a hard disk, a Random Access Memory (RAM), a read-only memory (ROM), an erasable programmable read-only memory (EPROM or flash memory), an optical fiber, a compact disc read-only memory (CD-ROM), an optical storage device, a magnetic storage device, or any suitable combination of the foregoing.
To provide for interaction with a user, the systems and techniques described here can be implemented on an electronic device having: a display device (e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor) for displaying information to a user; and a keyboard and a pointing device (e.g., a mouse or a trackball) by which a user may provide input to the electronic device. Other kinds of devices may also be used to provide for interaction with a user; for example, feedback provided to the user can be any form of sensory feedback (e.g., visual feedback, auditory feedback, or tactile feedback); and input from the user can be received in any form, including acoustic, speech, or tactile input.
The systems and techniques described here can be implemented in a computing system that includes a back-end component (e.g., as a data server), or that includes a middleware component (e.g., an application server), or that includes a front-end component (e.g., a user computer having a graphical user interface or a web browser through which a user can interact with an implementation of the systems and techniques described here), or any combination of such back-end, middleware, or front-end components. The components of the system can be interconnected by any form or medium of digital data communication (e.g., a communication network). Examples of communication networks include: local Area Networks (LANs), wide Area Networks (WANs), blockchain networks, and the internet.
The computing system may include clients and servers. A client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other. The server can be a cloud server, also called a cloud computing server or a cloud host, and is a host product in a cloud computing service system, so that the defects of high management difficulty and weak service expansibility in the traditional physical host and VPS service are overcome.
It should be understood that various forms of the flows shown above may be used, with steps reordered, added, or deleted. For example, the steps described in the present invention may be executed in parallel, sequentially, or in different orders, and are not limited herein as long as the desired result of the technical solution of the present invention can be achieved.
The above-described embodiments should not be construed as limiting the scope of the invention. It should be understood by those skilled in the art that various modifications, combinations, sub-combinations and substitutions may be made in accordance with design requirements and other factors. Any modification, equivalent replacement, and improvement made within the spirit and principle of the present invention should be included in the protection scope of the present invention.
Claims (10)
1. A metric learning method based on pseudo labels is characterized by comprising the following steps:
respectively inputting the target data sets into a teacher network and a student network in an iteration period; the teacher network and the student network are used for carrying out class information identification on the target data set;
determining pseudo labels containing contextualized semantic similarity according to the first embedded vectors generated by the teacher network, and determining Euclidean distances among the second embedded vectors according to the second embedded vectors generated by the student network;
determining a loss function of the student network according to the pseudo label and the Euclidean distance;
and adjusting the network parameters of the student network according to the loss function.
2. The method of claim 1, wherein determining pseudo-labels containing contextualized semantic similarities from each first embedded vector generated by the teacher network comprises:
according to the formula:calculating the pairing similarity between the first embedded vectors; wherein,a first embedded vector representing a generation of the teacher network based on the ith data in the target dataset;representing a first embedded vector generated by the teacher network based on the jth data in the target data set; σ represents the Gaussian kernel bandwidth;
calculating the context similarity between the first embedded vectors according to the overlapping degree of the first embedded vectors;
and calculating a pseudo label containing contextualized semantic similarity according to the pairing similarity and the context similarity.
3. The method of claim 1, wherein the student network comprises an assistance layer and a computing layer;
the determining the euclidean distance between the second embedded vectors according to the second embedded vectors generated by the student network includes:
determining a first Euclidean distance between vectors of each auxiliary layer according to the vectors of each auxiliary layer generated by the auxiliary layer in the student network;
and determining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network.
4. The method of claim 3, wherein determining the first Euclidean distance between the auxiliary layer vectors according to the auxiliary layer vectors generated by the auxiliary layers in the student network comprises:
according to the formula:determining a first Euclidean distance between the auxiliary layer vectors; wherein,representing an auxiliary layer vector generated by the student network based on ith data in the target data set;representing an auxiliary layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set;
determining a second Euclidean distance between vectors of each calculation layer according to the vectors of each calculation layer generated by the calculation layer in the student network, wherein the second Euclidean distance comprises the following steps:
according to the formula:determining a second Euclidean distance between the vectors of each calculation layer; wherein,a calculation layer vector representing the ith data generated by the student network based on the target data set;representing a computing layer vector generated by the student network based on the jth data in the target data set; n represents the number of data in the target data set.
5. The method of claim 4, wherein determining a loss function for a student network based on the pseudo-tag and the Euclidean distance comprises:
determining a first loose contrast loss corresponding to the auxiliary layer according to the pseudo label and the first Euclidean distance;
determining a second loose contrast loss corresponding to the calculation layer according to the pseudo label and the second Euclidean distance;
determining the target divergence of the student network according to the first Euclidean distance and the second Euclidean distance;
and determining a loss function of the student network according to the first loose contrast loss, the second loose contrast loss and the target divergence.
6. The method of claim 1, further comprising:
determining a network parameter value of a teacher network according to the network parameter of the student network;
and adjusting the network parameters of the teacher network according to the network parameter values so as to update the teacher network in the iteration period.
7. The method of claim 1, further comprising:
if the iteration cycle times meet a preset iteration cycle value, taking the current student network as a target label identification network;
and performing class information identification on the target data set according to the target label identification network to generate class information corresponding to the target data set.
8. A pseudo tag-based metric learning apparatus, comprising:
the data input module is used for respectively inputting the target data sets into the teacher network and the student network in an iteration cycle; the teacher network and the student network are used for identifying class information of the target data set;
the first data processing module is used for determining pseudo labels containing contextualized semantic similarity according to each first embedded vector generated by the teacher network and determining Euclidean distances among each second embedded vector according to each second embedded vector generated by the student network;
the second data processing module is used for determining a loss function of the student network according to the pseudo label and the Euclidean distance;
and the student network adjusting module is used for adjusting the network parameters of the student network according to the loss function.
9. An electronic device, characterized in that the electronic device comprises:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein,
the memory stores a computer program executable by the at least one processor, the computer program being executable by the at least one processor to enable the at least one processor to perform the pseudo tag-based metric learning method of any of claims 1-7.
10. A computer-readable storage medium having stored thereon computer instructions for causing a processor to implement the pseudo tag-based metric learning method of any of claims 1-7 when executed.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210966909.7A CN115249010A (en) | 2022-08-11 | 2022-08-11 | Metric learning method, device, equipment and medium based on pseudo label |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210966909.7A CN115249010A (en) | 2022-08-11 | 2022-08-11 | Metric learning method, device, equipment and medium based on pseudo label |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115249010A true CN115249010A (en) | 2022-10-28 |
Family
ID=83700632
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210966909.7A Pending CN115249010A (en) | 2022-08-11 | 2022-08-11 | Metric learning method, device, equipment and medium based on pseudo label |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115249010A (en) |
-
2022
- 2022-08-11 CN CN202210966909.7A patent/CN115249010A/en active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112016633A (en) | Model training method and device, electronic equipment and storage medium | |
CN112966744A (en) | Model training method, image processing method, device and electronic equipment | |
CN113705628B (en) | Determination method and device of pre-training model, electronic equipment and storage medium | |
CN114444619B (en) | Sample generation method, training method, data processing method and electronic device | |
CN113657483A (en) | Model training method, target detection method, device, equipment and storage medium | |
CN112560481B (en) | Statement processing method, device and storage medium | |
CN112580733A (en) | Method, device and equipment for training classification model and storage medium | |
CN113537192A (en) | Image detection method, image detection device, electronic equipment and storage medium | |
CN114037059A (en) | Pre-training model, model generation method, data processing method and data processing device | |
CN112949818A (en) | Model distillation method, device, equipment and storage medium | |
CN112989170A (en) | Keyword matching method applied to information search, information search method and device | |
CN115909376A (en) | Text recognition method, text recognition model training device and storage medium | |
CN115565177A (en) | Character recognition model training method, character recognition device, character recognition equipment and medium | |
CN114972877A (en) | Image classification model training method and device and electronic equipment | |
CN115359308A (en) | Model training method, apparatus, device, storage medium, and program for identifying difficult cases | |
CN113850072A (en) | Text emotion analysis method, emotion analysis model training method, device, equipment and medium | |
CN112580620A (en) | Sign picture processing method, device, equipment and medium | |
CN117273117A (en) | Language model training method, rewarding model training device and electronic equipment | |
CN115249010A (en) | Metric learning method, device, equipment and medium based on pseudo label | |
CN113343047B (en) | Data processing method, data retrieval method and device | |
CN114610953A (en) | Data classification method, device, equipment and storage medium | |
CN114611609A (en) | Graph network model node classification method, device, equipment and storage medium | |
CN113901901A (en) | Training method and device of image processing model, electronic equipment and medium | |
CN115809687A (en) | Training method and device for image processing network | |
CN114187487A (en) | Processing method, device, equipment and medium for large-scale point cloud data |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |