自定义 InputFormat
PriorFileInputFormat.java 文件
package org.example;
import lombok.SneakyThrows;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import java.io.IOException;
public class PriorFileInputFormat extends FileInputFormat<Text, IntWritable> {
@Override
protected boolean isSplitable(JobContext context, Path file) {
return false;
}
@SneakyThrows
@Override
public RecordReader<Text, IntWritable> createRecordReader(InputSplit split, TaskAttemptContext context) throws IOException {
//这里需要返回一个自定义的RecordReader
PriorRecordReader recordReader = new PriorRecordReader();
recordReader.initialize(split, context);
return recordReader;
}
static class PriorRecordReader extends RecordReader<Text, IntWritable> {
private FileSplit fileSplit;
private Configuration configuration;
private Text key = new Text();
private final IntWritable value = new IntWritable();
private boolean isRead = false;
private String flag;
@Override
public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException {
this.fileSplit = (FileSplit) split;
this.configuration = context.getConfiguration();
this.flag = this.fileSplit.getPath().getParent().getName();
}
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (!isRead) {
this.key.set(flag);
this.value.set(1);
isRead = true;
return true;
}
return false;
}
@Override
public Text getCurrentKey() throws IOException, InterruptedException {
return key;
}
@Override
public IntWritable getCurrentValue() throws IOException, InterruptedException {
return value;
}
@Override
public float getProgress() throws IOException, InterruptedException {
return isRead ? 1 : 0;
}
@Override
public void close() throws IOException {
}
}
}
PriorDriver.java 文件
package org.example;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import java.io.IOException;
public class PriorDriver {
static class PriorMapper extends Mapper<Text, IntWritable, Text, IntWritable> {
Text keyOut = new Text();
IntWritable valueOut = new IntWritable(1);
@Override
protected void map(Text key, IntWritable value, Context context) throws IOException, InterruptedException {
//这里完全可以不进行任何操作, 但是为了清楚一般的处理逻辑, 画蛇添足增加上
keyOut.set(key.toString());
valueOut.set(value.get());
context.write(keyOut, valueOut);
}
}
static class PriorReduce extends Reducer<Text, IntWritable, Text, IntWritable> {
Text keyOut = new Text();
IntWritable valueOut = new IntWritable();
@Override
protected void reduce(Text key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
int sum = 0;
for (IntWritable value : values) {
// 这里的for循环和java中的迭代器略有不同, 每次获取到的value是同一个对象, 但是其中保存了不同的地址, 因此如果想要使用集合保存全部的value, 需要在循环中new一个对象
sum += value.get();
}
//保存输出
keyOut.set(key.toString());
valueOut.set(sum);
context.write(keyOut, valueOut);
}
}
public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
Configuration configuration = new Configuration();
Job job = Job.getInstance(configuration);
job.setJarByClass(PriorDriver.class);
// 为了图省事, 使用了硬编码, 习惯不好
String inputFilePath = "G:/HW/hadoop/bayes/src/main/resources/train";
String outputFilePath = "G:/HW/hadoop/bayes/src/main/resources/output";
// 自定义输入格式InputFormat
job.setInputFormatClass(PriorFileInputFormat.class);
//设置路径被递归处理, ok
FileInputFormat.setInputDirRecursive(job, true);
//指定整个job的输入路径
FileInputFormat.setInputPaths(job, new Path(inputFilePath));
job.setMapperClass(PriorMapper.class);
job.setReducerClass(PriorReduce.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(IntWritable.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);
FileOutputFormat.setOutputPath(job, new Path(outputFilePath));
job.waitForCompletion(true);
}
}
